Skip to content

Commit 2ad16e8

Browse files
Add Pydantic model support for tool parameters
- Implement _deserialize_pydantic_args() method to handle Pydantic BaseModel parameters in tool calls - Update tool execution to automatically deserialize Pydantic model arguments - Add JSON schema generation for Pydantic model parameters in tool definitions - Bump version to 5.4.0 to reflect the new Pydantic model support - Add test case to verify tool definition and execution with Pydantic models
1 parent 06e0544 commit 2ad16e8

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

promptic.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from litellm import completion as litellm_completion
1818
from pydantic import BaseModel
1919

20-
__version__ = "5.3.1"
20+
__version__ = "5.4.0"
2121

2222
SystemPrompt = Optional[Union[str, List[str], List[Dict[str, str]]]]
2323

@@ -288,6 +288,8 @@ def _generate_tool_definition(self, fn: Callable) -> dict:
288288
param_info["type"] = "number"
289289
elif param_type == bool:
290290
param_info["type"] = "boolean"
291+
elif inspect.isclass(param_type) and issubclass(param_type, BaseModel):
292+
param_info = param_type.model_json_schema()
291293

292294
parameters["properties"][name] = param_info
293295

@@ -378,6 +380,28 @@ def llm(self, func: Callable = None, **kwargs):
378380

379381
return new_instance._decorator(func) if func else new_instance._decorator
380382

383+
def _deserialize_pydantic_args(self, fn: Callable, function_args: dict) -> dict:
384+
"""Deserialize any Pydantic model parameters in the function arguments.
385+
386+
Args:
387+
fn: The function whose parameters to check
388+
function_args: The arguments to deserialize
389+
390+
Returns:
391+
The function arguments with any Pydantic models deserialized
392+
"""
393+
sig = inspect.signature(fn)
394+
for param_name, param in sig.parameters.items():
395+
param_type = param.annotation if param.annotation != inspect._empty else Any
396+
if (
397+
inspect.isclass(param_type)
398+
and issubclass(param_type, BaseModel)
399+
and param_name in function_args
400+
and isinstance(function_args[param_name], dict)
401+
):
402+
function_args[param_name] = param_type(**function_args[param_name])
403+
return function_args
404+
381405
def _decorator(self, func: Callable):
382406
return_type = func.__annotations__.get("return")
383407

@@ -590,9 +614,14 @@ def wrapper(*args, **kwargs):
590614
function_response = f"[DRY RUN] Would have called {function_name = } {function_args = }"
591615
else:
592616
try:
593-
function_response = self.tools[function_name](
594-
**function_args
617+
self.logger.debug(
618+
f"Calling tool {function_name}({function_args}) using {self.model = }"
619+
)
620+
fn = self.tools[function_name]
621+
function_args = self._deserialize_pydantic_args(
622+
fn, function_args
595623
)
624+
function_response = fn(**function_args)
596625
except Exception as e:
597626
self.logger.error(
598627
f"Error calling tool {function_name}({function_args}): {e}"
@@ -625,6 +654,7 @@ def wrapper(*args, **kwargs):
625654
wrapper.tool = self.tool
626655
wrapper.clear = self.clear
627656
wrapper.message = self.message
657+
wrapper.instance = self
628658

629659
# Automatically expose all other attributes from self
630660
for attr_name, attr_value in self.__dict__.items():
@@ -683,9 +713,16 @@ def _stream_response(self, response, call=None):
683713
f"[DRY RUN] Would have called {tool_info['name']} with {function_args}"
684714
)
685715
else:
686-
self.tools[tool_info["name"]](
687-
**function_args
716+
self.logger.debug(
717+
f"Calling tool {tool_info['name']}({function_args}) using {self.model = }"
718+
)
719+
fn = self.tools[tool_info["name"]]
720+
function_args = (
721+
self._deserialize_pydantic_args(
722+
fn, function_args
723+
)
688724
)
725+
fn(**function_args)
689726
# Clear after successful execution
690727
del current_tool_calls[current_index]
691728
except json.JSONDecodeError:

tests/test_promptic.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,3 +1668,46 @@ def child_function(command):
16681668
# It should respond without using the tool
16691669
result = child_function("What time is it?")
16701670
assert "12:00" not in result
1671+
1672+
1673+
@pytest.mark.parametrize("model", REGULAR_MODELS)
1674+
@pytest.mark.parametrize(
1675+
"create_completion_fn", [openai_completion_fn, litellm_completion]
1676+
)
1677+
def test_tool_definition_with_pydantic_param(model, create_completion_fn):
1678+
"""Test that tool definitions properly handle Pydantic BaseModel parameters"""
1679+
if create_completion_fn == openai_completion_fn and not model.startswith("gpt"):
1680+
pytest.skip("Non-GPT models are not supported with OpenAI client")
1681+
if "gemini" in model:
1682+
pytest.skip("The Gemini model is flaky with this test.")
1683+
1684+
class UserProfile(BaseModel):
1685+
name: str
1686+
age: int
1687+
email: str
1688+
1689+
llm = Promptic(
1690+
model=model,
1691+
temperature=0,
1692+
timeout=5,
1693+
create_completion_fn=create_completion_fn,
1694+
)
1695+
1696+
@llm
1697+
def assistant(command):
1698+
"""{command}"""
1699+
1700+
name = None
1701+
1702+
@assistant.tool
1703+
def update_user(profile: UserProfile) -> str:
1704+
"""Update user profile"""
1705+
nonlocal name
1706+
name = profile.name
1707+
return f"Updated profile for {profile.name}"
1708+
1709+
assistant(
1710+
"Update the user's profile with the name 'John Doe', age 30, and email '[email protected]'"
1711+
)
1712+
1713+
assert name == "John Doe"

0 commit comments

Comments
 (0)