Skip to content

Commit 309a656

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: extract credentail key building logic to auth_config
PiperOrigin-RevId: 768124459
1 parent 9abb841 commit 309a656

File tree

4 files changed

+131
-56
lines changed

4 files changed

+131
-56
lines changed

src/google/adk/auth/auth_handler.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def exchange_auth_token(
112112

113113
def parse_and_store_auth_response(self, state: State) -> None:
114114

115-
credential_key = self.get_credential_key()
115+
credential_key = "temp:" + self.auth_config.get_credential_key()
116116

117117
state[credential_key] = self.auth_config.exchanged_auth_credential
118118
if not isinstance(
@@ -130,7 +130,7 @@ def _validate(self) -> None:
130130
raise ValueError("auth_scheme is empty.")
131131

132132
def get_auth_response(self, state: State) -> AuthCredential:
133-
credential_key = self.get_credential_key()
133+
credential_key = "temp:" + self.auth_config.get_credential_key()
134134
return state.get(credential_key, None)
135135

136136
def generate_auth_request(self) -> AuthConfig:
@@ -192,29 +192,6 @@ def generate_auth_request(self) -> AuthConfig:
192192
exchanged_auth_credential=exchanged_credential,
193193
)
194194

195-
def get_credential_key(self) -> str:
196-
"""Generates a unique key for the given auth scheme and credential."""
197-
auth_scheme = self.auth_config.auth_scheme
198-
auth_credential = self.auth_config.raw_auth_credential
199-
if auth_scheme.model_extra:
200-
auth_scheme = auth_scheme.model_copy(deep=True)
201-
auth_scheme.model_extra.clear()
202-
scheme_name = (
203-
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
204-
if auth_scheme
205-
else ""
206-
)
207-
if auth_credential.model_extra:
208-
auth_credential = auth_credential.model_copy(deep=True)
209-
auth_credential.model_extra.clear()
210-
credential_name = (
211-
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
212-
if auth_credential
213-
else ""
214-
)
215-
216-
return f"temp:adk_{scheme_name}_{credential_name}"
217-
218195
def generate_auth_uri(
219196
self,
220197
) -> AuthCredential:

src/google/adk/auth/auth_tool.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from .auth_credential import AuthCredential
1618
from .auth_credential import BaseModelWithConfig
1719
from .auth_schemes import AuthScheme
@@ -43,6 +45,34 @@ class AuthConfig(BaseModelWithConfig):
4345
this field to guide the user through the OAuth2 flow and fill auth response in
4446
this field"""
4547

48+
def get_credential_key(self):
49+
"""Generates a hash key based on auth_scheme and raw_auth_credential. This
50+
hash key can be used to store / retrieve exchanged_auth_credential in a
51+
credentials store.
52+
"""
53+
auth_scheme = self.auth_scheme
54+
55+
if auth_scheme.model_extra:
56+
auth_scheme = auth_scheme.model_copy(deep=True)
57+
auth_scheme.model_extra.clear()
58+
scheme_name = (
59+
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
60+
if auth_scheme
61+
else ""
62+
)
63+
64+
auth_credential = self.raw_auth_credential
65+
if auth_credential.model_extra:
66+
auth_credential = auth_credential.model_copy(deep=True)
67+
auth_credential.model_extra.clear()
68+
credential_name = (
69+
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
70+
if auth_credential
71+
else ""
72+
)
73+
74+
return f"adk_{scheme_name}_{credential_name}"
75+
4676

4777
class AuthToolArguments(BaseModelWithConfig):
4878
"""the arguments for the special long running function tool that is used to
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from fastapi.openapi.models import OAuth2
16+
from fastapi.openapi.models import OAuthFlowAuthorizationCode
17+
from fastapi.openapi.models import OAuthFlows
18+
from google.adk.auth.auth_credential import AuthCredential
19+
from google.adk.auth.auth_credential import AuthCredentialTypes
20+
from google.adk.auth.auth_credential import OAuth2Auth
21+
from google.adk.auth.auth_tool import AuthConfig
22+
import pytest
23+
24+
25+
class TestAuthConfig:
26+
"""Tests for the AuthConfig method."""
27+
28+
29+
@pytest.fixture
30+
def oauth2_auth_scheme():
31+
"""Create an OAuth2 auth scheme for testing."""
32+
# Create the OAuthFlows object first
33+
flows = OAuthFlows(
34+
authorizationCode=OAuthFlowAuthorizationCode(
35+
authorizationUrl="https://example.com/oauth2/authorize",
36+
tokenUrl="https://example.com/oauth2/token",
37+
scopes={"read": "Read access", "write": "Write access"},
38+
)
39+
)
40+
41+
# Then create the OAuth2 object with the flows
42+
return OAuth2(flows=flows)
43+
44+
45+
@pytest.fixture
46+
def oauth2_credentials():
47+
"""Create OAuth2 credentials for testing."""
48+
return AuthCredential(
49+
auth_type=AuthCredentialTypes.OAUTH2,
50+
oauth2=OAuth2Auth(
51+
client_id="mock_client_id",
52+
client_secret="mock_client_secret",
53+
redirect_uri="https://example.com/callback",
54+
),
55+
)
56+
57+
58+
@pytest.fixture
59+
def auth_config(oauth2_auth_scheme, oauth2_credentials):
60+
"""Create an AuthConfig for testing."""
61+
# Create a copy of the credentials for the exchanged_auth_credential
62+
exchanged_credential = oauth2_credentials.model_copy(deep=True)
63+
64+
return AuthConfig(
65+
auth_scheme=oauth2_auth_scheme,
66+
raw_auth_credential=oauth2_credentials,
67+
exchanged_auth_credential=exchanged_credential,
68+
)
69+
70+
71+
def test_get_credential_key(auth_config):
72+
"""Test generating a unique credential key."""
73+
74+
key = auth_config.get_credential_key()
75+
assert key.startswith("adk_oauth2_")
76+
assert "_oauth2_" in key
77+
78+
79+
def test_get_credential_key_with_extras(auth_config):
80+
"""Test generating a key when model_extra exists."""
81+
# Add model_extra to test cleanup
82+
83+
original_key = auth_config.get_credential_key()
84+
key = auth_config.get_credential_key()
85+
86+
auth_config.auth_scheme.model_extra["extra_field"] = "value"
87+
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
88+
89+
assert original_key == key
90+
assert "extra_field" in auth_config.auth_scheme.model_extra
91+
assert "extra_field" in auth_config.raw_auth_credential.model_extra

tests/unittests/auth/test_auth_handler.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -209,31 +209,6 @@ def test_init(self, auth_config):
209209
assert handler.auth_config == auth_config
210210

211211

212-
class TestGetCredentialKey:
213-
"""Tests for the get_credential_key method."""
214-
215-
def test_get_credential_key(self, auth_config):
216-
"""Test generating a unique credential key."""
217-
handler = AuthHandler(auth_config)
218-
key = handler.get_credential_key()
219-
assert key.startswith("temp:adk_oauth2_")
220-
assert "_oauth2_" in key
221-
222-
def test_get_credential_key_with_extras(self, auth_config):
223-
"""Test generating a key when model_extra exists."""
224-
# Add model_extra to test cleanup
225-
226-
original_key = AuthHandler(auth_config).get_credential_key()
227-
key = AuthHandler(auth_config).get_credential_key()
228-
229-
auth_config.auth_scheme.model_extra["extra_field"] = "value"
230-
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
231-
232-
assert original_key == key
233-
assert "extra_field" in auth_config.auth_scheme.model_extra
234-
assert "extra_field" in auth_config.raw_auth_credential.model_extra
235-
236-
237212
class TestGenerateAuthUri:
238213
"""Tests for the generate_auth_uri method."""
239214

@@ -412,8 +387,8 @@ def test_get_auth_response_exists(
412387
state = MockState()
413388

414389
# Store a credential in the state
415-
credential_key = handler.get_credential_key()
416-
state[credential_key] = oauth2_credentials_with_auth_uri
390+
credential_key = auth_config.get_credential_key()
391+
state["temp:" + credential_key] = oauth2_credentials_with_auth_uri
417392

418393
result = handler.get_auth_response(state)
419394
assert result == oauth2_credentials_with_auth_uri
@@ -443,8 +418,10 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged):
443418

444419
handler.parse_and_store_auth_response(state)
445420

446-
credential_key = handler.get_credential_key()
447-
assert state[credential_key] == auth_config.exchanged_auth_credential
421+
credential_key = auth_config.get_credential_key()
422+
assert (
423+
state["temp:" + credential_key] == auth_config.exchanged_auth_credential
424+
)
448425

449426
@patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token")
450427
def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
@@ -459,8 +436,8 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
459436

460437
handler.parse_and_store_auth_response(state)
461438

462-
credential_key = handler.get_credential_key()
463-
assert state[credential_key] == mock_exchange_token.return_value
439+
credential_key = auth_config_with_exchanged.get_credential_key()
440+
assert state["temp:" + credential_key] == mock_exchange_token.return_value
464441
assert mock_exchange_token.called
465442

466443

0 commit comments

Comments
 (0)