diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 7278357913..35b5ba1b84 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from concurrent.futures import Future from datetime import datetime, timezone -from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar import requests import structlog @@ -11,7 +11,7 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -from .. import schema, types +from .. import schema from ..files import put_file_to_signed_endpoint from ..json import upload_files from ..predictor import BaseInput @@ -89,7 +89,7 @@ def predict( task_kwargs = task_kwargs or {} - self._predict_task = create_predict_task(prediction, **task_kwargs) + self._predict_task = PredictTask(prediction, **task_kwargs) self._prediction_id = prediction.id if isinstance(prediction.input, BaseInput): @@ -225,32 +225,6 @@ def _handle_done(self, f: "Future[Done]") -> None: self.failed() -def create_predict_task( - prediction: schema.PredictionRequest, - upload_url: Optional[str] = None, -) -> "PredictTask": - response = schema.PredictionResponse(**prediction.dict()) - - webhook = prediction.webhook - events_filter = ( - prediction.webhook_events_filter or schema.WebhookEvent.default_events() - ) - - webhook_sender = None - if webhook is not None: - webhook_sender = webhook_caller_filtered(webhook, set(events_filter)) - - file_uploader = None - if upload_url is not None: - file_uploader = generate_file_uploader(upload_url, prediction_id=prediction.id) - - event_handler = PredictTask( - response, webhook_sender=webhook_sender, file_uploader=file_uploader - ) - - return event_handler - - def generate_file_uploader( upload_url: str, prediction_id: Optional[str] ) -> Callable[[Any], Any]: @@ -270,27 +244,37 @@ def upload_file(fh: io.IOBase) -> str: class PredictTask(Task[schema.PredictionResponse]): def __init__( self, - p: schema.PredictionResponse, - webhook_sender: Optional[Callable[[Any, schema.WebhookEvent], None]] = None, - file_uploader: Optional[Callable[[Any], Any]] = None, + prediction_request: schema.PredictionRequest, + upload_url: Optional[str] = None, ) -> None: - super().__init__() - - self._log = log.bind(prediction_id=p.id) + self._log = log.bind(prediction_id=prediction_request.id) self._log.info("starting prediction") self._fut: "Optional[Future[Done]]" = None - self._p = p + self._p = schema.PredictionResponse(**prediction_request.dict()) self._p.status = schema.Status.PROCESSING self._output_type_multi = None self._p.output = None self._p.logs = "" self._p.started_at = datetime.now(tz=timezone.utc) - self._webhook_sender = webhook_sender - self._file_uploader = file_uploader + self._webhook_sender = None + if prediction_request.webhook: + self._webhook_sender = webhook_caller_filtered( + str(prediction_request.webhook), + set( + prediction_request.webhook_events_filter + or schema.WebhookEvent.default_events() + ), + ) + + self._file_uploader = None + if upload_url: + self._file_uploader = generate_file_uploader( + upload_url, prediction_id=self._p.id + ) @property def result(self) -> schema.PredictionResponse: diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index d12d336cdd..e2485c63c7 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -6,7 +6,7 @@ import pytest -from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent +from cog.schema import PredictionRequest, Status, WebhookEvent from cog.server.eventtypes import Done, Log from cog.server.runner import ( PredictionRunner, @@ -387,55 +387,78 @@ def test_setup_task(log, result): def test_predict_task(): - p = PredictionResponse(input={"hello": "there"}) + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook=None, + ) t = PredictTask(p) - assert p.status == Status.PROCESSING - assert p.output is None - assert p.logs == "" - assert isinstance(p.started_at, datetime) + assert t.result.status == Status.PROCESSING + assert t.result.output is None + assert t.result.logs == "" + assert isinstance(t.result.started_at, datetime) t.set_output_type(multi=False) t.append_output("giraffes") - assert p.output == "giraffes" + assert t.result.output == "giraffes" def test_predict_task_multi(): - p = PredictionResponse(input={"hello": "there"}) + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook=None, + ) t = PredictTask(p) - assert p.status == Status.PROCESSING - assert p.output is None - assert p.logs == "" - assert isinstance(p.started_at, datetime) + assert t.result.status == Status.PROCESSING + assert t.result.output is None + assert t.result.logs == "" + assert isinstance(t.result.started_at, datetime) t.set_output_type(multi=True) t.append_output("elephant") t.append_output("duck") - assert p.output == ["elephant", "duck"] + assert t.result.output == ["elephant", "duck"] t.append_logs("running a prediction\n") t.append_logs("still running\n") - assert p.logs == "running a prediction\nstill running\n" + assert t.result.logs == "running a prediction\nstill running\n" t.succeeded() - assert p.status == Status.SUCCEEDED - assert isinstance(p.completed_at, datetime) + assert t.result.status == Status.SUCCEEDED + assert isinstance(t.result.completed_at, datetime) t.failed("oops") - assert p.status == Status.FAILED - assert p.error == "oops" - assert isinstance(p.completed_at, datetime) + assert t.result.status == Status.FAILED + assert t.result.error == "oops" + assert isinstance(t.result.completed_at, datetime) t.canceled() - assert p.status == Status.CANCELED - assert isinstance(p.completed_at, datetime) + assert t.result.status == Status.CANCELED + assert isinstance(t.result.completed_at, datetime) def test_predict_task_webhook_sender(): - s = mock.Mock() - p = PredictionResponse(input={"hello": "there"}) - t = PredictTask(p, webhook_sender=s) + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook="https://a.url.honest", + ) + t = PredictTask(p) + t._webhook_sender = mock.Mock() + t.track(Future()) + + t._webhook_sender.assert_called_once_with(mock.ANY, WebhookEvent.START) + actual = t._webhook_sender.call_args[0][0] + assert actual.status == "processing" t.set_output_type(multi=True) t.append_output("elephant") @@ -444,14 +467,14 @@ def test_predict_task_webhook_sender(): t.append_logs("running a prediction\n") t.append_logs("still running\n") - s.reset_mock() + t._webhook_sender.reset_mock() t.succeeded() - s.assert_called_once_with( + t._webhook_sender.assert_called_once_with( mock.ANY, WebhookEvent.COMPLETED, ) - actual = s.call_args[0][0] + actual = t._webhook_sender.call_args[0][0] assert actual.input == {"hello": "there"} assert actual.output == ["elephant", "duck"] assert actual.logs == "running a prediction\nstill running\n" @@ -460,111 +483,140 @@ def test_predict_task_webhook_sender(): def test_predict_task_webhook_sender_intermediate(): - s = mock.Mock() - p = PredictionResponse(input={"hello": "there"}) - t = PredictTask(p, webhook_sender=s) + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook="https://a.url.honest", + ) + t = PredictTask(p) + t._webhook_sender = mock.Mock() + t.track(Future()) - s.assert_called_once_with(mock.ANY, WebhookEvent.START) - actual = s.call_args[0][0] + t._webhook_sender.assert_called_once_with(mock.ANY, WebhookEvent.START) + actual = t._webhook_sender.call_args[0][0] assert actual.status == "processing" - s.reset_mock() + t._webhook_sender.reset_mock() t.set_output_type(multi=False) t.append_output("giraffes") - assert s.call_count == 0 + assert t._webhook_sender.call_count == 0 def test_predict_task_webhook_sender_intermediate_multi(): - s = mock.Mock() - p = PredictionResponse(input={"hello": "there"}) - t = PredictTask(p, webhook_sender=s) + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook="https://a.url.honest", + ) + t = PredictTask(p) + t._webhook_sender = mock.Mock() + t.track(Future()) - s.assert_called_once_with(mock.ANY, WebhookEvent.START) - actual = s.call_args[0][0] + t._webhook_sender.assert_called_once_with(mock.ANY, WebhookEvent.START) + actual = t._webhook_sender.call_args[0][0] assert actual.status == "processing" - s.reset_mock() + t._webhook_sender.reset_mock() t.set_output_type(multi=True) t.append_output("elephant") - print(s.call_args_list) - assert s.call_count == 1 - actual = s.call_args_list[0][0][0] + print(t._webhook_sender.call_args_list) + assert t._webhook_sender.call_count == 1 + actual = t._webhook_sender.call_args_list[0][0][0] assert actual.output == ["elephant"] - assert s.call_args_list[0][0][1] == WebhookEvent.OUTPUT + assert t._webhook_sender.call_args_list[0][0][1] == WebhookEvent.OUTPUT - s.reset_mock() + t._webhook_sender.reset_mock() t.append_output("duck") - assert s.call_count == 1 - actual = s.call_args_list[0][0][0] + assert t._webhook_sender.call_count == 1 + actual = t._webhook_sender.call_args_list[0][0][0] assert actual.output == ["elephant", "duck"] - assert s.call_args_list[0][0][1] == WebhookEvent.OUTPUT + assert t._webhook_sender.call_args_list[0][0][1] == WebhookEvent.OUTPUT - s.reset_mock() + t._webhook_sender.reset_mock() t.append_logs("running a prediction\n") - assert s.call_count == 1 - actual = s.call_args_list[0][0][0] + assert t._webhook_sender.call_count == 1 + actual = t._webhook_sender.call_args_list[0][0][0] assert actual.logs == "running a prediction\n" - assert s.call_args_list[0][0][1] == WebhookEvent.LOGS + assert t._webhook_sender.call_args_list[0][0][1] == WebhookEvent.LOGS - s.reset_mock() + t._webhook_sender.reset_mock() t.append_logs("still running\n") - assert s.call_count == 1 - actual = s.call_args_list[0][0][0] + assert t._webhook_sender.call_count == 1 + actual = t._webhook_sender.call_args_list[0][0][0] assert actual.logs == "running a prediction\nstill running\n" - assert s.call_args_list[0][0][1] == WebhookEvent.LOGS + assert t._webhook_sender.call_args_list[0][0][1] == WebhookEvent.LOGS - s.reset_mock() + t._webhook_sender.reset_mock() t.succeeded() - s.assert_called_once() - actual = s.call_args[0][0] + t._webhook_sender.assert_called_once() + actual = t._webhook_sender.call_args[0][0] assert actual.status == "succeeded" - assert s.call_args[0][1] == WebhookEvent.COMPLETED + assert t._webhook_sender.call_args[0][1] == WebhookEvent.COMPLETED - s.reset_mock() + t._webhook_sender.reset_mock() t.failed("oops") - s.assert_called_once() - actual = s.call_args[0][0] + t._webhook_sender.assert_called_once() + actual = t._webhook_sender.call_args[0][0] assert actual.status == "failed" assert actual.error == "oops" - assert s.call_args[0][1] == WebhookEvent.COMPLETED + assert t._webhook_sender.call_args[0][1] == WebhookEvent.COMPLETED - s.reset_mock() + t._webhook_sender.reset_mock() t.canceled() - s.assert_called_once() - actual = s.call_args[0][0] + t._webhook_sender.assert_called_once() + actual = t._webhook_sender.call_args[0][0] assert actual.status == "canceled" - assert s.call_args[0][1] == WebhookEvent.COMPLETED + assert t._webhook_sender.call_args[0][1] == WebhookEvent.COMPLETED def test_predict_task_file_uploads(): - u = mock.Mock() - p = PredictionResponse(input={"hello": "there"}) - t = PredictTask(p, file_uploader=u) + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook=None, + ) + t = PredictTask(p, upload_url="https://a.url.honest") + t._file_uploader = mock.Mock() # in reality this would be a Path object, but in this test we just care it # passes the output into the upload files function and uses whatever comes # back as final output. - u.return_value = "http://example.com/output-image.png" + t._file_uploader.return_value = "http://example.com/output-image.png" t.set_output_type(multi=False) t.append_output("Path(to/my/file)") - u.assert_called_once_with("Path(to/my/file)") - assert p.output == "http://example.com/output-image.png" + t._file_uploader.assert_called_once_with("Path(to/my/file)") + assert t.result.output == "http://example.com/output-image.png" def test_predict_task_file_uploads_multi(): - u = mock.Mock() - p = PredictionResponse(input={"hello": "there"}) - t = PredictTask(p, file_uploader=u) + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook=None, + ) + t = PredictTask(p, upload_url="https://a.url.honest") + t._file_uploader = mock.Mock() - u.return_value = [] + t._file_uploader.return_value = [] t.set_output_type(multi=True) - u.return_value = "http://example.com/hello.jpg" + t._file_uploader.return_value = "http://example.com/hello.jpg" t.append_output("hello.jpg") - u.return_value = "http://example.com/world.jpg" + t._file_uploader.return_value = "http://example.com/world.jpg" t.append_output("world.jpg") - u.assert_has_calls([mock.call("hello.jpg"), mock.call("world.jpg")]) - assert p.output == ["http://example.com/hello.jpg", "http://example.com/world.jpg"] + t._file_uploader.assert_has_calls([mock.call("hello.jpg"), mock.call("world.jpg")]) + assert t.result.output == [ + "http://example.com/hello.jpg", + "http://example.com/world.jpg", + ]