|
17 | 17 | from litellm import completion as litellm_completion
|
18 | 18 | from pydantic import BaseModel
|
19 | 19 |
|
20 |
| -__version__ = "5.3.1" |
| 20 | +__version__ = "5.4.0" |
21 | 21 |
|
22 | 22 | SystemPrompt = Optional[Union[str, List[str], List[Dict[str, str]]]]
|
23 | 23 |
|
@@ -288,6 +288,8 @@ def _generate_tool_definition(self, fn: Callable) -> dict:
|
288 | 288 | param_info["type"] = "number"
|
289 | 289 | elif param_type == bool:
|
290 | 290 | param_info["type"] = "boolean"
|
| 291 | + elif inspect.isclass(param_type) and issubclass(param_type, BaseModel): |
| 292 | + param_info = param_type.model_json_schema() |
291 | 293 |
|
292 | 294 | parameters["properties"][name] = param_info
|
293 | 295 |
|
@@ -378,6 +380,28 @@ def llm(self, func: Callable = None, **kwargs):
|
378 | 380 |
|
379 | 381 | return new_instance._decorator(func) if func else new_instance._decorator
|
380 | 382 |
|
| 383 | + def _deserialize_pydantic_args(self, fn: Callable, function_args: dict) -> dict: |
| 384 | + """Deserialize any Pydantic model parameters in the function arguments. |
| 385 | +
|
| 386 | + Args: |
| 387 | + fn: The function whose parameters to check |
| 388 | + function_args: The arguments to deserialize |
| 389 | +
|
| 390 | + Returns: |
| 391 | + The function arguments with any Pydantic models deserialized |
| 392 | + """ |
| 393 | + sig = inspect.signature(fn) |
| 394 | + for param_name, param in sig.parameters.items(): |
| 395 | + param_type = param.annotation if param.annotation != inspect._empty else Any |
| 396 | + if ( |
| 397 | + inspect.isclass(param_type) |
| 398 | + and issubclass(param_type, BaseModel) |
| 399 | + and param_name in function_args |
| 400 | + and isinstance(function_args[param_name], dict) |
| 401 | + ): |
| 402 | + function_args[param_name] = param_type(**function_args[param_name]) |
| 403 | + return function_args |
| 404 | + |
381 | 405 | def _decorator(self, func: Callable):
|
382 | 406 | return_type = func.__annotations__.get("return")
|
383 | 407 |
|
@@ -590,9 +614,14 @@ def wrapper(*args, **kwargs):
|
590 | 614 | function_response = f"[DRY RUN] Would have called {function_name = } {function_args = }"
|
591 | 615 | else:
|
592 | 616 | try:
|
593 |
| - function_response = self.tools[function_name]( |
594 |
| - **function_args |
| 617 | + self.logger.debug( |
| 618 | + f"Calling tool {function_name}({function_args}) using {self.model = }" |
| 619 | + ) |
| 620 | + fn = self.tools[function_name] |
| 621 | + function_args = self._deserialize_pydantic_args( |
| 622 | + fn, function_args |
595 | 623 | )
|
| 624 | + function_response = fn(**function_args) |
596 | 625 | except Exception as e:
|
597 | 626 | self.logger.error(
|
598 | 627 | f"Error calling tool {function_name}({function_args}): {e}"
|
@@ -625,6 +654,7 @@ def wrapper(*args, **kwargs):
|
625 | 654 | wrapper.tool = self.tool
|
626 | 655 | wrapper.clear = self.clear
|
627 | 656 | wrapper.message = self.message
|
| 657 | + wrapper.instance = self |
628 | 658 |
|
629 | 659 | # Automatically expose all other attributes from self
|
630 | 660 | for attr_name, attr_value in self.__dict__.items():
|
@@ -683,9 +713,16 @@ def _stream_response(self, response, call=None):
|
683 | 713 | f"[DRY RUN] Would have called {tool_info['name']} with {function_args}"
|
684 | 714 | )
|
685 | 715 | else:
|
686 |
| - self.tools[tool_info["name"]]( |
687 |
| - **function_args |
| 716 | + self.logger.debug( |
| 717 | + f"Calling tool {tool_info['name']}({function_args}) using {self.model = }" |
| 718 | + ) |
| 719 | + fn = self.tools[tool_info["name"]] |
| 720 | + function_args = ( |
| 721 | + self._deserialize_pydantic_args( |
| 722 | + fn, function_args |
| 723 | + ) |
688 | 724 | )
|
| 725 | + fn(**function_args) |
689 | 726 | # Clear after successful execution
|
690 | 727 | del current_tool_calls[current_index]
|
691 | 728 | except json.JSONDecodeError:
|
|
0 commit comments