Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
44 changes: 34 additions & 10 deletions src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -48,9 +43,27 @@ async def exchange_auth_token(
self,
) -> AuthCredential:
exchanger = OAuth2CredentialExchanger()
return await exchanger.exchange(
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
)

# Restore secret if needed
credential = self.auth_config.exchanged_auth_credential
redacted = False

if credential and credential.oauth2 and credential.oauth2.client_id:
# Check if secret needs restoration
from .credential_manager import CredentialManager

secret = CredentialManager.get_client_secret(credential.oauth2.client_id)
if secret and credential.oauth2.client_secret == "<redacted>":
credential.oauth2.client_secret = secret
redacted = True

try:
res = await exchanger.exchange(credential, self.auth_config.auth_scheme)
return res
finally:
# Always re-redact if we restored it
if redacted and credential and credential.oauth2:
credential.oauth2.client_secret = "<redacted>"

async def parse_and_store_auth_response(self, state: State) -> None:

Expand Down Expand Up @@ -182,9 +195,20 @@ def generate_auth_uri(
)
scopes = list(scopes.keys())

client_id = auth_credential.oauth2.client_id
client_secret = auth_credential.oauth2.client_secret

# Check if secret is redacted and restore it from manager
if client_secret == "<redacted>" and client_id:
from .credential_manager import CredentialManager

secret = CredentialManager.get_client_secret(client_id)
if secret:
client_secret = secret

client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
client_id,
client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
Expand Down
9 changes: 7 additions & 2 deletions src/google/adk/auth/auth_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ def get_credential_key(self):
)

auth_credential = self.raw_auth_credential
if auth_credential and auth_credential.model_extra:
if auth_credential and (auth_credential.model_extra or auth_credential.oauth2):
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()
if auth_credential.model_extra:
auth_credential.model_extra.clear()
# Normalize secret to ensure stable key regardless of redaction
if auth_credential.oauth2:
auth_credential.oauth2.client_secret = None

credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
Expand Down
74 changes: 73 additions & 1 deletion src/google/adk/auth/credential_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,23 @@ class CredentialManager:
```
"""

# A map to store client secrets in memory. Key is client_id, value is client_secret
_CLIENT_SECRETS: dict[str, str] = {}

def __init__(
self,
auth_config: AuthConfig,
):
self._auth_config = auth_config
# We deep copy the auth_config to avoid modifying the original object passed
# by the user. This allows for safe redaction of sensitive information without
# causing side effects.

self._auth_config = auth_config.model_copy(deep=True)

# Secure the client secret
self._secure_client_secret(self._auth_config.raw_auth_credential)
self._secure_client_secret(self._auth_config.exchanged_auth_credential)

self._exchanger_registry = CredentialExchangerRegistry()
self._refresher_registry = CredentialRefresherRegistry()
self._discovery_manager = OAuth2DiscoveryManager()
Expand Down Expand Up @@ -110,6 +122,36 @@ def __init__(
AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher
)

def _secure_client_secret(self, credential: Optional[AuthCredential]):
"""Extracts client secret to memory and redacts it from the credential."""
if (
credential
and credential.oauth2
and credential.oauth2.client_id
and credential.oauth2.client_secret
and credential.oauth2.client_secret != "<redacted>"
):
logger.info(
f"Securing client secret for client_id: {credential.oauth2.client_id}"
)
# Store in memory map
self._CLIENT_SECRETS[credential.oauth2.client_id] = (
credential.oauth2.client_secret
)
# Redact from config
credential.oauth2.client_secret = "<redacted>"
else:
if credential and credential.oauth2:
logger.debug(
f"Not securing secret for client_id {credential.oauth2.client_id}:"
f" secret is {credential.oauth2.client_secret}"
)

@staticmethod
def get_client_secret(client_id: str) -> Optional[str]:
"""Retrieves the client secret for a given client_id."""
return CredentialManager._CLIENT_SECRETS.get(client_id)

def register_credential_exchanger(
self,
credential_type: AuthCredentialTypes,
Expand All @@ -124,6 +166,9 @@ def register_credential_exchanger(
self._exchanger_registry.register(credential_type, exchanger_instance)

async def request_credential(self, callback_context: CallbackContext) -> None:
# We send the auth_config (which is already redacted in __init__) to the client
# Note: we need to ensure we don't send any stale exchanged credentials if they are not valid
# But usually CredentialManager manages that.
callback_context.request_credential(self._auth_config)

async def get_auth_credential(
Expand Down Expand Up @@ -218,10 +263,37 @@ async def _exchange_credential(
self._auth_config.auth_scheme, credential
)
else:
# Restore client secret from memory map for exchange
restored = False
if (
credential.oauth2
and credential.oauth2.client_id
and credential.oauth2.client_id in self._CLIENT_SECRETS
):
credential.oauth2.client_secret = self._CLIENT_SECRETS[
credential.oauth2.client_id
]
restored = True
elif (
self._auth_config.raw_auth_credential
and self._auth_config.raw_auth_credential.oauth2
and self._auth_config.raw_auth_credential.oauth2.client_id
in self._CLIENT_SECRETS
):
# Fallback to look up using raw credential client id if credential client id is missing (unlikely for valid flow)
credential.oauth2.client_secret = self._CLIENT_SECRETS[
self._auth_config.raw_auth_credential.oauth2.client_id
]
restored = True

exchanged_credential = await exchanger.exchange(
credential, self._auth_config.auth_scheme
)

# Redact client secret again after exchange to prevent leakage
if exchanged_credential.oauth2:
exchanged_credential.oauth2.client_secret = "<redacted>"

return exchanged_credential, True

async def _refresh_credential(
Expand Down
127 changes: 127 additions & 0 deletions tests/unittests/auth/test_auth_handler_secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch

from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_handler import AuthHandler
from google.adk.auth.auth_tool import AuthConfig
from google.adk.auth.credential_manager import CredentialManager
import pytest


class TestAuthHandlerSecrets:

def setUp(self):
# Clear secret store
CredentialManager._CLIENT_SECRETS = {}

@pytest.mark.asyncio
async def test_exchange_auth_token_restores_and_reredacts_secret(self):
client_id = "test_client_id"
secret = "super_secret_value"

# Setup secure storage
CredentialManager._CLIENT_SECRETS[client_id] = secret

# Create credential with redacted secret
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(client_id=client_id, client_secret="<redacted>"),
)

auth_config = Mock(spec=AuthConfig)
auth_config.exchanged_auth_credential = credential
auth_config.auth_scheme = Mock()

handler = AuthHandler(auth_config)

# Mock exchanger
mock_exchanger = AsyncMock()

# Check secret inside exchange
def check_secret(cred, scheme):
assert cred.oauth2.client_secret == secret
return cred

mock_exchanger.exchange.side_effect = check_secret

with patch(
"google.adk.auth.auth_handler.OAuth2CredentialExchanger",
return_value=mock_exchanger,
):
await handler.exchange_auth_token()

# Verify secret is re-redacted
assert credential.oauth2.client_secret == "<redacted>"

def test_generate_auth_uri_uses_restored_secret(self):
client_id = "test_client_id"
secret = "super_secret_value"

# Setup secure storage
CredentialManager._CLIENT_SECRETS[client_id] = secret

# Create credential with redacted secret
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id=client_id,
client_secret="<redacted>",
redirect_uri="http://localhost/callback",
),
)

auth_config = Mock(spec=AuthConfig)
auth_config.raw_auth_credential = credential
auth_config.auth_scheme = Mock()
# Mock flows for scopes
auth_config.auth_scheme.flows.implicit = None
auth_config.auth_scheme.flows.clientCredentials = None
auth_config.auth_scheme.flows.password = None
auth_config.auth_scheme.flows.authorizationCode.scopes = {"scope": "desc"}
auth_config.auth_scheme.flows.authorizationCode.authorizationUrl = (
"http://auth"
)

handler = AuthHandler(auth_config)

# Mock OAuth2Session
with (
patch("google.adk.auth.auth_handler.OAuth2Session") as mock_session_cls,
patch("google.adk.auth.auth_handler.AUTHLIB_AVAILABLE", True),
):

mock_session = Mock()
mock_session.create_authorization_url.return_value = (
"http://auth?param=1",
"state",
)
mock_session_cls.return_value = mock_session

handler.generate_auth_uri()

# Verify session was created with the REAL secret, not redacted one
mock_session_cls.assert_called_with(
client_id,
secret,
scope="scope",
redirect_uri="http://localhost/callback",
)
Loading