Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
46 changes: 36 additions & 10 deletions src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -48,9 +43,29 @@ async def exchange_auth_token(
self,
) -> AuthCredential:
exchanger = OAuth2CredentialExchanger()
return await exchanger.exchange(
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
)

# Restore secret if needed
credential = self.auth_config.exchanged_auth_credential
redacted = False
original_secret = None

if credential and credential.oauth2 and credential.oauth2.client_id:
# Check if secret needs restoration
from .credential_manager import CredentialManager

secret = CredentialManager.get_client_secret(credential.oauth2.client_id)
if secret and credential.oauth2.client_secret == "<redacted>":
original_secret = credential.oauth2.client_secret
credential.oauth2.client_secret = secret
redacted = True

try:
res = await exchanger.exchange(credential, self.auth_config.auth_scheme)
return res
finally:
# Always re-redact if we restored it
if redacted and credential and credential.oauth2:
credential.oauth2.client_secret = "<redacted>"

async def parse_and_store_auth_response(self, state: State) -> None:

Expand Down Expand Up @@ -182,9 +197,20 @@ def generate_auth_uri(
)
scopes = list(scopes.keys())

client_id = auth_credential.oauth2.client_id
client_secret = auth_credential.oauth2.client_secret

# Check if secret is redacted and restore it from manager
if client_secret == "<redacted>" and client_id:
from .credential_manager import CredentialManager

secret = CredentialManager.get_client_secret(client_id)
if secret:
client_secret = secret

client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
client_id,
client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/auth/auth_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def get_credential_key(self):
if auth_credential and auth_credential.model_extra:
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()

# Normalize secret to ensure stable key regardless of redaction
if auth_credential and auth_credential.oauth2:
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.oauth2.client_secret = None

credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
Expand Down
81 changes: 80 additions & 1 deletion src/google/adk/auth/credential_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,30 @@ class CredentialManager:
```
"""

# A map to store client secrets in memory. Key is client_id, value is client_secret
_CLIENT_SECRETS: dict[str, str] = {}

def __init__(
self,
auth_config: AuthConfig,
):
self._auth_config = auth_config
# We deep copy the auth_config to avoid modifying the original one passed by user
# and to ensure we can safely redact sensitive information
# However we cannot rely on copy.deepcopy because AuthConfig is a pydantic model
# and deepcopy on pydantic model is not always reliable? No, deepcopy works.
# But better to use model_copy if possible. AuthConfig inherits BaseModelWithConfig which is Pydantic.
# auth_config.model_copy(deep=True) is available in Pydantic V2.
# For safe side, we use the passed instance but we redact sensitive info immediately.
# Wait, modifying passed instance is bad practice if user reuses it.
# But CredentialManager usually takes ownership?
# Let's perform redaction on `self._auth_config` which we assign.
# And we should clone it first.
self._auth_config = auth_config.model_copy(deep=True)

# Secure the client secret
self._secure_client_secret(self._auth_config.raw_auth_credential)
self._secure_client_secret(self._auth_config.exchanged_auth_credential)

self._exchanger_registry = CredentialExchangerRegistry()
self._refresher_registry = CredentialRefresherRegistry()
self._discovery_manager = OAuth2DiscoveryManager()
Expand Down Expand Up @@ -110,6 +129,36 @@ def __init__(
AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher
)

def _secure_client_secret(self, credential: Optional[AuthCredential]):
"""Extracts client secret to memory and redacts it from the credential."""
if (
credential
and credential.oauth2
and credential.oauth2.client_id
and credential.oauth2.client_secret
and credential.oauth2.client_secret != "<redacted>"
):
logger.info(
f"Securing client secret for client_id: {credential.oauth2.client_id}"
)
# Store in memory map
self._CLIENT_SECRETS[credential.oauth2.client_id] = (
credential.oauth2.client_secret
)
# Redact from config
credential.oauth2.client_secret = "<redacted>"
else:
if credential and credential.oauth2:
logger.debug(
f"Not securing secret for client_id {credential.oauth2.client_id}:"
f" secret is {credential.oauth2.client_secret}"
)

@staticmethod
def get_client_secret(client_id: str) -> Optional[str]:
"""Retrieves the client secret for a given client_id."""
return CredentialManager._CLIENT_SECRETS.get(client_id)

def register_credential_exchanger(
self,
credential_type: AuthCredentialTypes,
Expand All @@ -124,6 +173,9 @@ def register_credential_exchanger(
self._exchanger_registry.register(credential_type, exchanger_instance)

async def request_credential(self, callback_context: CallbackContext) -> None:
# We send the auth_config (which is already redacted in __init__) to the client
# Note: we need to ensure we don't send any stale exchanged credentials if they are not valid
# But usually CredentialManager manages that.
callback_context.request_credential(self._auth_config)

async def get_auth_credential(
Expand Down Expand Up @@ -218,10 +270,37 @@ async def _exchange_credential(
self._auth_config.auth_scheme, credential
)
else:
# Restore client secret from memory map for exchange
restored = False
if (
credential.oauth2
and credential.oauth2.client_id
and credential.oauth2.client_id in self._CLIENT_SECRETS
):
credential.oauth2.client_secret = self._CLIENT_SECRETS[
credential.oauth2.client_id
]
restored = True
elif (
self._auth_config.raw_auth_credential
and self._auth_config.raw_auth_credential.oauth2
and self._auth_config.raw_auth_credential.oauth2.client_id
in self._CLIENT_SECRETS
):
# Fallback to look up using raw credential client id if credential client id is missing (unlikely for valid flow)
credential.oauth2.client_secret = self._CLIENT_SECRETS[
self._auth_config.raw_auth_credential.oauth2.client_id
]
restored = True

exchanged_credential = await exchanger.exchange(
credential, self._auth_config.auth_scheme
)

# Redact client secret again after exchange to prevent leakage
if exchanged_credential.oauth2:
exchanged_credential.oauth2.client_secret = "<redacted>"

return exchanged_credential, True

async def _refresh_credential(
Expand Down
4 changes: 4 additions & 0 deletions src/google/adk/auth/oauth2_credential_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
import sys
from typing import Optional
from typing import Tuple

Expand Down Expand Up @@ -107,6 +108,9 @@ def update_credential_with_tokens(
tokens: The OAuth2Token object containing new token information.
"""
auth_credential.oauth2.access_token = tokens.get("access_token")
sys.stderr.write(
f"[DEBUG] Assigned access_token: {auth_credential.oauth2.access_token}\n"
)
auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
auth_credential.oauth2.expires_at = (
int(tokens.get("expires_at")) if tokens.get("expires_at") else None
Expand Down
Loading