diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 950c4baedb..748219ab7f 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING from google.adk import version as adk_version +import google.auth from google.genai import types import httpx import tenacity @@ -52,6 +53,11 @@ _PROJECT_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_PROJECT' _LOCATION_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_LOCATION' +_APIGEE_SCOPES = [ + 'https://www.googleapis.com/auth/cloud-platform', + 'https://www.googleapis.com/auth/userinfo.email', +] + _CUSTOM_METADATA_FIELDS = ( 'id', 'created', @@ -232,6 +238,8 @@ def api_client(self) -> Client: **kwargs_for_http_options, ) + credentials, _ = google.auth.default(scopes=_APIGEE_SCOPES) + kwargs_for_client = {} kwargs_for_client['vertexai'] = self._isvertexai if self._isvertexai: @@ -239,6 +247,7 @@ def api_client(self) -> Client: kwargs_for_client['location'] = self._location return Client( + credentials=credentials, http_options=http_options, **kwargs_for_client, ) diff --git a/tests/unittests/models/test_apigee_llm.py b/tests/unittests/models/test_apigee_llm.py index f48039fadd..b22f615b89 100644 --- a/tests/unittests/models/test_apigee_llm.py +++ b/tests/unittests/models/test_apigee_llm.py @@ -18,6 +18,7 @@ from unittest import mock from unittest.mock import AsyncMock +from google.adk.models.apigee_llm import _APIGEE_SCOPES from google.adk.models.apigee_llm import ApigeeLlm from google.adk.models.apigee_llm import CompletionsHTTPClient from google.adk.models.llm_request import LlmRequest @@ -33,6 +34,16 @@ PROXY_URL = 'https://test.apigee.net' +@pytest.fixture(autouse=True) +def mock_google_auth_default(): + """Mocks google.auth.default to avoid requiring real credentials in tests.""" + with mock.patch( + 'google.adk.models.apigee_llm.google.auth.default' + ) as mock_auth: + mock_auth.return_value = (mock.Mock(), 'test-project') + yield mock_auth + + @pytest.fixture def llm_request(): """Provides a sample LlmRequest for testing.""" @@ -649,3 +660,39 @@ def test_parse_response_usage_metadata(): assert llm_response.usage_metadata.candidates_token_count == 5 assert llm_response.usage_metadata.total_token_count == 15 assert llm_response.usage_metadata.thoughts_token_count == 4 + + +@pytest.mark.asyncio +@mock.patch('google.genai.Client') +async def test_api_client_requests_userinfo_email_scope( + mock_client_constructor, llm_request, mock_google_auth_default +): + """Tests that api_client requests userinfo.email scope for Apigee Gateway tokeninfo.""" + mock_credentials = mock.Mock() + mock_google_auth_default.return_value = (mock_credentials, 'test-project') + + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + apigee_llm = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + ) + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + mock_google_auth_default.assert_called_once_with(scopes=_APIGEE_SCOPES) + + _, kwargs = mock_client_constructor.call_args + assert kwargs['credentials'] is mock_credentials