diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 0af1838989..ac1c7fe5fa 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -13,6 +13,7 @@ import uuid import warnings import weakref +from concurrent import futures from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto, unique from multiprocessing.connection import Connection @@ -130,14 +131,16 @@ def __init__( self._request_send_conn = send_conn self._request_recv_conn = recv_conn - self._pool = ThreadPoolExecutor(max_workers=1) + self._event_consumer_pool = ThreadPoolExecutor(max_workers=1) + self._prediction_prepare_pool = ThreadPoolExecutor(max_workers=max_concurrency) + self._input_download_pool = ThreadPoolExecutor(max_workers=8) self._event_consumer = None def setup(self) -> "Future[Done]": self._assert_state(WorkerState.NEW) self._state = WorkerState.STARTING self._child.start() - self._event_consumer = self._pool.submit(self._consume_events) + self._event_consumer = self._event_consumer_pool.submit(self._consume_events) return self._setup_result def predict( @@ -163,9 +166,65 @@ def predict( self._assert_state(WorkerState.READY) result = Future() self._predictions_in_flight[tag] = PredictionState(tag, payload, result) - self._request_send_conn.send(PredictionRequest(tag)) + + # Prepare payload asynchronously (download URLPath objects) + fut = self._prediction_prepare_pool.submit(self._prepare_payload(payload)) + # then start the prediction + fut.add_done_callback(self._start_prediction(tag)) return result + def _prepare_payload(self, payload: Dict[str, Any]) -> Callable[[], Dict[str, Any]]: + def prepare_payload() -> Dict[str, Any]: + to_await = [] + futs = {} + for k, v in payload.items(): + # Check if v is an instance of URLPath + if isinstance(v, URLPath): + futs[k] = self._input_download_pool.submit(v.convert) + to_await.append(futs[k]) + # Check if v is a list of URLPath instances + elif isinstance(v, list) and all( + isinstance(item, URLPath) for item in v + ): + futs[k] = [ + self._input_download_pool.submit(item.convert) for item in v + ] + to_await += futs[k] + futures.wait(to_await, return_when=futures.FIRST_EXCEPTION) + for k, v in futs.items(): + if isinstance(v, list): + payload[k] = [] + for fut in v: + # the future may not be done if and only if another + # future finished with an exception + if fut.done(): + payload[k].append(fut.result()) + elif isinstance(v, Future): + if v.done(): + payload[k] = v.result() + return payload + + return prepare_payload + + def _start_prediction( + self, tag: Optional[str] + ) -> Callable[["Future[Dict[str,Any]]"], None]: + def payload_callback(fut: "Future[Dict[str,Any]]") -> None: + if fut.exception(): + done = Done(error=True, error_detail=str(fut.exception())) + self._publish(Envelope(done, tag)) + self._complete_prediction(done, tag) + return + payload = fut.result() + self._events.send( + Envelope( + event=PredictionInput(payload=payload), + tag=tag, + ) + ) + + return payload_callback + def subscribe( self, subscriber: Callable[[_PublicEventType], None], @@ -195,7 +254,7 @@ def shutdown(self, timeout: Optional[float] = None) -> None: if self._event_consumer: self._event_consumer.result(timeout=timeout) - self._pool.shutdown() + self._event_consumer_pool.shutdown() def terminate(self) -> None: """ @@ -209,7 +268,7 @@ def terminate(self) -> None: self._child.terminate() self._child.join() - self._pool.shutdown(wait=False) + self._event_consumer_pool.shutdown(wait=False) def cancel(self, tag: Optional[str] = None) -> None: self._request_send_conn.send(CancelRequest(tag)) @@ -275,27 +334,7 @@ def _consume_events_inner(self) -> None: ) if self._request_recv_conn in read_socks: ev = self._request_recv_conn.recv() - if isinstance(ev, PredictionRequest): - with self._predictions_lock: - state = self._predictions_in_flight[ev.tag] - - # Prepare payload (download URLPath objects) - # FIXME this blocks the event loop, which is bad in concurrent mode - try: - _prepare_payload(state.payload) - except Exception as e: - done = Done(error=True, error_detail=str(e)) - self._publish(Envelope(done, state.tag)) - self._complete_prediction(done, state.tag) - else: - # Start the prediction - self._events.send( - Envelope( - event=PredictionInput(payload=state.payload), - tag=state.tag, - ) - ) - elif isinstance(ev, CancelRequest): + if isinstance(ev, CancelRequest): with self._predictions_lock: predict_state = self._predictions_in_flight.get(ev.tag) if predict_state and not predict_state.cancel_sent: @@ -844,13 +883,3 @@ def make_worker( ) parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency) return parent - - -def _prepare_payload(payload: Dict[str, Any]) -> None: - for k, v in payload.items(): - # Check if v is an instance of URLPath - if isinstance(v, URLPath): - payload[k] = v.convert() - # Check if v is a list of URLPath instances - elif isinstance(v, list) and all(isinstance(item, URLPath) for item in v): - payload[k] = [item.convert() for item in v]