|
14 | 14 |
|
15 | 15 | import os
|
16 | 16 | import sys
|
| 17 | +from typing import Optional |
17 | 18 | from unittest import mock
|
18 | 19 |
|
19 | 20 | from google.adk import version as adk_version
|
|
23 | 24 | from google.adk.models.google_llm import Gemini
|
24 | 25 | from google.adk.models.llm_request import LlmRequest
|
25 | 26 | from google.adk.models.llm_response import LlmResponse
|
| 27 | +from google.adk.utils.variant_utils import GoogleLLMVariant |
26 | 28 | from google.genai import types
|
27 | 29 | from google.genai import version as genai_version
|
28 | 30 | from google.genai.types import Content
|
@@ -337,3 +339,84 @@ async def __aexit__(self, *args):
|
337 | 339 | ):
|
338 | 340 | async with gemini_llm.connect(llm_request) as connection:
|
339 | 341 | assert connection is mock_connection
|
| 342 | + |
| 343 | + |
| 344 | +@pytest.mark.parametrize( |
| 345 | + ( |
| 346 | + "api_backend, " |
| 347 | + "expected_file_display_name, " |
| 348 | + "expected_inline_display_name, " |
| 349 | + "expected_labels" |
| 350 | + ), |
| 351 | + [ |
| 352 | + ( |
| 353 | + GoogleLLMVariant.GEMINI_API, |
| 354 | + None, |
| 355 | + None, |
| 356 | + None, |
| 357 | + ), |
| 358 | + ( |
| 359 | + GoogleLLMVariant.VERTEX_AI, |
| 360 | + "My Test PDF", |
| 361 | + "My Test Image", |
| 362 | + {"key": "value"}, |
| 363 | + ), |
| 364 | + ], |
| 365 | +) |
| 366 | +def test_preprocess_request_handles_backend_specific_fields( |
| 367 | + gemini_llm: Gemini, |
| 368 | + api_backend: GoogleLLMVariant, |
| 369 | + expected_file_display_name: Optional[str], |
| 370 | + expected_inline_display_name: Optional[str], |
| 371 | + expected_labels: Optional[str], |
| 372 | +): |
| 373 | + """ |
| 374 | + Tests that _preprocess_request correctly sanitizes fields based on the API backend. |
| 375 | +
|
| 376 | + - For GEMINI_API, it should remove 'display_name' from file/inline data |
| 377 | + and remove 'labels' from the config. |
| 378 | + - For VERTEX_AI, it should leave these fields untouched. |
| 379 | + """ |
| 380 | + # Arrange: Create a request with fields that need to be preprocessed. |
| 381 | + llm_request_with_files = LlmRequest( |
| 382 | + model="gemini-1.5-flash", |
| 383 | + contents=[ |
| 384 | + Content( |
| 385 | + role="user", |
| 386 | + parts=[ |
| 387 | + Part( |
| 388 | + file_data=types.FileData( |
| 389 | + file_uri="gs://bucket/file.pdf", |
| 390 | + mime_type="application/pdf", |
| 391 | + display_name="My Test PDF", |
| 392 | + ) |
| 393 | + ), |
| 394 | + Part( |
| 395 | + inline_data=types.Blob( |
| 396 | + data=b"some_bytes", |
| 397 | + mime_type="image/png", |
| 398 | + display_name="My Test Image", |
| 399 | + ) |
| 400 | + ), |
| 401 | + ], |
| 402 | + ) |
| 403 | + ], |
| 404 | + config=types.GenerateContentConfig(labels={"key": "value"}), |
| 405 | + ) |
| 406 | + |
| 407 | + # Mock the _api_backend property to control the test scenario |
| 408 | + with mock.patch.object( |
| 409 | + Gemini, "_api_backend", new_callable=mock.PropertyMock |
| 410 | + ) as mock_backend: |
| 411 | + mock_backend.return_value = api_backend |
| 412 | + |
| 413 | + # Act: Run the preprocessing method |
| 414 | + gemini_llm._preprocess_request(llm_request_with_files) |
| 415 | + |
| 416 | + # Assert: Check if the fields were correctly processed |
| 417 | + file_part = llm_request_with_files.contents[0].parts[0] |
| 418 | + inline_part = llm_request_with_files.contents[0].parts[1] |
| 419 | + |
| 420 | + assert file_part.file_data.display_name == expected_file_display_name |
| 421 | + assert inline_part.inline_data.display_name == expected_inline_display_name |
| 422 | + assert llm_request_with_files.config.labels == expected_labels |
0 commit comments