diff --git a/pkg/cli/build.go b/pkg/cli/build.go index ebdc2e1d92..92dda89c6b 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -64,7 +64,7 @@ func buildCommand(cmd *cobra.Command, args []string) error { imageName = config.DockerImageName(projectDir) } - err = config.ValidateModelPythonVersion(cfg.Build.PythonVersion) + err = config.ValidateModelPythonVersion(cfg) if err != nil { return err } diff --git a/pkg/config/config.go b/pkg/config/config.go index 4a282deebd..f3e28eb7f6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,9 +30,10 @@ var ( // TODO(andreas): suggest valid torchvision versions (e.g. if the user wants to use 0.8.0, suggest 0.8.1) const ( - MinimumMajorPythonVersion int = 3 - MinimumMinorPythonVersion int = 8 - MinimumMajorCudaVersion int = 11 + MinimumMajorPythonVersion int = 3 + MinimumMinorPythonVersion int = 8 + MinimumMinorPythonVersionForConcurrency int = 11 + MinimumMajorCudaVersion int = 11 ) type RunItem struct { @@ -58,16 +59,21 @@ type Build struct { pythonRequirementsContent []string } +type Concurrency struct { + Max int `json:"max,omitempty" yaml:"max"` +} + type Example struct { Input map[string]string `json:"input" yaml:"input"` Output string `json:"output" yaml:"output"` } type Config struct { - Build *Build `json:"build" yaml:"build"` - Image string `json:"image,omitempty" yaml:"image"` - Predict string `json:"predict,omitempty" yaml:"predict"` - Train string `json:"train,omitempty" yaml:"train"` + Build *Build `json:"build" yaml:"build"` + Image string `json:"image,omitempty" yaml:"image"` + Predict string `json:"predict,omitempty" yaml:"predict"` + Train string `json:"train,omitempty" yaml:"train"` + Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"` } func DefaultConfig() *Config { @@ -244,7 +250,9 @@ func splitPythonVersion(version string) (major int, minor int, err error) { return major, minor, nil } -func ValidateModelPythonVersion(version string) error { +func ValidateModelPythonVersion(cfg *Config) error { + version := cfg.Build.PythonVersion + // we check for minimum supported here major, minor, err := splitPythonVersion(version) if err != nil { @@ -255,6 +263,10 @@ func ValidateModelPythonVersion(version string) error { return fmt.Errorf("minimum supported Python version is %d.%d. requested %s", MinimumMajorPythonVersion, MinimumMinorPythonVersion, version) } + if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency { + return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s", + MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version) + } return nil } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index d0ec825ab8..934fba5a70 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -13,47 +13,68 @@ import ( func TestValidateModelPythonVersion(t *testing.T) { testCases := []struct { - name string - input string - expectedErr bool + name string + pythonVersion string + concurrencyMax int + expectedErr string }{ { - name: "ValidVersion", - input: "3.12", - expectedErr: false, + name: "ValidVersion", + pythonVersion: "3.12", }, { - name: "MinimumVersion", - input: "3.8", - expectedErr: false, + name: "MinimumVersion", + pythonVersion: "3.8", }, { - name: "FullyQualifiedVersion", - input: "3.12.1", - expectedErr: false, + name: "MinimumVersionForConcurrency", + pythonVersion: "3.11", + concurrencyMax: 5, }, { - name: "InvalidFormat", - input: "3-12", - expectedErr: true, + name: "TooOldForConcurrency", + pythonVersion: "3.8", + concurrencyMax: 5, + expectedErr: "when concurrency.max is set, minimum supported Python version is 3.11. requested 3.8", }, { - name: "InvalidMissingMinor", - input: "3", - expectedErr: true, + name: "FullyQualifiedVersion", + pythonVersion: "3.12.1", }, { - name: "LessThanMinimum", - input: "3.7", - expectedErr: true, + name: "InvalidFormat", + pythonVersion: "3-12", + expectedErr: "invalid Python version format: missing minor version in 3-12", + }, + { + name: "InvalidMissingMinor", + pythonVersion: "3", + expectedErr: "invalid Python version format: missing minor version in 3", + }, + { + name: "LessThanMinimum", + pythonVersion: "3.7", + expectedErr: "minimum supported Python version is 3.8. requested 3.7", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := ValidateModelPythonVersion(tc.input) - if tc.expectedErr { - require.Error(t, err) + cfg := &Config{ + Build: &Build{ + PythonVersion: tc.pythonVersion, + }, + } + if tc.concurrencyMax != 0 { + // the Concurrency key is optional, only populate it if + // concurrencyMax is a non-default value + cfg.Concurrency = &Concurrency{ + Max: tc.concurrencyMax, + } + } + err := ValidateModelPythonVersion(cfg) + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) } else { require.NoError(t, err) } @@ -649,17 +670,6 @@ func TestBlankBuild(t *testing.T) { require.Equal(t, false, config.Build.GPU) } -func TestModelPythonVersionValidation(t *testing.T) { - err := ValidateModelPythonVersion("3.8") - require.NoError(t, err) - err = ValidateModelPythonVersion("3.8.1") - require.NoError(t, err) - err = ValidateModelPythonVersion("3.7") - require.Equal(t, "minimum supported Python version is 3.8. requested 3.7", err.Error()) - err = ValidateModelPythonVersion("3.7.1") - require.Equal(t, "minimum supported Python version is 3.8. requested 3.7.1", err.Error()) -} - func TestSplitPinnedPythonRequirement(t *testing.T) { testCases := []struct { input string diff --git a/pyproject.toml b/pyproject.toml index cfe833b8ac..e467288ba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ tests = [ "numpy", "pillow", "pytest", + "pytest-asyncio", "pytest-httpserver", "pytest-timeout", "pytest-xdist", @@ -70,6 +71,9 @@ reportUnusedExpression = "warning" [tool.pyright.defineConstant] PYDANTIC_V2 = true +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" + [tool.setuptools] include-package-data = false diff --git a/python/cog/__init__.py b/python/cog/__init__.py index d6ad24d9af..72f1399cd0 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -6,6 +6,7 @@ from .mimetypes_ext import install_mime_extensions from .server.scope import current_scope, emit_metric from .types import ( + AsyncConcatenateIterator, ConcatenateIterator, ExperimentalFeatureWarning, File, @@ -26,6 +27,7 @@ "__version__", "current_scope", "emit_metric", + "AsyncConcatenateIterator", "BaseModel", "BasePredictor", "ConcatenateIterator", diff --git a/python/cog/config.py b/python/cog/config.py index ec367e4fb7..bc560a1ddc 100644 --- a/python/cog/config.py +++ b/python/cog/config.py @@ -33,6 +33,7 @@ COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP" COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP" COG_GPU_ENV_VAR = "COG_GPU" +COG_MAX_CONCURRENCY_ENV_VAR = "COG_MAX_CONCURRENCY" PREDICT_METHOD_NAME = "predict" TRAIN_METHOD_NAME = "train" @@ -101,6 +102,12 @@ def requires_gpu(self) -> bool: """Whether this cog requires the use of a GPU.""" return bool(self._cog_config.get("build", {}).get("gpu", False)) + @property + @env_property(COG_MAX_CONCURRENCY_ENV_VAR) + def max_concurrency(self) -> int: + """The maximum concurrency of predictions supported by this model. Defaults to 1.""" + return int(self._cog_config.get("concurrency", {}).get("max", 1)) + def _predictor_code( self, module_path: str, diff --git a/python/cog/server/helpers.py b/python/cog/server/helpers.py index 3aa967ed2a..990fa49cb0 100644 --- a/python/cog/server/helpers.py +++ b/python/cog/server/helpers.py @@ -33,7 +33,7 @@ def __init__( callback: Callable[[str, str], None], tee: bool = False, ) -> None: - super().__init__(buffer, line_buffering=True) + super().__init__(buffer) self._callback = callback self._tee = tee @@ -44,11 +44,10 @@ def write(self, s: str) -> int: self._buffer.append(s) if self._tee: super().write(s) - else: - # If we're not teeing, we have to handle automatic flush on - # newline. When `tee` is true, this is handled by the write method. - if "\n" in s or "\r" in s: - self.flush() + + if "\n" in s or "\r" in s: + self.flush() + return length def flush(self) -> None: diff --git a/python/cog/server/http.py b/python/cog/server/http.py index ae350fcfb3..f7da747dd4 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -165,9 +165,11 @@ async def start_shutdown() -> Any: return app worker = make_worker( - predictor_ref=cog_config.get_predictor_ref(mode=mode), is_async=is_async + predictor_ref=cog_config.get_predictor_ref(mode=mode), + is_async=is_async, + max_concurrency=cog_config.max_concurrency, ) - runner = PredictionRunner(worker=worker) + runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency) class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)): pass @@ -219,7 +221,7 @@ class TrainingRequest( response_model=TrainingResponse, response_model_exclude_unset=True, ) - def train( + async def train( request: TrainingRequest = Body(default=None), prefer: Optional[str] = Header(default=None), traceparent: Optional[str] = Header( @@ -232,7 +234,7 @@ def train( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=TrainingResponse, respond_async=respond_async, @@ -243,7 +245,7 @@ def train( response_model=TrainingResponse, response_model_exclude_unset=True, ) - def train_idempotent( + async def train_idempotent( training_id: str = Path(..., title="Training ID"), request: TrainingRequest = Body(..., title="Training Request"), prefer: Optional[str] = Header(default=None), @@ -280,7 +282,7 @@ def train_idempotent( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=TrainingResponse, respond_async=respond_async, @@ -359,7 +361,7 @@ async def predict( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=PredictionResponse, respond_async=respond_async, @@ -407,13 +409,13 @@ async def predict_idempotent( respond_async = prefer == "respond-async" with trace_context(make_trace_context(traceparent, tracestate)): - return _predict( + return await _predict( request=request, response_type=PredictionResponse, respond_async=respond_async, ) - def _predict( + async def _predict( *, request: Optional[PredictionRequest], response_type: Type[schema.PredictionResponse], @@ -455,7 +457,7 @@ def _predict( ) # Otherwise, wait for the prediction to complete... - predict_task.wait() + await predict_task.wait_async() # ...and return the result. if PYDANTIC_V2: diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 7468be2572..1b752ae27a 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -1,5 +1,8 @@ +import asyncio import io +import threading import traceback +import uuid from abc import ABC, abstractmethod from concurrent.futures import Future from datetime import datetime, timezone @@ -60,13 +63,15 @@ class PredictionRunner: def __init__( self, *, + max_concurrency: int = 1, worker: Worker, ) -> None: self._worker = worker + self._max_concurrency = max_concurrency self._setup_task: Optional[SetupTask] = None - self._predict_task: Optional[PredictTask] = None - self._prediction_id = None + self._predict_tasks: Dict[str, PredictTask] = {} + self._predict_tasks_lock = threading.Lock() def setup(self) -> "SetupTask": assert self._setup_task is None, "do not call setup twice" @@ -88,8 +93,14 @@ def predict( task_kwargs = task_kwargs or {} - self._predict_task = PredictTask(prediction, **task_kwargs) - self._prediction_id = prediction.id + tag = prediction.id + if tag is None: + tag = uuid.uuid4().hex + + task = PredictTask(prediction, **task_kwargs) + + with self._predict_tasks_lock: + self._predict_tasks[tag] = task if isinstance(prediction.input, BaseInput): if PYDANTIC_V2: @@ -101,18 +112,23 @@ def predict( else: payload = prediction.input.copy() - sid = self._worker.subscribe(self._predict_task.handle_event) - self._predict_task.track(self._worker.predict(payload)) - self._predict_task.add_done_callback(lambda _: self._worker.unsubscribe(sid)) + sid = self._worker.subscribe(task.handle_event, tag=tag) + task.track(self._worker.predict(payload, tag=tag)) + task.add_done_callback(self._task_done_callback(tag, sid)) + + return task + + def _task_done_callback(self, tag: str, sid: int) -> Callable[[Any], None]: + def _callback(_) -> None: + self._worker.unsubscribe(sid) + with self._predict_tasks_lock: + del self._predict_tasks[tag] - return self._predict_task + return _callback def get_predict_task(self, id: str) -> Optional["PredictTask"]: - if not self._predict_task: - return None - if self._predict_task.result.id != id: - return None - return self._predict_task + with self._predict_tasks_lock: + return self._predict_tasks.get(id, None) def is_busy(self) -> bool: try: @@ -124,9 +140,13 @@ def is_busy(self) -> bool: def cancel(self, prediction_id: str) -> None: if not prediction_id: raise ValueError("prediction_id is required") - if self._prediction_id != prediction_id: - raise UnknownPredictionError("id mismatch") - self._worker.cancel() + with self._predict_tasks_lock: + if ( + prediction_id not in self._predict_tasks + or self._predict_tasks[prediction_id].done() + ): + raise UnknownPredictionError("unknown prediction id") + self._worker.cancel(tag=prediction_id) def _raise_if_busy(self) -> None: if self._setup_task is None: @@ -135,9 +155,17 @@ def _raise_if_busy(self) -> None: if not self._setup_task.done(): # Setup is still running. raise RunnerBusyError("setup is not complete") - if self._predict_task is not None and not self._predict_task.done(): - # Prediction is still running. - raise RunnerBusyError("prediction running") + + with self._predict_tasks_lock: + processing_tasks = [ + id for id in self._predict_tasks if not self._predict_tasks[id].done() + ] + + if len(processing_tasks) >= self._max_concurrency: + # We're at max concurrency + if self._max_concurrency == 1: + raise RunnerBusyError("prediction running") + raise RunnerBusyError("max predictions running") T = TypeVar("T") @@ -317,6 +345,11 @@ def done(self) -> bool: assert self._fut, "call track before checking done" return self._fut.done() + async def wait_async(self) -> None: + assert self._fut, "call track before waiting" + await asyncio.wrap_future(self._fut) + return None + def wait(self, timeout: Optional[float] = None) -> None: assert self._fut, "call track before waiting" self._fut.result(timeout=timeout) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 713e0a38f8..2cba7c3f7e 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -11,6 +11,7 @@ import traceback import types import uuid +import warnings from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto, unique from multiprocessing.connection import Connection @@ -351,6 +352,7 @@ def __init__( *, is_async: bool, events: Connection, + max_concurrency: int = 1, tee_output: bool = True, ) -> None: self._predictor_ref = predictor_ref @@ -360,6 +362,7 @@ def __init__( ) self._tee_output = tee_output self._cancelable = False + self._max_concurrency = max_concurrency # for synchronous predictors only! async predictors use _tag_var instead self._sync_tag: Optional[str] = None @@ -394,6 +397,7 @@ def run(self) -> None: # it has sent a error Done event and we're done here. if not self._predictor: return + self._predictor.log = self._log # type: ignore predict = get_predict(self._predictor) if self._is_async: @@ -459,6 +463,25 @@ def _setup( # Could be a function or a class if hasattr(self._predictor, "setup"): run_setup(self._predictor) + + predict = get_predict(self._predictor) + + is_async_predictor = inspect.iscoroutinefunction( + predict + ) or inspect.isasyncgenfunction(predict) + + # Async models require python >= 3.11 so we can use asyncio.TaskGroup + # We should check for this before getting to this point + if is_async_predictor and sys.version_info < (3, 11): + raise FatalWorkerException( + "Cog requires Python >=3.11 for `async def predict()` support" + ) + + if self._max_concurrency > 1 and not is_async_predictor: + raise FatalWorkerException( + "max_concurrency > 1 requires an async predict function, e.g. `async def predict()`" + ) + except Exception as e: # pylint: disable=broad-exception-caught traceback.print_exc() done.error = True @@ -512,20 +535,19 @@ async def _aloop( task = None - while True: - e = cast(Envelope, await self._events.recv()) - if isinstance(e.event, Cancel) and task and self._cancelable: - task.cancel() - elif isinstance(e.event, Shutdown): - break - elif isinstance(e.event, PredictionInput): - task = asyncio.create_task( - self._apredict(e.tag, e.event.payload, predict, redirector) - ) - else: - print(f"Got unexpected event: {e.event}", file=sys.stderr) - if task: - await task + async with asyncio.TaskGroup() as tg: + while True: + e = cast(Envelope, await self._events.recv()) + if isinstance(e.event, Cancel) and task and self._cancelable: + task.cancel() + elif isinstance(e.event, Shutdown): + break + elif isinstance(e.event, PredictionInput): + task = tg.create_task( + self._apredict(e.tag, e.event.payload, predict, redirector) + ) + else: + print(f"Got unexpected event: {e.event}", file=sys.stderr) def _predict( self, @@ -707,6 +729,18 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None: Envelope(event=Log(data, source="stderr"), tag=self._current_tag) ) + def _log(self, *messages: str, source: str = "stderr") -> None: + """ + DEPRECATED: This function will be removed in a future version of cog. + """ + warnings.warn( + "log() is deprecated and will be removed in a future version. Use `print` or `logging` module instead", + category=DeprecationWarning, + stacklevel=1, + ) + file = sys.stdout if source == "stdout" else sys.stderr + print(*messages, file=file, end="") + def make_worker( predictor_ref: str, @@ -717,7 +751,11 @@ def make_worker( ) -> Worker: parent_conn, child_conn = _spawn.Pipe() child = _ChildWorker( - predictor_ref, events=child_conn, tee_output=tee_output, is_async=is_async + predictor_ref, + is_async=is_async, + events=child_conn, + tee_output=tee_output, + max_concurrency=max_concurrency, ) parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency) return parent diff --git a/python/cog/types.py b/python/cog/types.py index 29d868c9e7..c27247afa9 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -9,6 +9,7 @@ import urllib.response from typing import ( Any, + AsyncIterator, Dict, Iterator, List, @@ -43,6 +44,7 @@ class ExperimentalFeatureWarning(Warning): class CogConfig(TypedDict): # pylint: disable=too-many-ancestors build: "CogBuildConfig" + concurrency: "CogConcurrencyConfig" image: NotRequired[str] predict: NotRequired[str] train: NotRequired[str] @@ -58,15 +60,19 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest run: Optional[Union[List[str], List[Dict[str, Any]]]] +class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many-ancestors + max: NotRequired[int] + + def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., - description: str = None, - ge: float = None, - le: float = None, - min_length: int = None, - max_length: int = None, - regex: str = None, - choices: List[Union[str, int]] = None, + description: Optional[str] = None, + ge: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + choices: Optional[List[Union[str, int]]] = None, ) -> Any: """Input is similar to pydantic.Field, but doesn't require a default value to be the first argument.""" field_kwargs = { @@ -410,6 +416,12 @@ def get_filename(url: str) -> str: Item = TypeVar("Item") +_concatenate_iterator_schema = { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", +} class ConcatenateIterator(Iterator[Item]): # pylint: disable=abstract-method @@ -445,14 +457,7 @@ def __get_pydantic_json_schema__( ) -> "JsonSchemaValue": # type: ignore # noqa: F821 json_schema = handler(core_schema) json_schema.pop("allOf", None) - json_schema.update( - { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "x-cog-array-display": "concatenate", - } - ) + json_schema.update(_concatenate_iterator_schema) return json_schema else: @@ -465,15 +470,62 @@ def __get_validators__(cls) -> Iterator[Any]: def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: """Defines what this type should be in openapi.json""" field_schema.pop("allOf", None) - field_schema.update( - { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "x-cog-array-display": "concatenate", - } + field_schema.update(_concatenate_iterator_schema) + + +class AsyncConcatenateIterator(AsyncIterator[Item]): + @classmethod + def validate(cls, value: AsyncIterator[Any]) -> AsyncIterator[Any]: + return value + + if PYDANTIC_V2: + from pydantic import GetCoreSchemaHandler + from pydantic.json_schema import JsonSchemaValue + from pydantic_core import CoreSchema + + @classmethod + def __get_pydantic_core_schema__( + cls, + source: Type[Any], # pylint: disable=unused-argument + handler: "pydantic.GetCoreSchemaHandler", # pylint: disable=unused-argument + ) -> "CoreSchema": + from pydantic_core import ( # pylint: disable=import-outside-toplevel + core_schema, ) + return core_schema.union_schema( + [ + core_schema.is_instance_schema(AsyncIterator), + core_schema.no_info_plain_validator_function(cls.validate), + ] + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: "CoreSchema", handler: "pydantic.GetJsonSchemaHandler" + ) -> "JsonSchemaValue": # type: ignore # noqa: F821 + json_schema = handler(core_schema) + json_schema.pop("allOf", None) + json_schema.update(_concatenate_iterator_schema) + return json_schema + else: + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + """Defines what this type should be in openapi.json""" + field_schema.pop("allOf", None) + field_schema.update(_concatenate_iterator_schema) + + @classmethod + def __get_validators__(cls) -> Iterator[Any]: + yield cls.validate + + +def get_filename_from_urlopen(resp: urllib.response.addinfourl) -> str: + mime_type = resp.headers.get_content_type() + extension = mimetypes.guess_extension(mime_type) + return ("file" + extension) if extension else "file" + def _len_bytes(s: str, encoding: str = "utf-8") -> int: return len(s.encode(encoding)) diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 3bf6d71c01..4d5dfd8bb7 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -1,8 +1,9 @@ import os +import sys import threading import time from contextlib import ExitStack -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, Optional, Sequence, Tuple from unittest import mock import pytest @@ -27,6 +28,7 @@ class WorkerConfig: is_async: bool = False setup: bool = True max_concurrency: int = 1 + min_python: Optional[Tuple[int, int]] = None def pytest_make_parametrize_id(config, val): @@ -72,7 +74,9 @@ def uses_predictor_with_client_options(name, **options): ) -def uses_worker(name_or_names, setup=True, max_concurrency=1, is_async=False): +def uses_worker( + name_or_names, setup=True, max_concurrency=1, min_python=None, is_async=False +): """ Decorator for tests that require a Worker instance. `name_or_names` can be a single fixture name, or a sequence (list, tuple) of fixture names. If @@ -81,31 +85,34 @@ def uses_worker(name_or_names, setup=True, max_concurrency=1, is_async=False): If `setup` is True (the default) setup will be run before the test runs. """ if isinstance(name_or_names, (tuple, list)): - values = [ + values = ( WorkerConfig( fixture_name=n, setup=setup, max_concurrency=max_concurrency, + min_python=min_python, is_async=is_async, ) for n in name_or_names - ] + ) else: - values = [ + values = ( WorkerConfig( fixture_name=name_or_names, setup=setup, max_concurrency=max_concurrency, + min_python=min_python, is_async=is_async, ), - ] - return uses_worker_configs(values) + ) + return uses_worker_configs(list(values)) def uses_worker_configs(values: Sequence[WorkerConfig]): """ - Decorator for tests that require a Worker instance. `configs` can be - a sequence of `WorkerConfig` instances. + Decorator for tests that require a Worker instance. The test will be + run once for each worker. `configs` is a sequence (list, tuple, generator) + of WorkerConfig. """ return pytest.mark.parametrize("worker", values, indirect=True) @@ -168,6 +175,13 @@ def static_schema(client) -> dict: @pytest.fixture def worker(request): ref = _fixture_path(request.param.fixture_name) + if ( + request.param.min_python is not None + and sys.version_info < request.param.min_python + ): + pytest.skip( + f"Test requires python {request.param.min_python[0]}.{request.param.min_python[1]}" + ) w = make_worker( predictor_ref=ref, is_async=request.param.is_async, diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index b012e5f159..5e27357213 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -38,18 +38,23 @@ def __call__(self): class FakeWorker: def __init__(self): self.subscribers = {} - self.last_prediction_payload = None + self.subscribers_by_tag = {} self._setup_future = None - self._predict_future = None + self._predict_futures = {} + self.last_prediction_payload = None def subscribe(self, subscriber, tag=None): sid = uuid.uuid4() - self.subscribers[sid] = subscriber + self.subscribers[sid] = tag + if tag not in self.subscribers_by_tag: + self.subscribers_by_tag[tag] = {} + self.subscribers_by_tag[tag][sid] = subscriber return sid def unsubscribe(self, sid): - del self.subscribers[sid] + tag = self.subscribers.pop(sid) + del self.subscribers_by_tag[tag][sid] def setup(self): assert self._setup_future is None @@ -61,32 +66,38 @@ def run_setup(self, events): if isinstance(event, Exception): self._setup_future.set_exception(event) return - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(None, {}).values(): subscriber(event) if isinstance(event, Done): self._setup_future.set_result(event) def predict(self, payload, tag=None): - assert self._predict_future is None or self._predict_future.done() + assert tag not in self._predict_futures or self._predict_futures[tag].done() self.last_prediction_payload = payload - self._predict_future = Future() - return self._predict_future - - def run_predict(self, events): + self._predict_futures[tag] = Future() + print(f"setting {tag}, now {self._predict_futures}") + return self._predict_futures[tag] + + def run_predict(self, events, id=None): + if id is None: + if len(self._predict_futures) != 1: + raise ValueError("Could not guess prediction id, please specify") + id = next(iter(self._predict_futures)) for event in events: if isinstance(event, Exception): - self._predict_future.set_exception(event) + self._predict_futures[id].set_exception(event) return - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(id, {}).values(): subscriber(event) if isinstance(event, Done): - self._predict_future.set_result(event) + print(f"reading {id} from {self._predict_futures}") + self._predict_futures[id].set_result(event) def cancel(self, tag=None): done = Done(canceled=True) - for subscriber in self.subscribers.values(): + for subscriber in self.subscribers_by_tag.get(tag, {}).values(): subscriber(done) - self._predict_future.set_result(done) + self._predict_futures[tag].set_result(done) def test_prediction_runner_setup_success(): @@ -229,11 +240,11 @@ def test_prediction_runner_predict_after_predict_completes(): r.setup() w.run_setup([Done()]) - r.predict(PredictionRequest(input={"text": "giraffes"})) - w.run_predict([Done()]) + r.predict(PredictionRequest(id="p-1", input={"text": "giraffes"})) + w.run_predict([Done()], id="p-1") - r.predict(PredictionRequest(input={"text": "elephants"})) - w.run_predict([Done()]) + r.predict(PredictionRequest(id="p-2", input={"text": "elephants"})) + w.run_predict([Done()], id="p-2") assert w.last_prediction_payload == {"text": "elephants"} @@ -257,6 +268,31 @@ def test_prediction_runner_is_busy(): assert not r.is_busy() +def test_prediction_runner_is_busy_concurrency(): + w = FakeWorker() + r = PredictionRunner(worker=w, max_concurrency=3) + + assert r.is_busy() + + r.setup() + assert r.is_busy() + + w.run_setup([Done()]) + assert not r.is_busy() + + r.predict(PredictionRequest(id="1", input={"text": "elephants"})) + assert not r.is_busy() + + r.predict(PredictionRequest(id="2", input={"text": "elephants"})) + assert not r.is_busy() + + r.predict(PredictionRequest(id="3", input={"text": "elephants"})) + assert r.is_busy() + + w.run_predict([Done()], id="1") + assert not r.is_busy() + + def test_prediction_runner_predict_cancelation(): w = FakeWorker() r = PredictionRunner(worker=w) @@ -299,6 +335,23 @@ def test_prediction_runner_predict_cancelation_multiple_predictions(): assert task2.result.status == Status.CANCELED +def test_prediction_runner_predict_cancelation_concurrent_predictions(): + w = FakeWorker() + r = PredictionRunner(worker=w, max_concurrency=5) + + r.setup() + w.run_setup([Done()]) + + task1 = r.predict(PredictionRequest(id="abcd1234", input={"text": "giraffes"})) + + task2 = r.predict(PredictionRequest(id="defg6789", input={"text": "elephants"})) + + r.cancel("abcd1234") + w.run_predict([Done()], id="defg6789") + assert task1.result.status == Status.CANCELED + assert task2.result.status == Status.SUCCEEDED + + def test_prediction_runner_setup_e2e(): w = make_worker(predictor_ref=_fixture_path("sleep"), is_async=False) r = PredictionRunner(worker=w) diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index cb6a469430..d7aa69cad1 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -1,5 +1,6 @@ import multiprocessing import os +import sys import threading import time import uuid @@ -76,7 +77,7 @@ }, ), ( - WorkerConfig("record_metric_async", is_async=True), + WorkerConfig("record_metric_async", min_python=(3, 11), is_async=True), {"name": ST_NAMES}, { "foo": 123, @@ -90,7 +91,7 @@ }, ), ( - WorkerConfig("emit_metric_async", is_async=True), + WorkerConfig("emit_metric_async", min_python=(3, 11), is_async=True), {"name": ST_NAMES}, { "foo": 123, @@ -105,7 +106,7 @@ lambda x: f"hello, {x['name']}", ), ( - WorkerConfig("hello_world_async", is_async=True), + WorkerConfig("hello_world_async", min_python=(3, 11), is_async=True), {"name": ST_NAMES}, lambda x: f"hello, {x['name']}", ), @@ -132,7 +133,7 @@ "writing to stderr at import time\n", ), ( - WorkerConfig("logging_async", is_async=True, setup=False), + WorkerConfig("logging_async", setup=False, min_python=(3, 11), is_async=True), ("writing to stdout at import time\n" "setting up predictor\n"), "writing to stderr at import time\n", ), @@ -145,12 +146,22 @@ ("WARNING:root:writing log message\n" "writing to stderr\n"), ), ( - WorkerConfig("logging_async", is_async=True), + WorkerConfig("logging_async", min_python=(3, 11), is_async=True), ("writing with print\n"), ("WARNING:root:writing log message\n" "writing to stderr\n"), ), ] +SLEEP_FIXTURES = [ + WorkerConfig("sleep"), + WorkerConfig("sleep_async", min_python=(3, 11), is_async=True), +] + +SLEEP_NO_SETUP_FIXTURES = [ + WorkerConfig("sleep", setup=False), + WorkerConfig("sleep_async", min_python=(3, 11), setup=False, is_async=True), +] + @define class Result: @@ -255,9 +266,11 @@ def test_no_exceptions_from_recoverable_failures(worker): _process(worker, lambda: worker.predict({})) -# TODO test this works with errors and cancelations and the like @uses_worker_configs( - [WorkerConfig("simple"), WorkerConfig("simple_async", is_async=True)] + [ + WorkerConfig("simple"), + WorkerConfig("simple_async", min_python=(3, 11), is_async=True), + ] ) def test_can_subscribe_for_a_specific_tag(worker): tag = "123" @@ -280,12 +293,12 @@ def test_can_subscribe_for_a_specific_tag(worker): worker.unsubscribe(subid) -@uses_worker("sleep_async", is_async=True, max_concurrency=5) +@uses_worker("sleep_async", max_concurrency=5, min_python=(3, 11), is_async=True) def test_can_run_predictions_concurrently_on_async_predictor(worker): subids = [] try: - start = time.time() + start = time.perf_counter() futures = [] results = [] for i in range(5): @@ -299,7 +312,7 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker): for fut in futures: fut.result() - end = time.time() + end = time.perf_counter() duration = end - start # we should take at least 0.5 seconds (the time for 1 prediction) but @@ -319,6 +332,41 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker): worker.unsubscribe(subid) +@pytest.mark.skipif( + sys.version_info >= (3, 11), reason="Testing error message on python versions <3.11" +) +@uses_worker("simple_async", setup=False) +def test_async_predictor_on_python_3_10_or_older_raises_error(worker): + fut = worker.setup() + result = Result() + worker.subscribe(result.handle_event) + + with pytest.raises(FatalWorkerException): + fut.result() + assert result.done + assert result.done.error + assert ( + result.done.error_detail + == "Cog requires Python >=3.11 for `async def predict()` support" + ) + + +@uses_worker("simple", max_concurrency=5, setup=False) +def test_concurrency_with_sync_predictor_raises_error(worker): + fut = worker.setup() + result = Result() + worker.subscribe(result.handle_event) + + with pytest.raises(FatalWorkerException): + fut.result() + assert result.done + assert result.done.error + assert ( + result.done.error_detail + == "max_concurrency > 1 requires an async predict function, e.g. `async def predict()`" + ) + + @uses_worker("stream_redirector_race_condition") def test_stream_redirector_race_condition(worker): """ @@ -403,12 +451,7 @@ def test_predict_logging(worker, expected_stdout, expected_stderr): assert result.stderr == expected_stderr -@uses_worker_configs( - [ - WorkerConfig("sleep", setup=False), - WorkerConfig("sleep_async", is_async=True, setup=False), - ] -) +@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) def test_cancel_is_safe(worker): """ Calls to cancel at any time should not result in unexpected things @@ -442,12 +485,7 @@ def test_cancel_is_safe(worker): assert result2.output == "done in 0.1 seconds" -@uses_worker_configs( - [ - WorkerConfig("sleep", setup=False), - WorkerConfig("sleep_async", is_async=True, setup=False), - ] -) +@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) def test_cancel_idempotency(worker): """ Multiple calls to cancel within the same prediction, while not necessary or @@ -479,9 +517,7 @@ def cancel_a_bunch(_): assert result2.output == "done in 0.1 seconds" -@uses_worker_configs( - [WorkerConfig("sleep"), WorkerConfig("sleep_async", is_async=True)] -) +@uses_worker_configs(SLEEP_FIXTURES) def test_cancel_multiple_predictions(worker): """ Multiple predictions cancelled in a row shouldn't be a problem. This test @@ -499,9 +535,7 @@ def test_cancel_multiple_predictions(worker): assert not worker.predict({"sleep": 0}).result().canceled -@uses_worker_configs( - [WorkerConfig("sleep"), WorkerConfig("sleep_async", is_async=True)] -) +@uses_worker_configs(SLEEP_FIXTURES) def test_graceful_shutdown(worker): """ On shutdown, the worker should finish running the current prediction, and diff --git a/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml b/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml new file mode 100644 index 0000000000..04d04bf7c8 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-sleep-project/cog.yaml @@ -0,0 +1,5 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" +concurrency: + max: 5 diff --git a/test-integration/test_integration/fixtures/async-sleep-project/predict.py b/test-integration/test_integration/fixtures/async-sleep-project/predict.py new file mode 100644 index 0000000000..e6c65797a0 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-sleep-project/predict.py @@ -0,0 +1,9 @@ +import asyncio + +from cog import BasePredictor + + +class Predictor(BasePredictor): + async def predict(self, s: str, sleep: float) -> str: + await asyncio.sleep(sleep) + return f"wake up {s}" diff --git a/test-integration/test_integration/fixtures/async-string-project/cog.yaml b/test-integration/test_integration/fixtures/async-string-project/cog.yaml new file mode 100644 index 0000000000..7b6d5d4dce --- /dev/null +++ b/test-integration/test_integration/fixtures/async-string-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/async-string-project/predict.py b/test-integration/test_integration/fixtures/async-string-project/predict.py new file mode 100644 index 0000000000..fb2805c794 --- /dev/null +++ b/test-integration/test_integration/fixtures/async-string-project/predict.py @@ -0,0 +1,6 @@ +from cog import BasePredictor + + +class Predictor(BasePredictor): + async def predict(self, s: str) -> str: + return "hello " + s diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 4065b69927..459f09f03e 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -1,6 +1,8 @@ +import asyncio import pathlib import shutil import subprocess +import time from pathlib import Path import httpx @@ -27,6 +29,20 @@ def test_predict_takes_string_inputs_and_returns_strings_to_stdout(): assert "falling back to slow loader" in result.stderr +def test_predict_supports_async_predictors(): + project_dir = Path(__file__).parent / "fixtures/async-string-project" + result = subprocess.run( + ["cog", "predict", "--debug", "-i", "s=world"], + cwd=project_dir, + check=True, + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + # stdout should be clean without any log messages so it can be piped to other commands + assert result.stdout == "hello world\n" + + def test_predict_takes_int_inputs_and_returns_ints_to_stdout(): project_dir = Path(__file__).parent / "fixtures/int-project" result = subprocess.run( @@ -322,3 +338,34 @@ def test_predict_with_subprocess_in_setup(fixture_name): assert response.status_code == 200, str(response) assert busy_count < 10 + + +@pytest.mark.asyncio +async def test_concurrent_predictions(): + async def make_request(i: int) -> httpx.Response: + return await client.post( + f"{addr}/predictions", + json={ + "id": f"id-{i}", + "input": {"s": f"sleepyhead{i}", "sleep": 1.0}, + }, + ) + + with cog_server_http_run( + Path(__file__).parent / "fixtures" / "async-sleep-project" + ) as addr: + async with httpx.AsyncClient() as client: + tasks = [] + start = time.perf_counter() + async with asyncio.TaskGroup() as tg: + for i in range(5): + tasks.append(tg.create_task(make_request(i))) + # give time for all of the predictions to be accepted, but not completed + await asyncio.sleep(0.2) + # we shut the server down, but expect all running predictions to complete + await client.post(f"{addr}/shutdown") + end = time.perf_counter() + assert (end - start) < 3.0 # ensure the predictions ran concurrently + for i, task in enumerate(tasks): + assert task.result().status_code == 200 + assert task.result().json()["output"] == f"wake up sleepyhead{i}"