Skip to content

Commit 8a3e961

Browse files
committed
Configure read timeout based on wait parameter
This fixes a bug in the new `wait` implementation where the default read timeout for the HTTP client is shorter than the timeout on the server. This results in the client erroring before the server has had the oppertunity to respond with a partial prediction. This commit now provides a custom timeout for the `predictions.create` request based on the `wait` parameter provided. We add a 500ms buffer to the timeout to account for some discrepancy between server and client timings.
1 parent c59bb32 commit 8a3e961

File tree

3 files changed

+86
-65
lines changed

3 files changed

+86
-65
lines changed

replicate/deployment.py

+25-27
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from replicate.prediction import (
99
Prediction,
1010
_create_prediction_body,
11-
_create_prediction_headers,
11+
_create_prediction_request_params,
1212
_json_to_prediction,
1313
)
1414
from replicate.resource import Namespace, Resource
@@ -421,21 +421,25 @@ def create(
421421
Create a new prediction with the deployment.
422422
"""
423423

424+
wait = params.pop("wait", None)
424425
file_encoding_strategy = params.pop("file_encoding_strategy", None)
426+
425427
if input is not None:
426428
input = encode_json(
427429
input,
428430
client=self._client,
429431
file_encoding_strategy=file_encoding_strategy,
430432
)
431-
headers = _create_prediction_headers(wait=params.pop("wait", None))
432-
body = _create_prediction_body(version=None, input=input, **params)
433433

434+
body = _create_prediction_body(version=None, input=input, **params)
435+
extras = _create_prediction_request_params(
436+
wait=wait,
437+
)
434438
resp = self._client._request(
435439
"POST",
436440
f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions",
437441
json=body,
438-
headers=headers,
442+
**extras,
439443
)
440444

441445
return _json_to_prediction(self._client, resp.json())
@@ -449,21 +453,24 @@ async def async_create(
449453
Create a new prediction with the deployment.
450454
"""
451455

456+
wait = params.pop("wait")
452457
file_encoding_strategy = params.pop("file_encoding_strategy", None)
453458
if input is not None:
454459
input = await async_encode_json(
455460
input,
456461
client=self._client,
457462
file_encoding_strategy=file_encoding_strategy,
458463
)
459-
headers = _create_prediction_headers(wait=params.pop("wait", None))
460-
body = _create_prediction_body(version=None, input=input, **params)
461464

465+
body = _create_prediction_body(version=None, input=input, **params)
466+
extras = _create_prediction_request_params(
467+
wait=wait,
468+
)
462469
resp = await self._client._async_request(
463470
"POST",
464471
f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions",
465472
json=body,
466-
headers=headers,
473+
**extras,
467474
)
468475

469476
return _json_to_prediction(self._client, resp.json())
@@ -484,24 +491,20 @@ def create(
484491
Create a new prediction with the deployment.
485492
"""
486493

487-
url = _create_prediction_url_from_deployment(deployment)
488-
494+
wait = params.pop("wait", None)
489495
file_encoding_strategy = params.pop("file_encoding_strategy", None)
496+
497+
url = _create_prediction_url_from_deployment(deployment)
490498
if input is not None:
491499
input = encode_json(
492500
input,
493501
client=self._client,
494502
file_encoding_strategy=file_encoding_strategy,
495503
)
496-
headers = _create_prediction_headers(wait=params.pop("wait", None))
497-
body = _create_prediction_body(version=None, input=input, **params)
498504

499-
resp = self._client._request(
500-
"POST",
501-
url,
502-
json=body,
503-
headers=headers,
504-
)
505+
body = _create_prediction_body(version=None, input=input, **params)
506+
extras = _create_prediction_request_params(wait=wait)
507+
resp = self._client._request("POST", url, json=body, **extras)
505508

506509
return _json_to_prediction(self._client, resp.json())
507510

@@ -515,25 +518,20 @@ async def async_create(
515518
Create a new prediction with the deployment.
516519
"""
517520

518-
url = _create_prediction_url_from_deployment(deployment)
519-
521+
wait = params.pop("wait", None)
520522
file_encoding_strategy = params.pop("file_encoding_strategy", None)
523+
524+
url = _create_prediction_url_from_deployment(deployment)
521525
if input is not None:
522526
input = await async_encode_json(
523527
input,
524528
client=self._client,
525529
file_encoding_strategy=file_encoding_strategy,
526530
)
527531

528-
headers = _create_prediction_headers(wait=params.pop("wait", None))
529532
body = _create_prediction_body(version=None, input=input, **params)
530-
531-
resp = await self._client._async_request(
532-
"POST",
533-
url,
534-
json=body,
535-
headers=headers,
536-
)
533+
extras = _create_prediction_request_params(wait=wait)
534+
resp = await self._client._async_request("POST", url, json=body, **extras)
537535

538536
return _json_to_prediction(self._client, resp.json())
539537

replicate/model.py

+15-22
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from replicate.prediction import (
1010
Prediction,
1111
_create_prediction_body,
12-
_create_prediction_headers,
12+
_create_prediction_request_params,
1313
_json_to_prediction,
1414
)
1515
from replicate.resource import Namespace, Resource
@@ -389,24 +389,20 @@ def create(
389389
Create a new prediction with the deployment.
390390
"""
391391

392-
url = _create_prediction_url_from_model(model)
393-
392+
wait = params.pop("wait", None)
394393
file_encoding_strategy = params.pop("file_encoding_strategy", None)
394+
395+
path = _create_prediction_path_from_model(model)
395396
if input is not None:
396397
input = encode_json(
397398
input,
398399
client=self._client,
399400
file_encoding_strategy=file_encoding_strategy,
400401
)
401-
headers = _create_prediction_headers(wait=params.pop("wait", None))
402-
body = _create_prediction_body(version=None, input=input, **params)
403402

404-
resp = self._client._request(
405-
"POST",
406-
url,
407-
json=body,
408-
headers=headers,
409-
)
403+
body = _create_prediction_body(version=None, input=input, **params)
404+
extras = _create_prediction_request_params(wait=wait)
405+
resp = self._client._request("POST", path, json=body, **extras)
410406

411407
return _json_to_prediction(self._client, resp.json())
412408

@@ -420,24 +416,21 @@ async def async_create(
420416
Create a new prediction with the deployment.
421417
"""
422418

423-
url = _create_prediction_url_from_model(model)
424-
419+
wait = params.pop("wait", None)
425420
file_encoding_strategy = params.pop("file_encoding_strategy", None)
421+
422+
path = _create_prediction_path_from_model(model)
423+
426424
if input is not None:
427425
input = await async_encode_json(
428426
input,
429427
client=self._client,
430428
file_encoding_strategy=file_encoding_strategy,
431429
)
432-
headers = _create_prediction_headers(wait=params.pop("wait", None))
433-
body = _create_prediction_body(version=None, input=input, **params)
434430

435-
resp = await self._client._async_request(
436-
"POST",
437-
url,
438-
json=body,
439-
headers=headers,
440-
)
431+
body = _create_prediction_body(version=None, input=input, **params)
432+
extras = _create_prediction_request_params(wait=wait)
433+
resp = await self._client._async_request("POST", path, json=body, **extras)
441434

442435
return _json_to_prediction(self._client, resp.json())
443436

@@ -522,7 +515,7 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
522515
return model
523516

524517

525-
def _create_prediction_url_from_model(
518+
def _create_prediction_path_from_model(
526519
model: Union[str, Tuple[str, str], "Model"],
527520
) -> str:
528521
owner, name = None, None

replicate/prediction.py

+46-16
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
overload,
1717
)
1818

19+
import httpx
1920
from typing_extensions import NotRequired, TypedDict, Unpack
2021

2122
from replicate.exceptions import ModelError, ReplicateError
@@ -446,6 +447,9 @@ def create( # type: ignore
446447
Create a new prediction for the specified model, version, or deployment.
447448
"""
448449

450+
wait = params.pop("wait", None)
451+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
452+
449453
if args:
450454
version = args[0] if len(args) > 0 else None
451455
input = args[1] if len(args) > 1 else input
@@ -477,26 +481,20 @@ def create( # type: ignore
477481
**params,
478482
)
479483

480-
file_encoding_strategy = params.pop("file_encoding_strategy", None)
481484
if input is not None:
482485
input = encode_json(
483486
input,
484487
client=self._client,
485488
file_encoding_strategy=file_encoding_strategy,
486489
)
487-
headers = _create_prediction_headers(wait=params.pop("wait", None))
490+
488491
body = _create_prediction_body(
489492
version,
490493
input,
491494
**params,
492495
)
493-
494-
resp = self._client._request(
495-
"POST",
496-
"/v1/predictions",
497-
headers=headers,
498-
json=body,
499-
)
496+
extras = _create_prediction_request_params(wait=wait)
497+
resp = self._client._request("POST", "/v1/predictions", json=body, **extras)
500498

501499
return _json_to_prediction(self._client, resp.json())
502500

@@ -538,6 +536,8 @@ async def async_create( # type: ignore
538536
"""
539537
Create a new prediction for the specified model, version, or deployment.
540538
"""
539+
wait = params.pop("wait", None)
540+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
541541

542542
if args:
543543
version = args[0] if len(args) > 0 else None
@@ -570,25 +570,21 @@ async def async_create( # type: ignore
570570
**params,
571571
)
572572

573-
file_encoding_strategy = params.pop("file_encoding_strategy", None)
574573
if input is not None:
575574
input = await async_encode_json(
576575
input,
577576
client=self._client,
578577
file_encoding_strategy=file_encoding_strategy,
579578
)
580-
headers = _create_prediction_headers(wait=params.pop("wait", None))
579+
581580
body = _create_prediction_body(
582581
version,
583582
input,
584583
**params,
585584
)
586-
585+
extras = _create_prediction_request_params(wait=wait)
587586
resp = await self._client._async_request(
588-
"POST",
589-
"/v1/predictions",
590-
headers=headers,
591-
json=body,
587+
"POST", "/v1/predictions", json=body, **extras
592588
)
593589

594590
return _json_to_prediction(self._client, resp.json())
@@ -628,6 +624,40 @@ async def async_cancel(self, id: str) -> Prediction:
628624
return _json_to_prediction(self._client, resp.json())
629625

630626

627+
class CreatePredictionRequestParams(TypedDict):
628+
headers: NotRequired[Optional[dict]]
629+
timeout: NotRequired[Optional[httpx.Timeout]]
630+
631+
632+
def _create_prediction_request_params(
633+
wait: Optional[Union[int, bool]],
634+
) -> CreatePredictionRequestParams:
635+
timeout = _create_prediction_timeout(wait=wait)
636+
headers = _create_prediction_headers(wait=wait)
637+
638+
return {
639+
"headers": headers,
640+
"timeout": timeout,
641+
}
642+
643+
644+
def _create_prediction_timeout(
645+
*, wait: Optional[Union[int, bool]] = None
646+
) -> Union[httpx.Timeout, None]:
647+
"""
648+
Returns an `httpx.Timeout` instances appropriate for the optional
649+
`Prefer: wait=x` header that can be provided with the request. This
650+
will ensure that we give the server enough time to respond with
651+
a partial prediction in the event that the request times out.
652+
"""
653+
654+
if not wait:
655+
return None
656+
657+
read_timeout = 60.0 if isinstance(wait, bool) else wait
658+
return httpx.Timeout(5.0, read=read_timeout + 0.5)
659+
660+
631661
def _create_prediction_headers(
632662
*,
633663
wait: Optional[Union[int, bool]] = None,

0 commit comments

Comments
 (0)