Skip to content

Commit 2388f18

Browse files
authored
fix(agents): Use name/arg hash on tool retries (#612)
1 parent 14f4778 commit 2388f18

File tree

3 files changed

+75
-42
lines changed

3 files changed

+75
-42
lines changed

swiftide-agents/src/agent.rs

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ use crate::{
99
system_prompt::SystemPrompt,
1010
tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
1111
};
12-
use std::{collections::HashSet, sync::Arc};
12+
use std::{
13+
collections::{HashMap, HashSet},
14+
hash::{DefaultHasher, Hash as _, Hasher as _},
15+
sync::Arc,
16+
};
1317

1418
use anyhow::Result;
1519
use derive_builder::Builder;
@@ -89,9 +93,16 @@ pub struct Agent {
8993
/// worth while. If the limit is not reached, the agent will send the formatted error back to
9094
/// the LLM.
9195
///
92-
/// The limit is on each individual tool call.
96+
/// The limit is hashed based on the tool call name and arguments, so the limit is per tool call.
97+
///
98+
/// This limit is _not_ reset when the agent is stopped.
9399
#[builder(default = 3)]
94100
pub(crate) tool_retry_limit: usize,
101+
102+
/// Internally tracks the amount of times a tool has been retried. The key is a hash based on
103+
/// the name and args of the tool.
104+
#[builder(private, default)]
105+
pub(crate) tool_retries_counter: HashMap<u64, usize>,
95106
}
96107

97108
impl std::fmt::Debug for Agent {
@@ -455,7 +466,7 @@ impl Agent {
455466
handles.push((handle, tool_call));
456467
}
457468

458-
for (handle, mut tool_call) in handles {
469+
for (handle, tool_call) in handles {
459470
let mut output = handle.await?;
460471

461472
// Invoking hooks feels too verbose and repetitive
@@ -472,26 +483,26 @@ impl Agent {
472483
}
473484
}
474485

475-
if let Err(error) = &output {
476-
if tool_call.retries < self.tool_retry_limit {
477-
tool_call.retries += 1;
478-
tracing::warn!(
479-
error = error.to_string(),
480-
"Tool call failed, retrying {}/{}",
481-
tool_call.retries,
482-
self.tool_retry_limit
486+
if let Err(error) = output {
487+
if self.tool_calls_over_limit(&tool_call) {
488+
tracing::error!(
489+
"Tool call failed, retry limit reached, stopping agent: {err}",
490+
err = error
483491
);
484-
self.add_message(ChatMessage::ToolOutput(
485-
tool_call,
486-
ToolOutput::Fail(error.to_string()),
487-
))
488-
.await?;
489-
continue;
492+
self.stop();
493+
return Err(error.into());
490494
}
491-
tracing::error!(
492-
"Tool call failed, retry limit reached, stopping agent: {err}",
493-
err = error
495+
tracing::warn!(
496+
error = error.to_string(),
497+
tool_call = ?tool_call,
498+
"Tool call failed, retrying",
494499
);
500+
self.add_message(ChatMessage::ToolOutput(
501+
tool_call,
502+
ToolOutput::Fail(error.to_string()),
503+
))
504+
.await?;
505+
continue;
495506
}
496507

497508
let output = output?;
@@ -525,6 +536,21 @@ impl Agent {
525536
}
526537
}
527538

539+
fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
540+
let mut s = DefaultHasher::new();
541+
tool_call.hash(&mut s);
542+
let hash = s.finish();
543+
544+
if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
545+
let val = *retries >= self.tool_retry_limit;
546+
*retries += 1;
547+
val
548+
} else {
549+
self.tool_retries_counter.insert(hash, 1);
550+
false
551+
}
552+
}
553+
528554
#[tracing::instrument(skip_all, fields(message = message.to_string()))]
529555
async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
530556
for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
@@ -965,40 +991,43 @@ mod tests {
965991
let mock_llm = MockChatCompletion::new();
966992
let mock_tool = MockTool::new("retry_tool");
967993

968-
// Configure mock tool to fail twice and succeed on third attempt
994+
// Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
995+
// error
996+
mock_tool.expect_invoke(
997+
Err(ToolError::WrongArguments(serde_json::Error::custom(
998+
"missing `query`",
999+
))),
1000+
None,
1001+
);
9691002
mock_tool.expect_invoke(
9701003
Err(ToolError::WrongArguments(serde_json::Error::custom(
9711004
"missing `query`",
9721005
))),
9731006
None,
9741007
);
975-
976-
// Expected response for first two failed calls
977-
let retry_response = chat_response! {
978-
"Attempted Retry";
979-
tool_calls = ["retry_tool"]
980-
};
981-
982-
// Final response to make the agent stop
983-
let stop_response = chat_response! {
984-
"Finished execution";
985-
tool_calls = ["stop"]
986-
};
9871008

9881009
let chat_request = chat_request! {
9891010
user!(prompt);
9901011
tools = [mock_tool.clone()]
9911012
};
992-
mock_llm.expect_complete(chat_request.clone(), Ok(retry_response.clone()));
1013+
let retry_response = chat_response! {
1014+
"First failing attempt";
1015+
tool_calls = ["retry_tool"]
1016+
};
1017+
mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
9931018

9941019
let chat_request = chat_request! {
9951020
user!(prompt),
996-
assistant!("Attempted Retry", ["retry_tool"]),
1021+
assistant!("First failing attempt", ["retry_tool"]),
9971022
tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
9981023

9991024
tools = [mock_tool.clone()]
10001025
};
1001-
mock_llm.expect_complete(chat_request.clone(), Ok(stop_response));
1026+
let will_fail_response = chat_response! {
1027+
"Finished execution";
1028+
tool_calls = ["retry_tool"]
1029+
};
1030+
mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
10021031

10031032
let mut agent = Agent::builder()
10041033
.tools([mock_tool])
@@ -1009,9 +1038,10 @@ mod tests {
10091038
.unwrap();
10101039

10111040
// Run the agent
1012-
agent.query(prompt).await.unwrap();
1041+
let result = agent.query(prompt).await;
10131042

1014-
// Assert that the agent is stopped after the tool succeeds
1043+
assert!(result.is_err());
1044+
assert!(result.unwrap_err().to_string().contains("missing `query`"));
10151045
assert!(agent.is_stopped());
10161046
}
10171047
}

swiftide-agents/src/test_utils.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ macro_rules! tool_failed {
9090
ToolCall::builder()
9191
.name($tool_name)
9292
.id("1")
93-
.retries(1 as usize)
9493
.build()
9594
.unwrap(),
9695
ToolOutput::Fail($message.to_string()),

swiftide-core/src/chat_completion/tools.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,14 @@ pub struct ToolCall {
4646
name: String,
4747
#[builder(default)]
4848
args: Option<String>,
49+
}
4950

50-
/// How often this tool call has been retried
51-
#[builder(default)]
52-
pub retries: usize,
51+
/// Hash is used for finding tool calls that have been retried by agents
52+
impl std::hash::Hash for &ToolCall {
53+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54+
self.name.hash(state);
55+
self.args.hash(state);
56+
}
5357
}
5458

5559
impl std::fmt::Display for ToolCall {

0 commit comments

Comments
 (0)