diff --git a/.gitignore b/.gitignore index bd44686d80..d5158d6e99 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ venv.bak/ # IDE .idea/ .vscode/ +.zed/ *.swp *.swo .DS_Store diff --git a/contributing/samples/hello_world_gemma/agent.py b/contributing/samples/hello_world_gemma/agent.py index 3407d721d3..f4273c1949 100644 --- a/contributing/samples/hello_world_gemma/agent.py +++ b/contributing/samples/hello_world_gemma/agent.py @@ -16,7 +16,7 @@ import random from google.adk.agents.llm_agent import Agent -from google.adk.models.gemma_llm import Gemma +from google.adk.models.gemma_llm import Gemma3GeminiAPI from google.genai.types import GenerateContentConfig @@ -61,7 +61,7 @@ async def check_prime(nums: list[int]) -> str: root_agent = Agent( - model=Gemma(model="gemma-3-27b-it"), + model=Gemma3GeminiAPI(model="gemma-3-27b-it"), name="data_processing_agent", description=( "hello world agent that can roll many-sided dice and check if numbers" diff --git a/contributing/samples/hello_world_gemma3_ollama/__init__.py b/contributing/samples/hello_world_gemma3_ollama/__init__.py new file mode 100644 index 0000000000..7d5bb0b1c6 --- /dev/null +++ b/contributing/samples/hello_world_gemma3_ollama/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from . import agent diff --git a/contributing/samples/hello_world_gemma3_ollama/agent.py b/contributing/samples/hello_world_gemma3_ollama/agent.py new file mode 100644 index 0000000000..74a9d72f85 --- /dev/null +++ b/contributing/samples/hello_world_gemma3_ollama/agent.py @@ -0,0 +1,93 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random + +from google.adk.agents.llm_agent import Agent +from google.adk.models.gemma_llm import Gemma3Ollama + +litellm_logger = logging.getLogger("LiteLLM") +litellm_logger.setLevel(logging.WARNING) + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model=Gemma3Ollama(model="ollama/gemma3:12b"), + name="data_processing_agent", + description=( + "hello world agent that can roll a dice of 8 sides and check prime" + " numbers." + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel (in one request and in one round). + It is ok to discuss previous dice rolls, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], +) diff --git a/contributing/samples/hello_world_gemma3_ollama/main.py b/contributing/samples/hello_world_gemma3_ollama/main.py new file mode 100644 index 0000000000..a383b4f279 --- /dev/null +++ b/contributing/samples/hello_world_gemma3_ollama/main.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import time + +import agent +from dotenv import load_dotenv +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_1 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_1, 'Hi, introduce yourself.') + await run_prompt( + session_1, 'Roll a die with 100 sides and check if it is prime' + ) + await run_prompt(session_1, 'Roll it again.') + await run_prompt(session_1, 'What numbers did I get?') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 72b3a6e748..c3b5ef7d4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ test = [ "a2a-sdk>=0.3.0,<0.4.0;python_version>='3.10'", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.10'", # For CrewaiTool tests + "instructor>=1.11.3", # For instructor (Gemma3 parsing) "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent @@ -144,6 +145,7 @@ extensions = [ "beautifulsoup4>=3.2.2", # For load_web_page tool. "crewai[tools];python_version>='3.10'", # For CrewaiTool "docker>=7.0.0", # For ContainerCodeExecutor + "instructor>=1.11.3", # For instructor (Gemma3 parsing) "kubernetes>=29.0.0", # For GkeCodeExecutor "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent "litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index 9f3c2a2c48..7ea349cbe6 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -16,7 +16,8 @@ from .apigee_llm import ApigeeLlm from .base_llm import BaseLlm -from .gemma_llm import Gemma +from .gemma_llm import Gemma3GeminiAPI +from .gemma_llm import Gemma3Ollama from .google_llm import Gemini from .llm_request import LlmRequest from .llm_response import LlmResponse @@ -25,11 +26,13 @@ __all__ = [ 'BaseLlm', 'Gemini', - 'Gemma', + 'Gemma3GeminiAPI', + 'Gemma3Ollama', 'LLMRegistry', ] LLMRegistry.register(Gemini) -LLMRegistry.register(Gemma) LLMRegistry.register(ApigeeLlm) +LLMRegistry.register(Gemma3GeminiAPI) +LLMRegistry.register(Gemma3Ollama) diff --git a/src/google/adk/models/gemma_llm.py b/src/google/adk/models/gemma_llm.py index 3233d66f99..c9247784ed 100644 --- a/src/google/adk/models/gemma_llm.py +++ b/src/google/adk/models/gemma_llm.py @@ -17,11 +17,11 @@ from functools import cached_property import json import logging -import re from typing import Any from typing import AsyncGenerator from google.adk.models.google_llm import Gemini +from google.adk.models.lite_llm import LiteLlm from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.utils.variant_utils import GoogleLLMVariant @@ -29,38 +29,34 @@ from google.genai.types import Content from google.genai.types import FunctionDeclaration from google.genai.types import Part +import litellm from pydantic import AliasChoices from pydantic import BaseModel from pydantic import Field from pydantic import ValidationError from typing_extensions import override -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) -class GemmaFunctionCallModel(BaseModel): +class Gemma3FunctionCallModel(BaseModel): """Flexible Pydantic model for parsing inline Gemma function call responses.""" - name: str = Field(validation_alias=AliasChoices('name', 'function')) + name: str = Field(validation_alias=AliasChoices("name", "function")) parameters: dict[str, Any] = Field( - validation_alias=AliasChoices('parameters', 'args') + validation_alias=AliasChoices("parameters", "args") ) -class Gemma(Gemini): - """Integration for Gemma models exposed via the Gemini API. +class Gemma3GeminiAPI(Gemini): + """Integration for Gemma 3 models exposed via the Gemini API. - Only Gemma 3 models are supported at this time. For agentic use cases, - use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended. + Only the larger Gemma 3 model sizes are supported (12b, 27b) as + function calling support on the smaller models is not consistent + enough for basic agent usage. For full documentation, see: https://ai.google.dev/gemma/docs/core/ - NOTE: Gemma does **NOT** support system instructions. Any system instructions - will be replaced with an initial *user* prompt in the LLM request. If system - instructions change over the course of agent execution, the initial content - **SHOULD** be replaced. Special care is warranted here. - See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions - NOTE: Gemma's function calling support is limited. It does not have full access to the same built-in tools as Gemini. It also does not have special API support for tools and functions. Rather, tools must be passed in via a `user` prompt, and extracted from model @@ -70,9 +66,13 @@ class Gemma(Gemini): usage via the Gemini API. """ - model: str = ( - 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] - ) + def __init__(self, model: str = "gemma-3-12b-it", **kwargs): + if model not in self.supported_models(): + raise ValueError( + f"Model '{model}' not supported. Use one of:" + f" {self.supported_models()=}" + ) + super().__init__(model=model, **kwargs) @classmethod @override @@ -84,119 +84,22 @@ def supported_models(cls) -> list[str]: """ return [ - r'gemma-3.*', + "gemma-3-12b-it", + "gemma-3-27b-it", ] @cached_property def _api_backend(self) -> GoogleLLMVariant: return GoogleLLMVariant.GEMINI_API - def _move_function_calls_into_system_instruction( - self, llm_request: LlmRequest - ): - if llm_request.model is None or not llm_request.model.startswith('gemma-3'): - return - - # Iterate through the existing contents to find and convert function calls and responses - # from text parts, as Gemma models don't directly support function calling. - new_contents: list[Content] = [] - for content_item in llm_request.contents: - ( - new_parts_for_content, - has_function_response_part, - has_function_call_part, - ) = _convert_content_parts_for_gemma(content_item) - - if has_function_response_part: - if new_parts_for_content: - new_contents.append(Content(role='user', parts=new_parts_for_content)) - elif has_function_call_part: - if new_parts_for_content: - new_contents.append( - Content(role='model', parts=new_parts_for_content) - ) - else: - new_contents.append(content_item) - - llm_request.contents = new_contents - - if not llm_request.config.tools: - return - - all_function_declarations: list[FunctionDeclaration] = [] - for tool_item in llm_request.config.tools: - if isinstance(tool_item, types.Tool) and tool_item.function_declarations: - all_function_declarations.extend(tool_item.function_declarations) - - if all_function_declarations: - system_instruction = _build_gemma_function_system_instruction( - all_function_declarations - ) - llm_request.append_instructions([system_instruction]) - - llm_request.config.tools = [] - - def _extract_function_calls_from_response(self, llm_response: LlmResponse): - if llm_response.partial or (llm_response.turn_complete is True): - return - - if not llm_response.content: - return - - if not llm_response.content.parts: - return - - if len(llm_response.content.parts) > 1: - return - - response_text = llm_response.content.parts[0].text - if not response_text: - return - - try: - json_candidate = None - - markdown_code_block_pattern = re.compile( - r'```(?:(json|tool_code))?\s*(.*?)\s*```', re.DOTALL - ) - block_match = markdown_code_block_pattern.search(response_text) - - if block_match: - json_candidate = block_match.group(2).strip() - else: - found, json_text = _get_last_valid_json_substring(response_text) - if found: - json_candidate = json_text - - if not json_candidate: - return - - function_call_parsed = GemmaFunctionCallModel.model_validate_json( - json_candidate - ) - function_call = types.FunctionCall( - name=function_call_parsed.name, - args=function_call_parsed.parameters, - ) - function_call_part = Part(function_call=function_call) - llm_response.content.parts = [function_call_part] - except (json.JSONDecodeError, ValidationError) as e: - logger.debug( - f'Error attempting to parse JSON into function call. Leaving as text' - f' response. %s', - e, - ) - except Exception as e: - logger.warning('Error processing Gemma function call response: %s', e) - @override async def _preprocess_request(self, llm_request: LlmRequest) -> None: - self._move_function_calls_into_system_instruction(llm_request=llm_request) + _move_function_calls_into_system_instruction(llm_request=llm_request) if system_instruction := llm_request.config.system_instruction: contents = llm_request.contents instruction_content = Content( - role='user', parts=[Part.from_text(text=system_instruction)] + role="user", parts=[Part.from_text(text=system_instruction)] ) # NOTE: if history is preserved, we must include the system instructions ONLY once at the beginning @@ -214,7 +117,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None: async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False ) -> AsyncGenerator[LlmResponse, None]: - """Sends a request to the Gemma model. + """Sends a request to the Gemma model via Gemini API. Args: llm_request: LlmRequest, the request to send to the Gemini model. @@ -223,17 +126,156 @@ async def generate_content_async( Yields: LlmResponse: The model response. """ - # print(f'{llm_request=}') - assert llm_request.model.startswith('gemma-'), ( - f'Requesting a non-Gemma model ({llm_request.model}) with the Gemma LLM' - ' is not supported.' - ) + async for response in super().generate_content_async(llm_request, stream): + _extract_function_calls_from_response(response) + yield response + + +class Gemma3Ollama(LiteLlm): + """Integration for Gemma 3 models exposed via Ollama. + Only the larger Gemma 3 model sizes are supported (12b, 27b) as + function calling support on the smaller models is not consistent + enough for basic agent usage. + """ + + def __init__(self, model: str = "ollama/gemma3:12b", **kwargs): + if model not in self.supported_models(): + raise ValueError( + f"Model '{model}' not supported. Use one of:" + f" {self.supported_models()=}" + ) + _register_gemma_prompt_template(model) + super().__init__(model, **kwargs) + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + return ["ollama/gemma3:12b", "ollama/gemma3:27b"] + + def _preprocess_request(self, llm_request: LlmRequest) -> None: + _move_function_calls_into_system_instruction(llm_request=llm_request) + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Sends a request to the Gemma model hosted on Ollama via LiteLLM integration. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + stream: bool = False, whether to do streaming call. + + Yields: + LlmResponse: The model response. + """ + self._preprocess_request(llm_request) async for response in super().generate_content_async(llm_request, stream): - self._extract_function_calls_from_response(response) + _extract_function_calls_from_response(response) yield response +def _move_function_calls_into_system_instruction(llm_request: LlmRequest): + if llm_request.model is None or not ( + "gemma3" in llm_request.model or llm_request.model.startswith("gemma-3") + ): + return + + # Iterate through the existing contents to find and convert function calls and responses + # from text parts, as Gemma models don't directly support function calling. + new_contents: list[Content] = [] + for content_item in llm_request.contents: + ( + new_parts_for_content, + has_function_response_part, + has_function_call_part, + ) = _convert_content_parts_for_gemma(content_item) + + if has_function_response_part: + if new_parts_for_content: + new_contents.append(Content(role="user", parts=new_parts_for_content)) + elif has_function_call_part: + if new_parts_for_content: + new_contents.append(Content(role="model", parts=new_parts_for_content)) + else: + new_contents.append(content_item) + + llm_request.contents = new_contents + + if not llm_request.config.tools: + return + + all_function_declarations: list[FunctionDeclaration] = [] + for tool_item in llm_request.config.tools: + if isinstance(tool_item, types.Tool) and tool_item.function_declarations: + all_function_declarations.extend(tool_item.function_declarations) + + if all_function_declarations: + system_instruction = _build_gemma_function_system_instruction( + all_function_declarations + ) + llm_request.append_instructions([system_instruction]) + + llm_request.config.tools = [] + + +def _extract_function_calls_from_response(llm_response: LlmResponse): + if llm_response.partial or (llm_response.turn_complete is True): + return + + if not llm_response.content: + return + + if not llm_response.content.parts: + return + + if len(llm_response.content.parts) > 1: + return + + response_text = llm_response.content.parts[0].text + if not response_text: + return + + try: + import instructor + except ImportError as e: + logger.warning( + "The 'instructor' package is required for Gemma3 function calling but" + " is not installed. Text response will be returned. To enable function" + ' calling, run: pip install "google-adk[extensions]"' + ) + return + + try: + json_candidate = instructor.utils.extract_json_from_codeblock(response_text) + + if not json_candidate: + return + + function_call_parsed = Gemma3FunctionCallModel.model_validate_json( + json_candidate + ) + function_call = types.FunctionCall( + name=function_call_parsed.name, + args=function_call_parsed.parameters, + ) + function_call_part = Part(function_call=function_call) + llm_response.content.parts = [function_call_part] + except (json.JSONDecodeError, ValidationError) as e: + logger.debug( + f"Error attempting to parse JSON into function call. Leaving as text" + f" response. %s", + e, + ) + except Exception as e: + logger.warning("Error processing Gemma function call response: %s", e) + + def _convert_content_parts_for_gemma( content_item: Content, ) -> tuple[list[Part], bool, bool]: @@ -256,8 +298,8 @@ def _convert_content_parts_for_gemma( if func_response := part.function_response: has_function_response_part = True response_text = ( - f'Invoking tool `{func_response.name}` produced:' - f' `{json.dumps(func_response.response)}`.' + f"Invoking tool `{func_response.name}` produced:" + f" `{json.dumps(func_response.response)}`." ) new_parts.append(Part.from_text(text=response_text)) elif func_call := part.function_call: @@ -275,57 +317,43 @@ def _build_gemma_function_system_instruction( ) -> str: """Constructs the system instruction string for Gemma function calling.""" if not function_declarations: - return '' + return "" - system_instruction_prefix = 'You have access to the following functions:\n[' + system_instruction_prefix = "You have access to the following functions:\n[" instruction_parts = [] for func in function_declarations: instruction_parts.append(func.model_dump_json(exclude_none=True)) - separator = ',\n' + separator = ",\n" system_instruction = ( - f'{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n' + f"{system_instruction_prefix}{separator.join(instruction_parts)}\n]\n" ) system_instruction += ( - 'When you call a function, you MUST respond in the format of: ' + "When you call a function, you MUST respond in the format of: " """{"name": function name, "parameters": dictionary of argument name and its value}\n""" - 'When you call a function, you MUST NOT include any other text in the' - ' response.\n' + "When you call a function, you MUST NOT include any other text in the" + " response.\n" ) return system_instruction -def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]: - """Attempts to find and return the last valid JSON object in a string. - - This function is designed to extract JSON that might be embedded in a larger - text, potentially with introductory or concluding remarks. It will always chose - the last block of valid json found within the supplied text (if it exists). - - Args: - text: The input string to search for JSON objects. - - Returns: - A tuple: - - bool: True if a valid JSON substring was found, False otherwise. - - str | None: The last valid JSON substring found, or None if none was - found. - """ - decoder = json.JSONDecoder() - last_json_str = None - start_pos = 0 - while start_pos < len(text): - try: - first_brace_index = text.index('{', start_pos) - _, end_index = decoder.raw_decode(text[first_brace_index:]) - last_json_str = text[first_brace_index : first_brace_index + end_index] - start_pos = first_brace_index + end_index - except json.JSONDecodeError: - start_pos = first_brace_index + 1 - except ValueError: - break - - if last_json_str: - return True, last_json_str - return False, None +def _register_gemma_prompt_template(model: str): + litellm.register_prompt_template( + model=model, + roles={ + "system": { + "pre_message": "user:\n", + "post_message": "\n", + }, + "user": { + "pre_message": "user:\n", + "post_message": "\n", + }, + "assistant": { + "pre_message": "model:\n", + "post_message": "\n", + }, + }, + final_prompt_value="model:\n", + ) diff --git a/tests/integration/models/test_gemma_llm.py b/tests/integration/models/test_gemma_llm.py index 81b9672a18..b42886e445 100644 --- a/tests/integration/models/test_gemma_llm.py +++ b/tests/integration/models/test_gemma_llm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.adk.models.gemma_llm import Gemma +from google.adk.models.gemma_llm import Gemma3GeminiAPI from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types @@ -20,12 +20,12 @@ from google.genai.types import Part import pytest -DEFAULT_GEMMA_MODEL = "gemma-3-1b-it" +DEFAULT_GEMMA_MODEL = "gemma-3-12b-it" @pytest.fixture def gemma_llm(): - return Gemma(model=DEFAULT_GEMMA_MODEL) + return Gemma3GeminiAPI(model=DEFAULT_GEMMA_MODEL) @pytest.fixture diff --git a/tests/unittests/models/test_gemma_llm.py b/tests/unittests/models/test_gemma_llm.py index 2cf98306b9..002d2fa507 100644 --- a/tests/unittests/models/test_gemma_llm.py +++ b/tests/unittests/models/test_gemma_llm.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.adk.models.gemma_llm import Gemma +from google.adk.models.gemma_llm import _extract_function_calls_from_response +from google.adk.models.gemma_llm import _move_function_calls_into_system_instruction +from google.adk.models.gemma_llm import Gemma3GeminiAPI +from google.adk.models.gemma_llm import Gemma3Ollama from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types @@ -24,7 +27,7 @@ @pytest.fixture def llm_request(): return LlmRequest( - model="gemma-3-4b-it", + model="gemma-3-12b-it", contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], config=types.GenerateContentConfig( temperature=0.1, @@ -37,7 +40,7 @@ def llm_request(): @pytest.fixture def llm_request_with_duplicate_instruction(): return LlmRequest( - model="gemma-3-1b-it", + model="gemma-3-12b-it", contents=[ Content( role="user", @@ -55,7 +58,7 @@ def llm_request_with_duplicate_instruction(): @pytest.fixture def llm_request_with_tools(): return LlmRequest( - model="gemma-3-1b-it", + model="gemma-3-12b-it", contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], config=types.GenerateContentConfig( tools=[ @@ -86,15 +89,16 @@ def llm_request_with_tools(): ) -@pytest.mark.asyncio -async def test_not_gemma_model(): - llm = Gemma() - llm_request_bad_model = LlmRequest( - model="not-a-gemma-model", - ) - with pytest.raises(AssertionError, match=r".*model.*"): - async for _ in llm.generate_content_async(llm_request_bad_model): - pass +def test_not_gemma_model(): + with pytest.raises(ValueError, match=r".*model.*"): + _ = Gemma3GeminiAPI( + model="not-a-gemma-model", + ) + + with pytest.raises(ValueError, match=r".*model.*"): + _ = Gemma3Ollama( + model="not-a-gemma-model", + ) @pytest.mark.asyncio @@ -103,8 +107,8 @@ async def test_not_gemma_model(): ["llm_request", "llm_request_with_duplicate_instruction"], indirect=True, ) -async def test_preprocess_request(llm_request): - llm = Gemma() +async def test_gemma_gemini_preprocess_request(llm_request): + llm = Gemma3GeminiAPI() want_content_text = llm_request.config.system_instruction await llm._preprocess_request(llm_request) @@ -119,9 +123,11 @@ async def test_preprocess_request(llm_request): @pytest.mark.asyncio -async def test_preprocess_request_with_tools(llm_request_with_tools): +async def test_gemma_gemini_preprocess_request_with_tools( + llm_request_with_tools, +): - gemma = Gemma() + gemma = Gemma3GeminiAPI() await gemma._preprocess_request(llm_request_with_tools) assert not llm_request_with_tools.config.tools @@ -144,13 +150,13 @@ async def test_preprocess_request_with_tools(llm_request_with_tools): @pytest.mark.asyncio -async def test_preprocess_request_with_function_response(): +async def test_gemma_gemini_preprocess_request_with_function_response(): # Simulate an LlmRequest with a function response func_response_data = types.FunctionResponse( name="search_web", response={"results": [{"title": "ADK"}]} ) llm_request = LlmRequest( - model="gemma-3-1b-it", + model="gemma-3-12b-it", contents=[ types.Content( role="model", @@ -160,7 +166,7 @@ async def test_preprocess_request_with_function_response(): config=types.GenerateContentConfig(), ) - gemma = Gemma() + gemma = Gemma3GeminiAPI() await gemma._preprocess_request(llm_request) # Assertions: function response converted to user role text content @@ -178,7 +184,7 @@ async def test_preprocess_request_with_function_response(): @pytest.mark.asyncio -async def test_preprocess_request_with_function_call(): +async def test_gemma_gemini_preprocess_request_with_function_call(): func_call_data = types.FunctionCall(name="get_current_time", args={}) llm_request = LlmRequest( model="gemma-3-1b-it", @@ -189,7 +195,7 @@ async def test_preprocess_request_with_function_call(): ], ) - gemma = Gemma() + gemma = Gemma3GeminiAPI() await gemma._preprocess_request(llm_request) assert len(llm_request.contents) == 1 @@ -203,14 +209,14 @@ async def test_preprocess_request_with_function_call(): @pytest.mark.asyncio -async def test_preprocess_request_with_mixed_content(): +async def test_gemma_gemini_preprocess_request_with_mixed_content(): func_call = types.FunctionCall(name="get_weather", args={"city": "London"}) func_response = types.FunctionResponse( name="get_weather", response={"temp": "15C"} ) llm_request = LlmRequest( - model="gemma-3-1b-it", + model="gemma-3-12b-it", contents=[ types.Content( role="user", parts=[types.Part.from_text(text="Hello!")] @@ -228,7 +234,7 @@ async def test_preprocess_request_with_mixed_content(): ], ) - gemma = Gemma() + gemma = Gemma3GeminiAPI() await gemma._preprocess_request(llm_request) # Assertions @@ -271,8 +277,7 @@ def test_process_response(): ) ) - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response=llm_response) + _extract_function_calls_from_response(llm_response=llm_response) # Assert that the content was transformed into a FunctionCall assert llm_response.content @@ -298,8 +303,7 @@ def test_process_response_invalid_json_text(): content=Content(role="model", parts=[Part.from_text(text=original_text)]) ) - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response=llm_response) + _extract_function_calls_from_response(llm_response=llm_response) # Assert that the content remains unchanged assert llm_response.content @@ -317,8 +321,8 @@ def test_process_response_malformed_json(): role="model", parts=[Part.from_text(text=malformed_json_str)] ) ) - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response=llm_response) + + _extract_function_calls_from_response(llm_response=llm_response) # Assert that the content remains unchanged because it doesn't match the expected schema assert llm_response.content @@ -329,22 +333,18 @@ def test_process_response_malformed_json(): def test_process_response_empty_content_or_multiple_parts(): - gemma = Gemma() # Test case 1: LlmResponse with no content llm_response_no_content = LlmResponse(content=None) - gemma._extract_function_calls_from_response( - llm_response=llm_response_no_content - ) + + _extract_function_calls_from_response(llm_response=llm_response_no_content) assert llm_response_no_content.content is None # Test case 2: LlmResponse with empty parts list llm_response_empty_parts = LlmResponse( content=Content(role="model", parts=[]) ) - gemma._extract_function_calls_from_response( - llm_response=llm_response_empty_parts - ) + _extract_function_calls_from_response(llm_response=llm_response_empty_parts) assert llm_response_empty_parts.content assert not llm_response_empty_parts.content.parts @@ -361,7 +361,7 @@ def test_process_response_empty_content_or_multiple_parts(): original_parts = list( llm_response_multiple_parts.content.parts ) # Copy for comparison - gemma._extract_function_calls_from_response( + _extract_function_calls_from_response( llm_response=llm_response_multiple_parts ) assert llm_response_multiple_parts.content @@ -373,7 +373,7 @@ def test_process_response_empty_content_or_multiple_parts(): llm_response_empty_text_part = LlmResponse( content=Content(role="model", parts=[Part.from_text(text="")]) ) - gemma._extract_function_calls_from_response( + _extract_function_calls_from_response( llm_response=llm_response_empty_text_part ) assert llm_response_empty_text_part.content @@ -394,8 +394,7 @@ def test_process_response_with_markdown_json_block(): ) ) - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response) + _extract_function_calls_from_response(llm_response) assert llm_response.content assert llm_response.content.parts @@ -411,7 +410,7 @@ def test_process_response_with_markdown_tool_code_block(): # Simulate a response from Gemma with a JSON function call in a 'tool_code' markdown block json_function_call_str = """ Some text before. -```tool_code +```json {"name": "get_current_time", "parameters": {}} ``` And some text after.""" @@ -421,14 +420,13 @@ def test_process_response_with_markdown_tool_code_block(): ) ) - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response) + _extract_function_calls_from_response(llm_response) assert llm_response.content assert llm_response.content.parts assert len(llm_response.content.parts) == 1 part = llm_response.content.parts[0] - assert part.function_call is not None + assert part.function_call is not None, f"{part=}" assert part.function_call.name == "get_current_time" assert part.function_call.args == {} assert part.text is None @@ -446,8 +444,7 @@ def test_process_response_with_embedded_json(): ) ) - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response) + _extract_function_calls_from_response(llm_response) assert llm_response.content assert llm_response.content.parts @@ -468,8 +465,7 @@ def test_process_response_flexible_parsing(): ) ) - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response) + _extract_function_calls_from_response(llm_response) assert llm_response.content assert llm_response.content.parts @@ -479,28 +475,3 @@ def test_process_response_flexible_parsing(): assert part.function_call.name == "do_something" assert part.function_call.args == {"value": 123} assert part.text is None - - -def test_process_response_last_json_object(): - # Simulate a response with multiple JSON objects, ensuring the last valid one is picked - multiple_json_str = ( - 'I thought about {"name": "first_call", "parameters": {"a": 1}} but then' - ' decided to call: {"name": "second_call", "parameters": {"b": 2}}' - ) - llm_response = LlmResponse( - content=Content( - role="model", parts=[Part.from_text(text=multiple_json_str)] - ) - ) - - gemma = Gemma() - gemma._extract_function_calls_from_response(llm_response) - - assert llm_response.content - assert llm_response.content.parts - assert len(llm_response.content.parts) == 1 - part = llm_response.content.parts[0] - assert part.function_call is not None - assert part.function_call.name == "second_call" - assert part.function_call.args == {"b": 2} - assert part.text is None