Skip to content

Commit

Permalink
Merge create_predict_task into PredictTask.__init__
Browse files Browse the repository at this point in the history
There's no real value in the being separate. This way, the `PredictTask`
knows how to set itself up properly.
  • Loading branch information
erbridge committed Aug 14, 2024
1 parent a1213bc commit 6d7580d
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 120 deletions.
60 changes: 22 additions & 38 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from abc import ABC, abstractmethod
from concurrent.futures import Future
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar, Union
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar

import requests
import structlog
from attrs import define, field
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from .. import schema, types
from .. import schema
from ..files import put_file_to_signed_endpoint
from ..json import upload_files
from ..predictor import BaseInput
Expand Down Expand Up @@ -89,7 +89,7 @@ def predict(

task_kwargs = task_kwargs or {}

self._predict_task = create_predict_task(prediction, **task_kwargs)
self._predict_task = PredictTask(prediction, **task_kwargs)
self._prediction_id = prediction.id

if isinstance(prediction.input, BaseInput):
Expand Down Expand Up @@ -225,32 +225,6 @@ def _handle_done(self, f: "Future[Done]") -> None:
self.failed()


def create_predict_task(
prediction: schema.PredictionRequest,
upload_url: Optional[str] = None,
) -> "PredictTask":
response = schema.PredictionResponse(**prediction.dict())

webhook = prediction.webhook
events_filter = (
prediction.webhook_events_filter or schema.WebhookEvent.default_events()
)

webhook_sender = None
if webhook is not None:
webhook_sender = webhook_caller_filtered(webhook, set(events_filter))

file_uploader = None
if upload_url is not None:
file_uploader = generate_file_uploader(upload_url, prediction_id=prediction.id)

event_handler = PredictTask(
response, webhook_sender=webhook_sender, file_uploader=file_uploader
)

return event_handler


def generate_file_uploader(
upload_url: str, prediction_id: Optional[str]
) -> Callable[[Any], Any]:
Expand All @@ -270,27 +244,37 @@ def upload_file(fh: io.IOBase) -> str:
class PredictTask(Task[schema.PredictionResponse]):
def __init__(
self,
p: schema.PredictionResponse,
webhook_sender: Optional[Callable[[Any, schema.WebhookEvent], None]] = None,
file_uploader: Optional[Callable[[Any], Any]] = None,
prediction_request: schema.PredictionRequest,
upload_url: Optional[str] = None,
) -> None:
super().__init__()

self._log = log.bind(prediction_id=p.id)
self._log = log.bind(prediction_id=prediction_request.id)

self._log.info("starting prediction")

self._fut: "Optional[Future[Done]]" = None

self._p = p
self._p = schema.PredictionResponse(**prediction_request.dict())
self._p.status = schema.Status.PROCESSING
self._output_type_multi = None
self._p.output = None
self._p.logs = ""
self._p.started_at = datetime.now(tz=timezone.utc)

self._webhook_sender = webhook_sender
self._file_uploader = file_uploader
self._webhook_sender = None
if prediction_request.webhook:
self._webhook_sender = webhook_caller_filtered(
str(prediction_request.webhook),
set(
prediction_request.webhook_events_filter
or schema.WebhookEvent.default_events()
),
)

self._file_uploader = None
if upload_url:
self._file_uploader = generate_file_uploader(
upload_url, prediction_id=self._p.id
)

@property
def result(self) -> schema.PredictionResponse:
Expand Down
Loading

0 comments on commit 6d7580d

Please sign in to comment.