diff --git a/src/google/adk/plugins/save_files_as_artifacts_plugin.py b/src/google/adk/plugins/save_files_as_artifacts_plugin.py index d92d9a7a54..8944f8b6af 100644 --- a/src/google/adk/plugins/save_files_as_artifacts_plugin.py +++ b/src/google/adk/plugins/save_files_as_artifacts_plugin.py @@ -31,6 +31,10 @@ # capabilities. _MODEL_ACCESSIBLE_URI_SCHEMES = {'gs', 'https', 'http'} +# Maximum file size for inline_data (20MB as per Gemini API documentation) +# https://ai.google.dev/gemini-api/docs/files +_MAX_INLINE_DATA_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB + class SaveFilesAsArtifactsPlugin(BasePlugin): """A plugin that saves files embedded in user messages as artifacts. @@ -81,8 +85,28 @@ async def on_user_message_callback( continue try: - # Use display_name if available, otherwise generate a filename + # Check file size before processing inline_data = part.inline_data + file_size = len(inline_data.data) if inline_data.data else 0 + + if file_size > _MAX_INLINE_DATA_SIZE_BYTES: + # File exceeds the inline_data limit + file_size_mb = file_size / (1024 * 1024) + limit_mb = _MAX_INLINE_DATA_SIZE_BYTES / (1024 * 1024) + error_message = ( + f'File size ({file_size_mb:.2f} MB) exceeds the maximum allowed' + f' size for inline uploads ({limit_mb:.0f} MB). Please use the' + f' Files API to upload files larger than {limit_mb:.0f} MB. See' + ' https://ai.google.dev/gemini-api/docs/files for more' + ' information.' + ) + logger.error(error_message) + # Replace with error message part + new_parts.append(types.Part(text=f'[Upload Error: {error_message}]')) + modified = True + continue + + # Use display_name if available, otherwise generate a filename file_name = inline_data.display_name if not file_name: file_name = f'artifact_{invocation_context.invocation_id}_{i}' diff --git a/tests/unittests/plugins/test_save_files_as_artifacts.py b/tests/unittests/plugins/test_save_files_as_artifacts.py index 66ab08098c..a1cf6cbdd2 100644 --- a/tests/unittests/plugins/test_save_files_as_artifacts.py +++ b/tests/unittests/plugins/test_save_files_as_artifacts.py @@ -303,3 +303,127 @@ def test_plugin_name_default(self): """Test that plugin has correct default name.""" plugin = SaveFilesAsArtifactsPlugin() assert plugin.name == "save_files_as_artifacts_plugin" + + @pytest.mark.asyncio + async def test_file_size_exceeds_limit(self): + """Test that files exceeding 20MB limit show error message.""" + # Create a file larger than 20MB (20 * 1024 * 1024 bytes) + large_file_data = b"x" * (21 * 1024 * 1024) # 21 MB + inline_data = types.Blob( + display_name="large_file.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should not try to save the artifact + self.mock_context.artifact_service.save_artifact.assert_not_called() + + # Should return modified content with error message + assert result is not None + assert len(result.parts) == 1 + assert result.parts[0].text is not None + assert "[Upload Error:" in result.parts[0].text + assert "21.00 MB" in result.parts[0].text + assert "20 MB" in result.parts[0].text + assert "Files API" in result.parts[0].text + + @pytest.mark.asyncio + async def test_file_size_at_limit(self): + """Test that files exactly at 20MB limit are processed successfully.""" + # Create a file exactly 20MB (20 * 1024 * 1024 bytes) + file_data = b"x" * (20 * 1024 * 1024) # Exactly 20 MB + inline_data = types.Blob( + display_name="max_size_file.pdf", + data=file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should save the artifact since it's at the limit + self.mock_context.artifact_service.save_artifact.assert_called_once() + assert result is not None + assert len(result.parts) == 2 + assert result.parts[0].text == '[Uploaded Artifact: "max_size_file.pdf"]' + assert result.parts[1].file_data is not None + + @pytest.mark.asyncio + async def test_file_size_just_over_limit(self): + """Test that files just over 20MB limit show error message.""" + # Create a file just over 20MB + large_file_data = b"x" * (20 * 1024 * 1024 + 1) # 20 MB + 1 byte + inline_data = types.Blob( + display_name="slightly_too_large.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should not try to save the artifact + self.mock_context.artifact_service.save_artifact.assert_not_called() + + # Should return error message + assert result is not None + assert len(result.parts) == 1 + assert "[Upload Error:" in result.parts[0].text + + @pytest.mark.asyncio + async def test_mixed_file_sizes(self): + """Test processing multiple files with mixed sizes.""" + # Small file (should succeed) + small_file_data = b"x" * (5 * 1024 * 1024) # 5 MB + small_inline_data = types.Blob( + display_name="small.pdf", + data=small_file_data, + mime_type="application/pdf", + ) + + # Large file (should fail) + large_file_data = b"x" * (25 * 1024 * 1024) # 25 MB + large_inline_data = types.Blob( + display_name="large.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content( + parts=[ + types.Part(inline_data=small_inline_data), + types.Part(inline_data=large_inline_data), + ] + ) + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should only save the small file + self.mock_context.artifact_service.save_artifact.assert_called_once_with( + app_name="test_app", + user_id="test_user", + session_id="test_session", + filename="small.pdf", + artifact=user_message.parts[0], + ) + + # Should return both success and error messages + assert result is not None + assert len(result.parts) == 3 # [success placeholder, file_data, error] + assert '[Uploaded Artifact: "small.pdf"]' in result.parts[0].text + assert result.parts[1].file_data is not None + assert "[Upload Error:" in result.parts[2].text