Skip to content

Commit

Permalink
Merge branch 'main' into zhrua/exp_orchestrator
Browse files Browse the repository at this point in the history
  • Loading branch information
lalala123123 authored Feb 4, 2024
2 parents e3de7b1 + 8cc7e20 commit 425e714
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 137 deletions.
82 changes: 60 additions & 22 deletions src/promptflow/promptflow/_sdk/_service/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import json
import logging
import os
import platform
import subprocess
import sys

import waitress

from promptflow._cli._utils import _get_cli_activity_name
from promptflow._constants import PF_NO_INTERACTIVE_LOGIN
from promptflow._sdk._constants import LOGGER_NAME
from promptflow._sdk._service.app import create_app
from promptflow._sdk._service.utils.utils import (
check_pfs_service_status,
dump_port_to_config,
get_port_from_config,
get_started_service_info,
Expand All @@ -24,6 +25,15 @@
from promptflow._sdk._utils import get_promptflow_sdk_version, print_pf_version
from promptflow.exceptions import UserErrorException

app = None


def get_app():
global app
if app is None:
app, _ = create_app()
return app


def add_start_service_action(subparsers):
"""Add action to start pfs."""
Expand All @@ -38,6 +48,11 @@ def add_start_service_action(subparsers):
action="store_true",
help="If the port is used, the existing service will be terminated and restart a new service.",
)
start_pfs_parser.add_argument(
"--synchronous",
action="store_true",
help=argparse.SUPPRESS,
)
start_pfs_parser.set_defaults(action="start")


Expand All @@ -52,26 +67,52 @@ def add_show_status_action(subparsers):


def start_service(args):
# User Agent will be set based on header in request, so not set globally here.
os.environ[PF_NO_INTERACTIVE_LOGIN] = "true"
port = args.port
app, _ = create_app()
if port and is_port_in_use(port) and not args.force:
app.logger.warning(f"Service port {port} is used.")
raise UserErrorException(f"Service port {port} is used.")
if not port:
get_app()

def validate_port(port, force_start):
if is_port_in_use(port):
if force_start:
app.logger.warning(f"Force restart the service on the port {port}.")
kill_exist_service(port)
else:
app.logger.warning(f"Service port {port} is used.")
raise UserErrorException(f"Service port {port} is used.")

if port:
dump_port_to_config(port)
validate_port(port, args.force)
else:
port = get_port_from_config(create_if_not_exists=True)
validate_port(port, args.force)
# Set host to localhost, only allow request from localhost.
cmd = [
sys.executable,
"-m",
"waitress",
"--host",
"127.0.0.1",
f"--port={port}",
"--call",
"promptflow._sdk._service.entry:get_app",
]
if args.synchronous:
subprocess.call(cmd)
else:
dump_port_to_config(port)

if is_port_in_use(port):
if args.force:
app.logger.warning(f"Force restart the service on the port {port}.")
kill_exist_service(port)
# Start a pfs process using detach mode
if platform.system() == "Windows":
os.spawnv(os.P_DETACH, sys.executable, cmd)
else:
app.logger.warning(f"Service port {port} is used.")
raise UserErrorException(f"Service port {port} is used.")
# Set host to localhost, only allow request from localhost.
app.logger.info(f"Start Prompt Flow Service on http://localhost:{port}, version: {get_promptflow_sdk_version()}")
waitress.serve(app, host="127.0.0.1", port=port)
os.system(" ".join(["nohup"] + cmd + ["&"]))
is_healthy = check_pfs_service_status(port)
if is_healthy:
app.logger.info(
f"Start Prompt Flow Service on http://localhost:{port}, version: {get_promptflow_sdk_version()}"
)
else:
app.logger.warning(f"Pfs service start failed in {port}.")


def main():
Expand All @@ -81,9 +122,6 @@ def main():
return json.dumps(version_dict, ensure_ascii=False, indent=2, sort_keys=True, separators=(",", ": ")) + "\n"
if len(command_args) == 0:
command_args.append("-h")

# User Agent will be set based on header in request, so not set globally here.
os.environ[PF_NO_INTERACTIVE_LOGIN] = "true"
entry(command_args)


Expand All @@ -106,7 +144,7 @@ def entry(command_args):
activity_name = _get_cli_activity_name(cli=parser.prog, args=args)
logger = get_telemetry_logger()

with log_activity(logger, activity_name, activity_type=ActivityType.INTERNALCALL):
with log_activity(logger, activity_name, activity_type=ActivityType.PUBLICAPI):
run_command(args)


Expand Down
33 changes: 33 additions & 0 deletions src/promptflow/promptflow/_sdk/_service/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
# ---------------------------------------------------------
import getpass
import socket
import time
from dataclasses import InitVar, dataclass, field
from datetime import datetime
from functools import wraps

import psutil
import requests
from flask import abort, make_response, request

from promptflow._sdk._constants import DEFAULT_ENCODING, HOME_PROMPT_FLOW_DIR, PF_SERVICE_PORT_FILE
from promptflow._sdk._errors import ConnectionNotFoundError, RunNotFoundError
from promptflow._sdk._utils import read_write_by_user
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow._utils.yaml_utils import dump_yaml, load_yaml
from promptflow._version import VERSION
from promptflow.exceptions import PromptflowException, UserErrorException

logger = get_cli_sdk_logger()


def local_user_only(func):
@wraps(func)
Expand Down Expand Up @@ -100,6 +105,34 @@ def make_response_no_content():
return make_response("", 204)


def is_pfs_service_healthy(pfs_port) -> bool:
"""Check if pfs service is running."""
try:
response = requests.get("http://localhost:{}/heartbeat".format(pfs_port))
if response.status_code == 200:
logger.debug(f"Pfs service is already running on port {pfs_port}.")
return True
except Exception: # pylint: disable=broad-except
pass
logger.warning(f"Pfs service can't be reached through port {pfs_port}, will try to start/force restart pfs.")
return False


def check_pfs_service_status(pfs_port, time_delay=5, time_threshold=30) -> bool:
wait_time = time_delay
time.sleep(time_delay)
is_healthy = is_pfs_service_healthy(pfs_port)
while is_healthy is False and time_threshold > wait_time:
logger.info(
f"Pfs service is not ready. It has been waited for {wait_time}s, will wait for at most "
f"{time_threshold}s."
)
wait_time += time_delay
time.sleep(time_delay)
is_healthy = is_pfs_service_healthy(pfs_port)
return is_healthy


@dataclass
class ErrorInfo:
exception: InitVar[Exception]
Expand Down
8 changes: 8 additions & 0 deletions src/promptflow/promptflow/_sdk/entities/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ class _LineRunData:
"""Basic data structure for line run, no matter if it is a main or evaluation."""

line_run_id: str
trace_id: str
root_span_id: str
inputs: typing.Dict
outputs: typing.Dict
start_time: datetime.datetime
Expand Down Expand Up @@ -180,6 +182,8 @@ def _from_root_span(span: Span) -> "_LineRunData":
cumulative_token_count = None
return _LineRunData(
line_run_id=line_run_id,
trace_id=span.trace_id,
root_span_id=span.span_id,
inputs=json.loads(attributes[SpanAttributeFieldName.INPUTS]),
outputs=json.loads(attributes[SpanAttributeFieldName.OUTPUT]),
start_time=start_time,
Expand All @@ -197,6 +201,8 @@ class LineRun:
"""Line run is an abstraction of spans related to prompt flow."""

line_run_id: str
trace_id: str
root_span_id: str
inputs: typing.Dict
outputs: typing.Dict
start_time: str
Expand Down Expand Up @@ -228,6 +234,8 @@ def _from_spans(spans: typing.List[Span]) -> "LineRun":
evaluations[eval_name] = eval_line_run_data
return LineRun(
line_run_id=main_line_run_data.line_run_id,
trace_id=main_line_run_data.trace_id,
root_span_id=main_line_run_data.root_span_id,
inputs=main_line_run_data.inputs,
outputs=main_line_run_data.outputs,
start_time=main_line_run_data.start_time.isoformat(),
Expand Down
54 changes: 8 additions & 46 deletions src/promptflow/promptflow/_trace/_start_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
# ---------------------------------------------------------

import os
import platform
import sys
import time
import typing
import uuid

import requests
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
Expand All @@ -19,11 +15,10 @@
from promptflow._constants import SpanAttributeFieldName
from promptflow._core.openai_injector import inject_openai_api
from promptflow._core.operation_context import OperationContext
from promptflow._sdk._service.utils.utils import is_pfs_service_healthy
from promptflow._utils.logger_utils import get_cli_sdk_logger

_logger = get_cli_sdk_logger()
time_threshold = 30
time_delay = 10


def start_trace(*, session: typing.Optional[str] = None, **kwargs):
Expand All @@ -38,7 +33,7 @@ def start_trace(*, session: typing.Optional[str] = None, **kwargs):
from promptflow._sdk._service.utils.utils import get_port_from_config

pfs_port = get_port_from_config(create_if_not_exists=True)
_start_pfs_in_background(pfs_port)
_start_pfs(pfs_port)
_logger.debug("PFS is serving on port %s", pfs_port)

# provision a session
Expand Down Expand Up @@ -70,51 +65,18 @@ def start_trace(*, session: typing.Optional[str] = None, **kwargs):
print(f"You can view the trace from UI url: {ui_url}")


def _start_pfs_in_background(pfs_port) -> None:
"""Start a pfs process in background."""
def _start_pfs(pfs_port) -> None:
from promptflow._sdk._service.entry import entry
from promptflow._sdk._service.utils.utils import is_port_in_use

args = [sys.executable, "-m", "promptflow._sdk._service.entry", "start", "--port", str(pfs_port)]
command_args = ["start", "--port", str(pfs_port)]
if is_port_in_use(pfs_port):
_logger.warning(f"Service port {pfs_port} is used.")
if _check_pfs_service_status(pfs_port) is True:
if is_pfs_service_healthy(pfs_port) is True:
return
else:
args += ["--force"]
# Start a pfs process using detach mode
if platform.system() == "Windows":
os.spawnv(os.P_DETACH, sys.executable, args)
else:
os.system(" ".join(["nohup"] + args + ["&"]))

wait_time = time_delay
time.sleep(time_delay)
is_healthy = _check_pfs_service_status(pfs_port)
while is_healthy is False and time_threshold > wait_time:
_logger.info(
f"Pfs service is not ready. It has been waited for {wait_time}s, will wait for at most "
f"{time_threshold}s."
)
wait_time += time_delay
time.sleep(time_delay)
is_healthy = _check_pfs_service_status(pfs_port)

if is_healthy is False:
_logger.error(f"Pfs service start failed in {pfs_port}.")
sys.exit(1)


def _check_pfs_service_status(pfs_port) -> bool:
"""Check if pfs service is running."""
try:
response = requests.get("http://localhost:{}/heartbeat".format(pfs_port))
if response.status_code == 200:
_logger.info(f"Pfs service is already running on port {pfs_port}.")
return True
except Exception: # pylint: disable=broad-except
pass
_logger.warning(f"Pfs service can't be reached through port {pfs_port}, will try to start/force restart pfs.")
return False
command_args += ["--force"]
entry(command_args)


def _provision_session(session_id: typing.Optional[str] = None) -> str:
Expand Down
Loading

0 comments on commit 425e714

Please sign in to comment.