Skip to content

Commit eeae314

Browse files
feat(plugins): add file size validation and Files API routing for SaveFilesAsArtifactsPlugin with robust error handling and improved test mocking (#3781)
1 parent 347f7b0 commit eeae314

File tree

2 files changed

+29
-60
lines changed

2 files changed

+29
-60
lines changed

src/google/adk/plugins/save_files_as_artifacts_plugin.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import copy
1818
import logging
19+
import mimetypes
1920
import os
2021
import tempfile
2122
from typing import Optional
@@ -92,7 +93,7 @@ async def on_user_message_callback(
9293
try:
9394
# Check file size before processing
9495
inline_data = part.inline_data
95-
file_size = len(inline_data.data) if inline_data.data else 0
96+
file_size = len(inline_data.data or b'')
9697

9798
# Use display_name if available, otherwise generate a filename
9899
file_name = inline_data.display_name
@@ -110,7 +111,10 @@ async def on_user_message_callback(
110111
file_size_gb = file_size / (1024 * 1024 * 1024)
111112
error_message = (
112113
f'File {display_name} ({file_size_gb:.2f} GB) exceeds the'
113-
' maximum supported size of 2GB. Please upload a smaller file.'
114+
f' maximumFile {display_name} ({file_size_gb:.2f} GB) exceeds the'
115+
' maximum supported size of'
116+
f' {_MAX_FILES_API_SIZE_BYTES / (1024*1024*1024):.0f}GB. Please'
117+
' upload a smaller file.'
114118
)
115119
logger.warning(error_message)
116120
new_parts.append(types.Part(text=f'[Upload Error: {error_message}]'))
@@ -121,8 +125,8 @@ async def on_user_message_callback(
121125
if file_size > _MAX_INLINE_DATA_SIZE_BYTES:
122126
file_size_mb = file_size / (1024 * 1024)
123127
logger.info(
124-
f'File {display_name} ({file_size_mb:.2f} MB) exceeds'
125-
' inline_data limit. Uploading via Files API...'
128+
f'File {display_name} ({file_size_mb:.2f} MB) exceeds inline_data'
129+
' limit. Uploading via Files API...'
126130
)
127131

128132
# Upload to Files API and convert to file_data
@@ -214,16 +218,8 @@ async def _upload_to_files_api(
214218
if inline_data.display_name and '.' in inline_data.display_name:
215219
file_extension = os.path.splitext(inline_data.display_name)[1]
216220
elif inline_data.mime_type:
217-
# Simple mime type to extension mapping
218-
mime_to_ext = {
219-
'application/pdf': '.pdf',
220-
'image/png': '.png',
221-
'image/jpeg': '.jpg',
222-
'image/gif': '.gif',
223-
'text/plain': '.txt',
224-
'application/json': '.json',
225-
}
226-
file_extension = mime_to_ext.get(inline_data.mime_type, '')
221+
# Use mimetypes for robust mime type to extension mapping
222+
file_extension = mimetypes.guess_extension(inline_data.mime_type) or ''
227223

228224
# Create temporary file
229225
with tempfile.NamedTemporaryFile(

tests/unittests/plugins/test_save_files_as_artifacts.py

Lines changed: 19 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -321,23 +321,22 @@ async def test_file_size_exceeds_limit(self):
321321
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
322322

323323
# Mock the Files API upload
324-
with (
325-
patch.object(Client, "__init__", return_value=None),
326-
patch.object(Client, "files") as mock_files,
327-
):
324+
with patch(
325+
"google.adk.plugins.save_files_as_artifacts_plugin.Client"
326+
) as mock_client:
328327
# Mock uploaded file response
329328
mock_uploaded_file = MagicMock()
330329
mock_uploaded_file.uri = (
331330
"https://generativelanguage.googleapis.com/v1beta/files/test-file-id"
332331
)
333-
mock_files.upload.return_value = mock_uploaded_file
332+
mock_client.return_value.files.upload.return_value = mock_uploaded_file
334333

335334
result = await self.plugin.on_user_message_callback(
336335
invocation_context=self.mock_context, user_message=user_message
337336
)
338337

339338
# Should upload via Files API
340-
mock_files.upload.assert_called_once()
339+
mock_client.return_value.files.upload.assert_called_once()
341340

342341
# Should save the artifact with file_data
343342
self.mock_context.artifact_service.save_artifact.assert_called_once()
@@ -387,7 +386,7 @@ async def test_file_size_just_over_limit(self):
387386
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
388387

389388
# Mock the Files API upload
390-
with patch.object(Client, "files", create=True) as mock_files:
389+
with patch.object(Client, "files") as mock_files:
391390
mock_uploaded_file = MagicMock()
392391
mock_uploaded_file.uri = (
393392
"https://generativelanguage.googleapis.com/v1beta/files/test-file-id"
@@ -434,7 +433,7 @@ async def test_mixed_file_sizes(self):
434433
)
435434

436435
# Mock the Files API upload for large file
437-
with patch.object(Client, "files", create=True) as mock_files:
436+
with patch.object(Client, "files") as mock_files:
438437
mock_uploaded_file = MagicMock()
439438
mock_uploaded_file.uri = (
440439
"https://generativelanguage.googleapis.com/v1beta/files/test-file-id"
@@ -475,7 +474,7 @@ async def test_files_api_upload_failure(self):
475474
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
476475

477476
# Mock the Files API to raise an exception
478-
with patch.object(Client, "files", create=True) as mock_files:
477+
with patch.object(Client, "files") as mock_files:
479478
mock_files.upload.side_effect = Exception("API quota exceeded")
480479

481480
result = await self.plugin.on_user_message_callback(
@@ -498,57 +497,31 @@ async def test_files_api_upload_failure(self):
498497
@pytest.mark.asyncio
499498
async def test_file_exceeds_files_api_limit(self):
500499
"""Test that files exceeding 2GB limit are rejected with clear error."""
501-
# Create a file larger than 2GB (simulated with a descriptor that reports large size)
502-
# Create a mock object that behaves like bytes but reports 2GB+ size
503-
large_data = b"x" * 1000 # Small actual data for testing
504-
505-
# Create inline_data with the small data
500+
# Use a small file for the test
501+
large_data = b"x" * 1000
506502
inline_data = types.Blob(
507503
display_name="huge_video.mp4",
508504
data=large_data,
509505
mime_type="video/mp4",
510506
)
511-
512507
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
513508

514-
# Patch the file size check to simulate a 2GB+ file
515-
original_callback = self.plugin.on_user_message_callback
516-
517-
async def patched_callback(*, invocation_context, user_message):
518-
# Temporarily replace the data length check
519-
for part in user_message.parts:
520-
if part.inline_data:
521-
# Simulate 2GB + 1 byte size
522-
file_size_over_limit = (2 * 1024 * 1024 * 1024) + 1
523-
# Manually inject the check that would happen in the real code
524-
if file_size_over_limit > (2 * 1024 * 1024 * 1024):
525-
file_size_gb = file_size_over_limit / (1024 * 1024 * 1024)
526-
display_name = part.inline_data.display_name or "unknown"
527-
error_message = (
528-
f"File {display_name} ({file_size_gb:.2f} GB) exceeds the"
529-
" maximum supported size of 2GB. Please upload a smaller file."
530-
)
531-
return types.Content(
532-
role="user",
533-
parts=[types.Part(text=f"[Upload Error: {error_message}]")],
534-
)
535-
return await original_callback(
536-
invocation_context=invocation_context, user_message=user_message
509+
# Patch the size limit to be smaller than the file data
510+
with patch(
511+
"google.adk.plugins.save_files_as_artifacts_plugin._MAX_FILES_API_SIZE_BYTES",
512+
500,
513+
):
514+
result = await self.plugin.on_user_message_callback(
515+
invocation_context=self.mock_context, user_message=user_message
537516
)
538517

539-
self.plugin.on_user_message_callback = patched_callback
540-
541-
result = await self.plugin.on_user_message_callback(
542-
invocation_context=self.mock_context, user_message=user_message
543-
)
544-
545518
# Should not attempt any upload
546519
self.mock_context.artifact_service.save_artifact.assert_not_called()
547520

548-
# Should return error message about 2GB limit
521+
# Should return error message about the limit
549522
assert result is not None
550523
assert len(result.parts) == 1
551524
assert "[Upload Error:" in result.parts[0].text
552525
assert "huge_video.mp4" in result.parts[0].text
553-
assert "2.00 GB" in result.parts[0].text
526+
# Note: This assertion will depend on fixing the hardcoded "2GB" in the error message.
554527
assert "exceeds the maximum supported size" in result.parts[0].text

0 commit comments

Comments
 (0)