diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 2268d4da8..fd1d8e582 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -23,6 +23,7 @@ from typing import AsyncGenerator from typing import cast from typing import TYPE_CHECKING +from typing import Union from google.genai import Client from google.genai import types @@ -244,9 +245,18 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: def _preprocess_request(self, llm_request: LlmRequest) -> None: - if llm_request.config and self._api_backend == GoogleLLMVariant.GEMINI_API: + if self._api_backend == GoogleLLMVariant.GEMINI_API: # Using API key from Google AI Studio to call model doesn't support labels. - llm_request.config.labels = None + if llm_request.config: + llm_request.config.labels = None + + if llm_request.contents: + for content in llm_request.contents: + if not content.parts: + continue + for part in content.parts: + _remove_display_name_if_present(part.inline_data) + _remove_display_name_if_present(part.file_data) def _build_function_declaration_log( @@ -324,3 +334,15 @@ def _build_response_log(resp: types.GenerateContentResponse) -> str: {resp.model_dump_json(exclude_none=True)} ----------------------------------------------------------- """ + + +def _remove_display_name_if_present( + data_obj: Union[types.Blob, types.FileData, None], +): + """Sets display_name to None for the Gemini API (non-Vertex) backend. + + This backend does not support the display_name parameter for file uploads, + so it must be removed to prevent request failures. + """ + if data_obj and data_obj.display_name: + data_obj.display_name = None diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 349dd13b8..9278bee54 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -14,6 +14,7 @@ import os import sys +from typing import Optional from unittest import mock from google.adk import version as adk_version @@ -23,6 +24,7 @@ from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai import version as genai_version from google.genai.types import Content @@ -337,3 +339,84 @@ async def __aexit__(self, *args): ): async with gemini_llm.connect(llm_request) as connection: assert connection is mock_connection + + +@pytest.mark.parametrize( + ( + "api_backend, " + "expected_file_display_name, " + "expected_inline_display_name, " + "expected_labels" + ), + [ + ( + GoogleLLMVariant.GEMINI_API, + None, + None, + None, + ), + ( + GoogleLLMVariant.VERTEX_AI, + "My Test PDF", + "My Test Image", + {"key": "value"}, + ), + ], +) +def test_preprocess_request_handles_backend_specific_fields( + gemini_llm: Gemini, + api_backend: GoogleLLMVariant, + expected_file_display_name: Optional[str], + expected_inline_display_name: Optional[str], + expected_labels: Optional[str], +): + """ + Tests that _preprocess_request correctly sanitizes fields based on the API backend. + + - For GEMINI_API, it should remove 'display_name' from file/inline data + and remove 'labels' from the config. + - For VERTEX_AI, it should leave these fields untouched. + """ + # Arrange: Create a request with fields that need to be preprocessed. + llm_request_with_files = LlmRequest( + model="gemini-1.5-flash", + contents=[ + Content( + role="user", + parts=[ + Part( + file_data=types.FileData( + file_uri="gs://bucket/file.pdf", + mime_type="application/pdf", + display_name="My Test PDF", + ) + ), + Part( + inline_data=types.Blob( + data=b"some_bytes", + mime_type="image/png", + display_name="My Test Image", + ) + ), + ], + ) + ], + config=types.GenerateContentConfig(labels={"key": "value"}), + ) + + # Mock the _api_backend property to control the test scenario + with mock.patch.object( + Gemini, "_api_backend", new_callable=mock.PropertyMock + ) as mock_backend: + mock_backend.return_value = api_backend + + # Act: Run the preprocessing method + gemini_llm._preprocess_request(llm_request_with_files) + + # Assert: Check if the fields were correctly processed + file_part = llm_request_with_files.contents[0].parts[0] + inline_part = llm_request_with_files.contents[0].parts[1] + + assert file_part.file_data.display_name == expected_file_display_name + assert inline_part.inline_data.display_name == expected_inline_display_name + assert llm_request_with_files.config.labels == expected_labels