Skip to content

Commit

Permalink
Fix a couple of bugs in the base64 file_encoding_strategy
Browse files Browse the repository at this point in the history
This commit adds tests for the `file_encoding_strategy` argument for
`replicate.run()` and fixes two bugs that surfaced:

 1. `replicate.run()` would convert the file provided into base64
    encoded data but not a valid data URL. We now use the
    `base64_encode_file` function used for outputs.
 2. `replicate.async_run()` accepted but did not use the
    `file_encoding_strategy` flag at all. This is fixed, though
    it is worth noting that `base64_encode_file` is not optimized
    for async workflows and will block. This migth be okay as the
    file sizes expected for data URL paylaods should be very
    small.
  • Loading branch information
aron committed Nov 13, 2024
1 parent 4fdd78f commit d405512
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 4 deletions.
10 changes: 7 additions & 3 deletions replicate/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def encode_json(
return encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
if file_encoding_strategy == "base64":
return base64.b64encode(obj.read()).decode("utf-8")
return base64_encode_file(obj)
else:
return client.files.create(obj).urls["get"]
if HAS_NUMPY:
Expand Down Expand Up @@ -77,9 +77,13 @@ async def async_encode_json(
]
if isinstance(obj, Path):
with obj.open("rb") as file:
return encode_json(file, client, file_encoding_strategy)
return await async_encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
return (await client.files.async_create(obj)).urls["get"]
if file_encoding_strategy == "base64":
# TODO: This should ideally use an async based file reader path.
return base64_encode_file(obj)
else:
return (await client.files.async_create(obj)).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand Down
133 changes: 132 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import io
import asyncio
import sys
from typing import AsyncIterator, Iterator, Optional, cast
from typing import AsyncIterator, Iterator, Optional, cast, Type

import httpx
import pytest
Expand All @@ -11,6 +13,11 @@
from replicate.exceptions import ModelError, ReplicateError
from replicate.helpers import FileOutput

import email
from email.message import EmailMessage
from email.parser import BytesParser
from email.policy import HTTP


@pytest.mark.vcr("run.yaml")
@pytest.mark.asyncio
Expand Down Expand Up @@ -581,6 +588,130 @@ async def test_run_with_model_error(mock_replicate_api_token):
assert excinfo.value.prediction.status == "failed"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=_prediction_with_status("processing"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
200,
json=_version_with_schema(),
)
)
mock_files_create = router.route(
method="POST",
path="/files",
).mock(
return_value=httpx.Response(
200,
json={
"id": "file1",
"name": "file.png",
"content_type": "image/png",
"size": 10,
"etag": "123",
"checksums": {},
"metadata": {},
"created_at": "",
"expires_at": "",
"urls": {"get": "https://api.replicate.com/files/file.txt"},
},
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)
if async_flag:
await client.async_run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
)
else:
client.run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
)

assert mock_predictions_create.called
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
assert (
prediction_payload.get("input", {}).get("file")
== "https://api.replicate.com/files/file.txt"
)

# Validate the Files API request
req = mock_files_create.calls[0].request
body = req.content
content_type = req.headers["Content-Type"]

# Parse the multipart data
parser = BytesParser(EmailMessage, policy=HTTP)
headers = f"Content-Type: {content_type}\n\n".encode("utf-8")
parsed_message_generator = parser.parsebytes(headers + body).walk()
next(parsed_message_generator) # wrapper
input_file = next(parsed_message_generator)
assert mock_files_create.called
assert input_file.get_content() == b"hello world"
assert input_file.get_content_type() == "application/octet-stream"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=_prediction_with_status("processing"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
200,
json=_version_with_schema(),
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)

if async_flag:
await client.async_run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
file_encoding_strategy="base64",
)
else:
client.run(
"test/example:v1",
input={"file": io.BytesIO(initial_bytes=b"hello world")},
file_encoding_strategy="base64",
)

assert mock_predictions_create.called
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
assert (
prediction_payload.get("input", {}).get("file")
== "data:application/octet-stream;base64,aGVsbG8gd29ybGQ="
)


@pytest.mark.asyncio
async def test_run_with_file_output(mock_replicate_api_token):
router = respx.Router(base_url="https://api.replicate.com/v1")
Expand Down

0 comments on commit d405512

Please sign in to comment.