@@ -9,7 +9,11 @@ use crate::{
9
9
system_prompt:: SystemPrompt ,
10
10
tools:: { arg_preprocessor:: ArgPreprocessor , control:: Stop } ,
11
11
} ;
12
- use std:: { collections:: HashSet , sync:: Arc } ;
12
+ use std:: {
13
+ collections:: { HashMap , HashSet } ,
14
+ hash:: { DefaultHasher , Hash as _, Hasher as _} ,
15
+ sync:: Arc ,
16
+ } ;
13
17
14
18
use anyhow:: Result ;
15
19
use derive_builder:: Builder ;
@@ -89,9 +93,16 @@ pub struct Agent {
89
93
/// worth while. If the limit is not reached, the agent will send the formatted error back to
90
94
/// the LLM.
91
95
///
92
- /// The limit is on each individual tool call.
96
+ /// The limit is hashed based on the tool call name and arguments, so the limit is per tool call.
97
+ ///
98
+ /// This limit is _not_ reset when the agent is stopped.
93
99
#[ builder( default = 3 ) ]
94
100
pub ( crate ) tool_retry_limit : usize ,
101
+
102
+ /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
103
+ /// the name and args of the tool.
104
+ #[ builder( private, default ) ]
105
+ pub ( crate ) tool_retries_counter : HashMap < u64 , usize > ,
95
106
}
96
107
97
108
impl std:: fmt:: Debug for Agent {
@@ -455,7 +466,7 @@ impl Agent {
455
466
handles. push ( ( handle, tool_call) ) ;
456
467
}
457
468
458
- for ( handle, mut tool_call) in handles {
469
+ for ( handle, tool_call) in handles {
459
470
let mut output = handle. await ?;
460
471
461
472
// Invoking hooks feels too verbose and repetitive
@@ -472,26 +483,26 @@ impl Agent {
472
483
}
473
484
}
474
485
475
- if let Err ( error) = & output {
476
- if tool_call. retries < self . tool_retry_limit {
477
- tool_call. retries += 1 ;
478
- tracing:: warn!(
479
- error = error. to_string( ) ,
480
- "Tool call failed, retrying {}/{}" ,
481
- tool_call. retries,
482
- self . tool_retry_limit
486
+ if let Err ( error) = output {
487
+ if self . tool_calls_over_limit ( & tool_call) {
488
+ tracing:: error!(
489
+ "Tool call failed, retry limit reached, stopping agent: {err}" ,
490
+ err = error
483
491
) ;
484
- self . add_message ( ChatMessage :: ToolOutput (
485
- tool_call,
486
- ToolOutput :: Fail ( error. to_string ( ) ) ,
487
- ) )
488
- . await ?;
489
- continue ;
492
+ self . stop ( ) ;
493
+ return Err ( error. into ( ) ) ;
490
494
}
491
- tracing:: error!(
492
- "Tool call failed, retry limit reached, stopping agent: {err}" ,
493
- err = error
495
+ tracing:: warn!(
496
+ error = error. to_string( ) ,
497
+ tool_call = ?tool_call,
498
+ "Tool call failed, retrying" ,
494
499
) ;
500
+ self . add_message ( ChatMessage :: ToolOutput (
501
+ tool_call,
502
+ ToolOutput :: Fail ( error. to_string ( ) ) ,
503
+ ) )
504
+ . await ?;
505
+ continue ;
495
506
}
496
507
497
508
let output = output?;
@@ -525,6 +536,21 @@ impl Agent {
525
536
}
526
537
}
527
538
539
+ fn tool_calls_over_limit ( & mut self , tool_call : & ToolCall ) -> bool {
540
+ let mut s = DefaultHasher :: new ( ) ;
541
+ tool_call. hash ( & mut s) ;
542
+ let hash = s. finish ( ) ;
543
+
544
+ if let Some ( retries) = self . tool_retries_counter . get_mut ( & hash) {
545
+ let val = * retries >= self . tool_retry_limit ;
546
+ * retries += 1 ;
547
+ val
548
+ } else {
549
+ self . tool_retries_counter . insert ( hash, 1 ) ;
550
+ false
551
+ }
552
+ }
553
+
528
554
#[ tracing:: instrument( skip_all, fields( message = message. to_string( ) ) ) ]
529
555
async fn add_message ( & self , mut message : ChatMessage ) -> Result < ( ) > {
530
556
for hook in self . hooks_by_type ( HookTypes :: OnNewMessage ) {
@@ -965,40 +991,43 @@ mod tests {
965
991
let mock_llm = MockChatCompletion :: new ( ) ;
966
992
let mock_tool = MockTool :: new ( "retry_tool" ) ;
967
993
968
- // Configure mock tool to fail twice and succeed on third attempt
994
+ // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
995
+ // error
996
+ mock_tool. expect_invoke (
997
+ Err ( ToolError :: WrongArguments ( serde_json:: Error :: custom (
998
+ "missing `query`" ,
999
+ ) ) ) ,
1000
+ None ,
1001
+ ) ;
969
1002
mock_tool. expect_invoke (
970
1003
Err ( ToolError :: WrongArguments ( serde_json:: Error :: custom (
971
1004
"missing `query`" ,
972
1005
) ) ) ,
973
1006
None ,
974
1007
) ;
975
-
976
- // Expected response for first two failed calls
977
- let retry_response = chat_response ! {
978
- "Attempted Retry" ;
979
- tool_calls = [ "retry_tool" ]
980
- } ;
981
-
982
- // Final response to make the agent stop
983
- let stop_response = chat_response ! {
984
- "Finished execution" ;
985
- tool_calls = [ "stop" ]
986
- } ;
987
1008
988
1009
let chat_request = chat_request ! {
989
1010
user!( prompt) ;
990
1011
tools = [ mock_tool. clone( ) ]
991
1012
} ;
992
- mock_llm. expect_complete ( chat_request. clone ( ) , Ok ( retry_response. clone ( ) ) ) ;
1013
+ let retry_response = chat_response ! {
1014
+ "First failing attempt" ;
1015
+ tool_calls = [ "retry_tool" ]
1016
+ } ;
1017
+ mock_llm. expect_complete ( chat_request. clone ( ) , Ok ( retry_response) ) ;
993
1018
994
1019
let chat_request = chat_request ! {
995
1020
user!( prompt) ,
996
- assistant!( "Attempted Retry " , [ "retry_tool" ] ) ,
1021
+ assistant!( "First failing attempt " , [ "retry_tool" ] ) ,
997
1022
tool_failed!( "retry_tool" , "arguments for tool failed to parse: missing `query`" ) ;
998
1023
999
1024
tools = [ mock_tool. clone( ) ]
1000
1025
} ;
1001
- mock_llm. expect_complete ( chat_request. clone ( ) , Ok ( stop_response) ) ;
1026
+ let will_fail_response = chat_response ! {
1027
+ "Finished execution" ;
1028
+ tool_calls = [ "retry_tool" ]
1029
+ } ;
1030
+ mock_llm. expect_complete ( chat_request. clone ( ) , Ok ( will_fail_response) ) ;
1002
1031
1003
1032
let mut agent = Agent :: builder ( )
1004
1033
. tools ( [ mock_tool] )
@@ -1009,9 +1038,10 @@ mod tests {
1009
1038
. unwrap ( ) ;
1010
1039
1011
1040
// Run the agent
1012
- agent. query ( prompt) . await . unwrap ( ) ;
1041
+ let result = agent. query ( prompt) . await ;
1013
1042
1014
- // Assert that the agent is stopped after the tool succeeds
1043
+ assert ! ( result. is_err( ) ) ;
1044
+ assert ! ( result. unwrap_err( ) . to_string( ) . contains( "missing `query`" ) ) ;
1015
1045
assert ! ( agent. is_stopped( ) ) ;
1016
1046
}
1017
1047
}
0 commit comments