Skip to content

Commit 1e22652

Browse files
Add vendor_id and finish_reason to Gemini/Google model responses (#1800)
1 parent 3b12530 commit 1e22652

File tree

5 files changed

+73
-8
lines changed

5 files changed

+73
-8
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ class ModelResponse:
589589
kind: Literal['response'] = 'response'
590590
"""Message type identifier, this is available on all parts as a discriminator."""
591591

592-
vendor_details: dict[str, Any] | None = field(default=None, repr=False)
592+
vendor_details: dict[str, Any] | None = field(default=None)
593593
"""Additional vendor-specific details in a serializable format.
594594
595595
This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ async def _make_request(
259259
yield r
260260

261261
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
262+
vendor_details: dict[str, Any] | None = None
263+
262264
if len(response['candidates']) != 1:
263265
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
264266
if 'content' not in response['candidates'][0]:
@@ -269,9 +271,19 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
269271
'Content field missing from Gemini response', str(response)
270272
)
271273
parts = response['candidates'][0]['content']['parts']
274+
vendor_id = response.get('vendor_id', None)
275+
finish_reason = response['candidates'][0].get('finish_reason')
276+
if finish_reason:
277+
vendor_details = {'finish_reason': finish_reason}
272278
usage = _metadata_as_usage(response)
273279
usage.requests = 1
274-
return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
280+
return _process_response_from_parts(
281+
parts,
282+
response.get('model_version', self._model_name),
283+
usage,
284+
vendor_id=vendor_id,
285+
vendor_details=vendor_details,
286+
)
275287

276288
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
277289
"""Process a streamed response, and prepare a streaming response to return."""
@@ -611,7 +623,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
611623

612624

613625
def _process_response_from_parts(
614-
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
626+
parts: Sequence[_GeminiPartUnion],
627+
model_name: GeminiModelName,
628+
usage: usage.Usage,
629+
vendor_id: str | None,
630+
vendor_details: dict[str, Any] | None = None,
615631
) -> ModelResponse:
616632
items: list[ModelResponsePart] = []
617633
for part in parts:
@@ -623,7 +639,9 @@ def _process_response_from_parts(
623639
raise UnexpectedModelBehavior(
624640
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
625641
)
626-
return ModelResponse(parts=items, usage=usage, model_name=model_name)
642+
return ModelResponse(
643+
parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
644+
)
627645

628646

629647
class _GeminiFunctionCall(TypedDict):
@@ -735,6 +753,7 @@ class _GeminiResponse(TypedDict):
735753
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
736754
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
737755
model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
756+
vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]
738757

739758

740759
class _GeminiCandidates(TypedDict):

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import asynccontextmanager
77
from dataclasses import dataclass, field, replace
88
from datetime import datetime
9-
from typing import Literal, Union, cast, overload
9+
from typing import Any, Literal, Union, cast, overload
1010
from uuid import uuid4
1111

1212
from typing_extensions import assert_never
@@ -291,9 +291,16 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
291291
'Content field missing from Gemini response', str(response)
292292
) # pragma: no cover
293293
parts = response.candidates[0].content.parts or []
294+
vendor_id = response.response_id or None
295+
vendor_details: dict[str, Any] | None = None
296+
finish_reason = response.candidates[0].finish_reason
297+
if finish_reason: # pragma: no branch
298+
vendor_details = {'finish_reason': finish_reason.value}
294299
usage = _metadata_as_usage(response)
295300
usage.requests = 1
296-
return _process_response_from_parts(parts, response.model_version or self._model_name, usage)
301+
return _process_response_from_parts(
302+
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
303+
)
297304

298305
async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
299306
"""Process a streamed response, and prepare a streaming response to return."""
@@ -439,7 +446,13 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
439446
return ContentDict(role='model', parts=parts)
440447

441448

442-
def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse:
449+
def _process_response_from_parts(
450+
parts: list[Part],
451+
model_name: GoogleModelName,
452+
usage: usage.Usage,
453+
vendor_id: str | None,
454+
vendor_details: dict[str, Any] | None = None,
455+
) -> ModelResponse:
443456
items: list[ModelResponsePart] = []
444457
for part in parts:
445458
if part.text:
@@ -454,7 +467,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName,
454467
raise UnexpectedModelBehavior(
455468
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
456469
)
457-
return ModelResponse(parts=items, model_name=model_name, usage=usage)
470+
return ModelResponse(
471+
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
472+
)
458473

459474

460475
def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:

tests/models/test_gemini.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
540540
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
541541
model_name='gemini-1.5-flash-123',
542542
timestamp=IsNow(tz=timezone.utc),
543+
vendor_details={'finish_reason': 'STOP'},
543544
),
544545
]
545546
)
@@ -555,13 +556,15 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
555556
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
556557
model_name='gemini-1.5-flash-123',
557558
timestamp=IsNow(tz=timezone.utc),
559+
vendor_details={'finish_reason': 'STOP'},
558560
),
559561
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
560562
ModelResponse(
561563
parts=[TextPart(content='Hello world')],
562564
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
563565
model_name='gemini-1.5-flash-123',
564566
timestamp=IsNow(tz=timezone.utc),
567+
vendor_details={'finish_reason': 'STOP'},
565568
),
566569
]
567570
)
@@ -585,6 +588,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
585588
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
586589
model_name='gemini-1.5-flash-123',
587590
timestamp=IsNow(tz=timezone.utc),
591+
vendor_details={'finish_reason': 'STOP'},
588592
),
589593
ModelRequest(
590594
parts=[
@@ -647,6 +651,7 @@ async def get_location(loc_name: str) -> str:
647651
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
648652
model_name='gemini-1.5-flash-123',
649653
timestamp=IsNow(tz=timezone.utc),
654+
vendor_details={'finish_reason': 'STOP'},
650655
),
651656
ModelRequest(
652657
parts=[
@@ -666,6 +671,7 @@ async def get_location(loc_name: str) -> str:
666671
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
667672
model_name='gemini-1.5-flash-123',
668673
timestamp=IsNow(tz=timezone.utc),
674+
vendor_details={'finish_reason': 'STOP'},
669675
),
670676
ModelRequest(
671677
parts=[
@@ -688,6 +694,7 @@ async def get_location(loc_name: str) -> str:
688694
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
689695
model_name='gemini-1.5-flash-123',
690696
timestamp=IsNow(tz=timezone.utc),
697+
vendor_details={'finish_reason': 'STOP'},
691698
),
692699
]
693700
)
@@ -1099,6 +1106,7 @@ async def get_image() -> BinaryContent:
10991106
usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}),
11001107
model_name='gemini-2.5-pro-preview-03-25',
11011108
timestamp=IsDatetime(),
1109+
vendor_details={'finish_reason': 'STOP'},
11021110
),
11031111
ModelRequest(
11041112
parts=[
@@ -1122,6 +1130,7 @@ async def get_image() -> BinaryContent:
11221130
usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}),
11231131
model_name='gemini-2.5-pro-preview-03-25',
11241132
timestamp=IsDatetime(),
1133+
vendor_details={'finish_reason': 'STOP'},
11251134
),
11261135
]
11271136
)
@@ -1244,6 +1253,7 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_
12441253
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
12451254
model_name='gemini-1.5-flash',
12461255
timestamp=IsDatetime(),
1256+
vendor_details={'finish_reason': 'STOP'},
12471257
),
12481258
]
12491259
)
@@ -1284,3 +1294,18 @@ async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pra
12841294
assert result.output == snapshot(
12851295
'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n'
12861296
)
1297+
1298+
1299+
async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient):
1300+
response = gemini_response(
1301+
_content_model_response(ModelResponse(parts=[TextPart('Hello world')])), finish_reason=None
1302+
)
1303+
gemini_client = get_gemini_client(response)
1304+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
1305+
agent = Agent(m)
1306+
1307+
result = await agent.run('Hello World')
1308+
1309+
for message in result.all_messages():
1310+
if isinstance(message, ModelResponse):
1311+
assert message.vendor_details is None

tests/models/test_google.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP
8585
usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}),
8686
model_name='gemini-1.5-flash',
8787
timestamp=IsDatetime(),
88+
vendor_details={'finish_reason': 'STOP'},
8889
),
8990
]
9091
)
@@ -138,6 +139,7 @@ async def temperature(city: str, date: datetime.date) -> str:
138139
usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}),
139140
model_name='gemini-1.5-flash',
140141
timestamp=IsDatetime(),
142+
vendor_details={'finish_reason': 'STOP'},
141143
),
142144
ModelRequest(
143145
parts=[
@@ -157,6 +159,7 @@ async def temperature(city: str, date: datetime.date) -> str:
157159
usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}),
158160
model_name='gemini-1.5-flash',
159161
timestamp=IsDatetime(),
162+
vendor_details={'finish_reason': 'STOP'},
160163
),
161164
ModelRequest(
162165
parts=[
@@ -215,6 +218,7 @@ async def get_capital(country: str) -> str:
215218
),
216219
model_name='models/gemini-2.5-pro-preview-05-06',
217220
timestamp=IsDatetime(),
221+
vendor_details={'finish_reason': 'STOP'},
218222
),
219223
ModelRequest(
220224
parts=[
@@ -235,6 +239,7 @@ async def get_capital(country: str) -> str:
235239
usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}),
236240
model_name='models/gemini-2.5-pro-preview-05-06',
237241
timestamp=IsDatetime(),
242+
vendor_details={'finish_reason': 'STOP'},
238243
),
239244
]
240245
)
@@ -491,6 +496,7 @@ def instructions() -> str:
491496
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
492497
model_name='gemini-2.0-flash',
493498
timestamp=IsDatetime(),
499+
vendor_details={'finish_reason': 'STOP'},
494500
),
495501
]
496502
)

0 commit comments

Comments
 (0)