Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/google/adk/models/apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Optional
from typing import TYPE_CHECKING

import google.auth
from google.adk import version as adk_version
from google.genai import types
import httpx
Expand All @@ -52,6 +53,11 @@
_PROJECT_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_PROJECT'
_LOCATION_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_LOCATION'

_APIGEE_SCOPES = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
]

_CUSTOM_METADATA_FIELDS = (
'id',
'created',
Expand Down Expand Up @@ -232,13 +238,16 @@ def api_client(self) -> Client:
**kwargs_for_http_options,
)

credentials, _ = google.auth.default(scopes=_APIGEE_SCOPES)

kwargs_for_client = {}
kwargs_for_client['vertexai'] = self._isvertexai
if self._isvertexai:
kwargs_for_client['project'] = self._project
kwargs_for_client['location'] = self._location

return Client(
credentials=credentials,
http_options=http_options,
**kwargs_for_client,
)
Expand Down
38 changes: 38 additions & 0 deletions tests/unittests/models/test_apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from unittest import mock
from unittest.mock import AsyncMock

from google.adk.models.apigee_llm import _APIGEE_SCOPES
from google.adk.models.apigee_llm import ApigeeLlm
from google.adk.models.apigee_llm import CompletionsHTTPClient
from google.adk.models.llm_request import LlmRequest
Expand Down Expand Up @@ -649,3 +650,40 @@ def test_parse_response_usage_metadata():
assert llm_response.usage_metadata.candidates_token_count == 5
assert llm_response.usage_metadata.total_token_count == 15
assert llm_response.usage_metadata.thoughts_token_count == 4


@pytest.mark.asyncio
@mock.patch('google.genai.Client')
@mock.patch('google.adk.models.apigee_llm.google.auth.default')
async def test_api_client_requests_userinfo_email_scope(
mock_auth_default, mock_client_constructor, llm_request
):
"""Tests that api_client requests userinfo.email scope for Apigee Gateway tokeninfo."""
mock_credentials = mock.Mock()
mock_auth_default.return_value = (mock_credentials, 'test-project')

mock_client_instance = mock.Mock()
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
)
mock_client_constructor.return_value = mock_client_instance

apigee_llm = ApigeeLlm(
model=APIGEE_GEMINI_MODEL_ID,
proxy_url=PROXY_URL,
)
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]

mock_auth_default.assert_called_once_with(scopes=_APIGEE_SCOPES)

_, kwargs = mock_client_constructor.call_args
assert kwargs['credentials'] is mock_credentials