Skip to content

Allow additional tools at runtime #8190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 36 additions & 29 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Callable, Literal
from typing import Any, Callable, Optional

from litellm import ContextWindowExceededError

Expand All @@ -20,8 +20,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
self.signature = signature = ensure_signature(signature)
self.max_iters = max_iters

tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
tools = {tool.name: tool for tool in tools}
tools = self._convert_tools(tools)

inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])
Expand All @@ -34,7 +33,6 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
"To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.",
"After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n",
"When writing next_thought, you may reason about the current situation and plan for future steps.",
"When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n",
]
)

Expand All @@ -45,14 +43,12 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
args={},
)

for idx, tool in enumerate(tools.values()):
instr.append(f"({idx + 1}) {tool}")

react_signature = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
.append("trajectory", dspy.InputField(), type_=str)
.append("tools", dspy.InputField(desc="Tools you select from when selecting the next_tool_name and its next_tool_args"), type_=list[str])
.append("next_thought", dspy.OutputField(), type_=str)
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
.append("next_tool_name", dspy.OutputField(), type_=str)
Copy link
Collaborator Author

@TomeHirata TomeHirata May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part feels a bit downgraded, but at the same time, I wouldn't dynamically update the signature at runtime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a little scary because it's hard to tell if this is going to cause regression or not.

@okhat From your experience, do you think LM can remember only to choose from provided tool without the Literal constrant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, tool names are passed to the LLM in the same call, so LLMs should have enough context. We could dynamically construct the literal at runtime, but I think it would increase the complexity so want to avoid it unless it causes regression.

.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
)

Expand All @@ -65,17 +61,18 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
self.react = dspy.Predict(react_signature)
self.extract = dspy.ChainOfThought(fallback_signature)

def _format_trajectory(self, trajectory: dict[str, Any]):
adapter = dspy.settings.adapter or dspy.ChatAdapter()
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
return adapter.format_user_message_content(trajectory_signature, trajectory)

def forward(self, **input_args):
def forward(self, additional_tools: Optional[list[Callable]] = None, **input_args):
trajectory = {}
max_iters = input_args.pop("max_iters", self.max_iters)
tools = self.tools | self._convert_tools(additional_tools)
for idx in range(max_iters):
try:
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
pred = self._call_with_potential_trajectory_truncation(
self.react,
trajectory,
tools=self._format_tools_string(tools),
**input_args
)
except ValueError as err:
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
break
Expand All @@ -85,7 +82,7 @@ def forward(self, **input_args):
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
trajectory[f"observation_{idx}"] = tools[pred.next_tool_name](**pred.next_tool_args)
except Exception as err:
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

Expand All @@ -95,12 +92,18 @@ def forward(self, **input_args):
extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
return dspy.Prediction(trajectory=trajectory, **extract)

async def aforward(self, **input_args):
async def aforward(self, additional_tools: Optional[list[Callable]] = None, **input_args):
trajectory = {}
max_iters = input_args.pop("max_iters", self.max_iters)
tools = self.tools | self._convert_tools(additional_tools)
for idx in range(max_iters):
try:
pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
pred = await self._async_call_with_potential_trajectory_truncation(
self.react,
trajectory,
tools=self._format_tools_string(tools),
**input_args
)
except ValueError as err:
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
break
Expand All @@ -110,7 +113,7 @@ async def aforward(self, **input_args):
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args)
trajectory[f"observation_{idx}"] = await tools[pred.next_tool_name].acall(**pred.next_tool_args)
except Exception as err:
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

Expand Down Expand Up @@ -159,6 +162,20 @@ def truncate_trajectory(self, trajectory):
trajectory.pop(key)

return trajectory

def _format_trajectory(self, trajectory: dict[str, Any]):
adapter = dspy.settings.adapter or dspy.ChatAdapter()
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
return adapter.format_user_message_content(trajectory_signature, trajectory)

def _convert_tools(self, tools: Optional[list[Callable]]) -> dict[str, Tool]:
"""Convert the tools to a dictionary of name -> tool."""
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools or []]
return {tool.name: tool for tool in tools}

def _format_tools_string(self, tools: dict[str, Tool]) -> list[str]:
"""Format the tools into a list of string."""
return [f"({idx + 1}) {tool}" for idx, tool in enumerate(tools.values())]


def _fmt_exc(err: BaseException, *, limit: int = 5) -> str:
Expand Down Expand Up @@ -199,14 +216,4 @@ def _fmt_exc(err: BaseException, *, limit: int = 5) -> str:


TOPIC 04: Default behavior when the trajectory gets too long.


TOPIC 05: Adding more structure around how the instruction is formatted.
* Concretely, it's now a string, so an optimizer can and does rewrite it freely.
* An alternative would be to add more structure, such that a certain template is fixed but values are variable?


TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls.
* So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations.
* This is pretty useful for allowing the agent to keep notes or count certain things, etc.
"""
55 changes: 55 additions & 0 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,58 @@ async def foo(a, b):
for i in range(2):
obs = traj[f"observation_{i}"]
assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}"


def test_tool_call_state_management():
# Create a tool class that maintains state
class Counter():
def __init__(self):
self.count = 0

def __call__(self, i: int):
self.count += i
return self.count

react = dspy.ReAct("max -> sum:int", tools=[])
lm = DummyLM(
[
{"next_thought": "I need to add 1", "next_tool_name": "Counter", "next_tool_args": {"i": 1}},
{"next_thought": "I need to add 2", "next_tool_name": "Counter", "next_tool_args": {"i": 2}},
{"next_thought": "I need to add 3", "next_tool_name": "Counter", "next_tool_args": {"i": 3}},
{"next_thought": "I have the sum, now I can finish.", "next_tool_name": "finish", "next_tool_args": {}},
{"reasoning": "I added the numbers successfully", "sum": 6},
]
)
dspy.settings.configure(lm=lm)

# Pass the tool object as an additional tool
outputs = react(call_count=3, additional_tools=[Counter()])

assert outputs.sum == 6

# Check the state is managed during the forward call
expected_trajectory = {
"thought_0": "I need to add 1",
"tool_name_0": "Counter",
"tool_args_0": {
"i": 1,
},
"observation_0": 1,
"thought_1": "I need to add 2",
"tool_name_1": "Counter",
"tool_args_1": {
"i": 2,
},
"observation_1": 3,
"thought_2": "I need to add 3",
"tool_name_2": "Counter",
"tool_args_2": {
"i": 3,
},
"observation_2": 6,
"thought_3": "I have the sum, now I can finish.",
"tool_name_3": "finish",
"tool_args_3": {},
"observation_3": "Completed.",
}
assert outputs.trajectory == expected_trajectory