Skip to content

Commit 6d0f4d1

Browse files
committed
feat(auth): Add native support for id_token in OAuth2 credentials
**Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** **2. Or, if no issue exists, describe the change:** When performing authentication flows via `OAUTH2` or `OPEN_ID_CONNECT`, the native `OAuth2Token` response from identity providers (like Google OAuth) often includes an `id_token` alongside the `access_token` and `refresh_token`. However, the ADK's `update_credential_with_tokens` utility explicitly drops the `id_token`, preventing agents and tools from verifying user identity or extracting OIDC claims securely. Furthermore, the `OAuth2Auth` model does not have a designated field for `id_token`. 1. Added an `id_token: Optional[str] = None` field to the `OAuth2Auth` pydantic model in `auth_credential.py`. 2. Updated `update_credential_with_tokens` in `oauth2_credential_util.py` to correctly extract and map `tokens.get("id_token")` into the `OAuth2Auth` credential object. 3. Updated the relevant unit tests to ensure `id_token` is asserted and preserved during credential updates. **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. Summary of passed `pytest` results: ```bash $ pytest tests/unittests/auth/test_oauth2_credential_util.py ======================= test session starts ======================= platform darwin -- Python 3.11.9, pytest-9.0.1, pluggy-1.6.0 collected 9 items tests/unittests/auth/test_oauth2_credential_util.py ......... [100%] ======================== 9 passed in 0.05s ========================
1 parent a2e43aa commit 6d0f4d1

3 files changed

Lines changed: 23 additions & 8 deletions

File tree

src/google/adk/auth/auth_credential.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class OAuth2Auth(BaseModelWithConfig):
7979
auth_code: Optional[str] = None
8080
access_token: Optional[str] = None
8181
refresh_token: Optional[str] = None
82+
id_token: Optional[str] = None
8283
expires_at: Optional[int] = None
8384
expires_in: Optional[int] = None
8485
audience: Optional[str] = None

src/google/adk/auth/oauth2_credential_util.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ def update_credential_with_tokens(
107107
auth_credential: The authentication credential to update.
108108
tokens: The OAuth2Token object containing new token information.
109109
"""
110-
auth_credential.oauth2.access_token = tokens.get("access_token")
111-
auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
112-
auth_credential.oauth2.expires_at = (
113-
int(tokens.get("expires_at")) if tokens.get("expires_at") else None
114-
)
115-
auth_credential.oauth2.expires_in = (
116-
int(tokens.get("expires_in")) if tokens.get("expires_in") else None
117-
)
110+
if auth_credential.oauth2:
111+
auth_credential.oauth2.access_token = tokens.get("access_token")
112+
auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
113+
auth_credential.oauth2.id_token = tokens.get("id_token")
114+
auth_credential.oauth2.expires_at = (
115+
int(tokens.get("expires_at")) if tokens.get("expires_at") else None
116+
)
117+
auth_credential.oauth2.expires_in = (
118+
int(tokens.get("expires_in")) if tokens.get("expires_in") else None
119+
)

tests/unittests/auth/test_oauth2_credential_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def test_update_credential_with_tokens(self):
222222
tokens = OAuth2Token({
223223
"access_token": "new_access_token",
224224
"refresh_token": "new_refresh_token",
225+
"id_token": "new_id_token",
225226
"expires_at": expected_expires_at,
226227
"expires_in": 3600,
227228
})
@@ -230,5 +231,16 @@ def test_update_credential_with_tokens(self):
230231

231232
assert credential.oauth2.access_token == "new_access_token"
232233
assert credential.oauth2.refresh_token == "new_refresh_token"
234+
assert credential.oauth2.id_token == "new_id_token"
233235
assert credential.oauth2.expires_at == expected_expires_at
234236
assert credential.oauth2.expires_in == 3600
237+
238+
def test_update_credential_with_tokens_none(self):
239+
credential = AuthCredential(
240+
auth_type=AuthCredentialTypes.API_KEY,
241+
)
242+
tokens = OAuth2Token({"access_token": "new_access_token"})
243+
244+
# Should not raise any exceptions when oauth2 is None
245+
update_credential_with_tokens(credential, tokens)
246+
assert credential.oauth2 is None

0 commit comments

Comments
 (0)