-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Allow additional tools at runtime #8190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
bc257fe
71d52d5
46a0556
1f55ae0
3e8a9c7
1d335e2
fcdeea7
0c53e81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import logging | ||
from typing import Any, Callable, Literal | ||
from typing import Any, Callable, Optional | ||
|
||
from litellm import ContextWindowExceededError | ||
|
||
|
@@ -20,8 +20,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): | |
self.signature = signature = ensure_signature(signature) | ||
self.max_iters = max_iters | ||
|
||
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools] | ||
tools = {tool.name: tool for tool in tools} | ||
tools = self._convert_tools(tools) | ||
|
||
inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) | ||
outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) | ||
|
@@ -34,7 +33,6 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): | |
"To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", | ||
"After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", | ||
"When writing next_thought, you may reason about the current situation and plan for future steps.", | ||
"When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", | ||
] | ||
) | ||
|
||
|
@@ -45,14 +43,12 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): | |
args={}, | ||
) | ||
|
||
for idx, tool in enumerate(tools.values()): | ||
instr.append(f"({idx + 1}) {tool}") | ||
|
||
react_signature = ( | ||
dspy.Signature({**signature.input_fields}, "\n".join(instr)) | ||
.append("trajectory", dspy.InputField(), type_=str) | ||
.append("tools", dspy.InputField(desc="Tools you select from when selecting the next_tool_name and its next_tool_args"), type_=list[str]) | ||
.append("next_thought", dspy.OutputField(), type_=str) | ||
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) | ||
.append("next_tool_name", dspy.OutputField(), type_=str) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part feels a bit downgraded, but at the same time, I wouldn't dynamically update the signature at runtime. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels a little scary because it's hard to tell if this is going to cause regression or not. @okhat From your experience, do you think LM can remember only to choose from provided tool without the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, tool names are passed to the LLM in the same call, so LLMs should have enough context. We could dynamically construct the literal at runtime, but I think it would increase the complexity so want to avoid it unless it causes regression. |
||
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) | ||
) | ||
|
||
|
@@ -65,17 +61,18 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): | |
self.react = dspy.Predict(react_signature) | ||
self.extract = dspy.ChainOfThought(fallback_signature) | ||
|
||
def _format_trajectory(self, trajectory: dict[str, Any]): | ||
adapter = dspy.settings.adapter or dspy.ChatAdapter() | ||
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") | ||
return adapter.format_user_message_content(trajectory_signature, trajectory) | ||
|
||
def forward(self, **input_args): | ||
def forward(self, additional_tools: Optional[list[Callable]] = None, **input_args): | ||
trajectory = {} | ||
max_iters = input_args.pop("max_iters", self.max_iters) | ||
tools = self.tools | self._convert_tools(additional_tools) | ||
for idx in range(max_iters): | ||
try: | ||
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) | ||
pred = self._call_with_potential_trajectory_truncation( | ||
self.react, | ||
trajectory, | ||
tools=self._format_tools_string(tools), | ||
**input_args | ||
) | ||
except ValueError as err: | ||
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") | ||
break | ||
|
@@ -85,7 +82,7 @@ def forward(self, **input_args): | |
trajectory[f"tool_args_{idx}"] = pred.next_tool_args | ||
|
||
try: | ||
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) | ||
trajectory[f"observation_{idx}"] = tools[pred.next_tool_name](**pred.next_tool_args) | ||
except Exception as err: | ||
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" | ||
|
||
|
@@ -95,12 +92,18 @@ def forward(self, **input_args): | |
extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) | ||
return dspy.Prediction(trajectory=trajectory, **extract) | ||
|
||
async def aforward(self, **input_args): | ||
async def aforward(self, additional_tools: Optional[list[Callable]] = None, **input_args): | ||
trajectory = {} | ||
max_iters = input_args.pop("max_iters", self.max_iters) | ||
tools = self.tools | self._convert_tools(additional_tools) | ||
TomeHirata marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for idx in range(max_iters): | ||
try: | ||
pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) | ||
pred = await self._async_call_with_potential_trajectory_truncation( | ||
self.react, | ||
trajectory, | ||
tools=self._format_tools_string(tools), | ||
**input_args | ||
) | ||
except ValueError as err: | ||
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") | ||
break | ||
|
@@ -110,7 +113,7 @@ async def aforward(self, **input_args): | |
trajectory[f"tool_args_{idx}"] = pred.next_tool_args | ||
|
||
try: | ||
trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) | ||
trajectory[f"observation_{idx}"] = await tools[pred.next_tool_name].acall(**pred.next_tool_args) | ||
except Exception as err: | ||
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" | ||
|
||
|
@@ -159,6 +162,20 @@ def truncate_trajectory(self, trajectory): | |
trajectory.pop(key) | ||
|
||
return trajectory | ||
|
||
def _format_trajectory(self, trajectory: dict[str, Any]): | ||
adapter = dspy.settings.adapter or dspy.ChatAdapter() | ||
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") | ||
return adapter.format_user_message_content(trajectory_signature, trajectory) | ||
|
||
def _convert_tools(self, tools: Optional[list[Callable]]) -> dict[str, Tool]: | ||
"""Convert the tools to a dictionary of name -> tool.""" | ||
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools or []] | ||
return {tool.name: tool for tool in tools} | ||
|
||
def _format_tools_string(self, tools: dict[str, Tool]) -> list[str]: | ||
"""Format the tools into a list of string.""" | ||
TomeHirata marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return [f"({idx + 1}) {tool}" for idx, tool in enumerate(tools.values())] | ||
|
||
|
||
def _fmt_exc(err: BaseException, *, limit: int = 5) -> str: | ||
|
@@ -199,14 +216,4 @@ def _fmt_exc(err: BaseException, *, limit: int = 5) -> str: | |
|
||
|
||
TOPIC 04: Default behavior when the trajectory gets too long. | ||
|
||
|
||
TOPIC 05: Adding more structure around how the instruction is formatted. | ||
* Concretely, it's now a string, so an optimizer can and does rewrite it freely. | ||
* An alternative would be to add more structure, such that a certain template is fixed but values are variable? | ||
|
||
|
||
TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls. | ||
* So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations. | ||
* This is pretty useful for allowing the agent to keep notes or count certain things, etc. | ||
""" |
Uh oh!
There was an error while loading. Please reload this page.