@@ -750,7 +750,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
750
750
hooks .AddBeforeAny (func (id any , method mcp.MCPMethod , message any ) {
751
751
beforeResults = append (beforeResults , beforeResult {method , message })
752
752
})
753
- hooks .AddAfterAny (func (id any , method mcp.MCPMethod , message any , result any ) {
753
+ hooks .AddOnSuccess (func (id any , method mcp.MCPMethod , message any , result any ) {
754
754
afterResults = append (afterResults , afterResult {method , message , result })
755
755
})
756
756
@@ -777,7 +777,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
777
777
name string
778
778
message string
779
779
expectedErr int
780
- validateCallbacks func (t * testing.T , err error , beforeResults beforeResult , afterResults afterResult )
780
+ validateCallbacks func (t * testing.T , err error , beforeResults beforeResult )
781
781
}{
782
782
{
783
783
name : "Undefined tool" ,
@@ -791,12 +791,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
791
791
}
792
792
}` ,
793
793
expectedErr : mcp .INVALID_PARAMS ,
794
- validateCallbacks : func (t * testing.T , err error , beforeResults beforeResult , afterResults afterResult ) {
794
+ validateCallbacks : func (t * testing.T , err error , beforeResults beforeResult ) {
795
795
assert .Equal (t , mcp .MethodToolsCall , beforeResults .method )
796
- assert .Equal (t , mcp .MethodToolsCall , afterResults .method )
797
- afterResultErr , ok := afterResults .result .(error )
798
- assert .True (t , ok )
799
- assert .Same (t , err , afterResultErr )
800
796
assert .True (t , errors .Is (err , ErrToolNotFound ))
801
797
},
802
798
},
@@ -812,12 +808,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
812
808
}
813
809
}` ,
814
810
expectedErr : mcp .INVALID_PARAMS ,
815
- validateCallbacks : func (t * testing.T , err error , beforeResults beforeResult , afterResults afterResult ) {
811
+ validateCallbacks : func (t * testing.T , err error , beforeResults beforeResult ) {
816
812
assert .Equal (t , mcp .MethodPromptsGet , beforeResults .method )
817
- assert .Equal (t , mcp .MethodPromptsGet , afterResults .method )
818
- afterResultErr , ok := afterResults .result .(error )
819
- assert .True (t , ok )
820
- assert .Same (t , err , afterResultErr )
821
813
assert .True (t , errors .Is (err , ErrPromptNotFound ))
822
814
},
823
815
},
@@ -832,12 +824,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
832
824
}
833
825
}` ,
834
826
expectedErr : mcp .INVALID_PARAMS ,
835
- validateCallbacks : func (t * testing.T , err error , beforeResults beforeResult , afterResults afterResult ) {
827
+ validateCallbacks : func (t * testing.T , err error , beforeResults beforeResult ) {
836
828
assert .Equal (t , mcp .MethodResourcesRead , beforeResults .method )
837
- assert .Equal (t , mcp .MethodResourcesRead , afterResults .method )
838
- afterResultErr , ok := afterResults .result .(error )
839
- assert .True (t , ok )
840
- assert .Same (t , err , afterResultErr )
841
829
assert .True (t , errors .Is (err , ErrResourceNotFound ))
842
830
},
843
831
},
@@ -847,7 +835,6 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
847
835
t .Run (tt .name , func (t * testing.T ) {
848
836
errs = nil // Reset errors for each test case
849
837
beforeResults = nil
850
- afterResults = nil
851
838
response := server .HandleMessage (
852
839
context .Background (),
853
840
[]byte (tt .message ),
@@ -861,8 +848,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
861
848
if tt .validateCallbacks != nil {
862
849
require .Len (t , errs , 1 , "Expected exactly one error" )
863
850
require .Len (t , beforeResults , 1 , "Expected exactly one before result" )
864
- require .Len (t , afterResults , 1 , "Expected exactly one after result " )
865
- tt .validateCallbacks (t , errs [0 ], beforeResults [0 ], afterResults [ 0 ] )
851
+ require .Len (t , afterResults , 0 , "Expected no after results because these calls generate errors " )
852
+ tt .validateCallbacks (t , errs [0 ], beforeResults [0 ])
866
853
}
867
854
})
868
855
}
@@ -1165,7 +1152,7 @@ func TestMCPServer_WithHooks(t *testing.T) {
1165
1152
// Create hook counters to verify calls
1166
1153
var (
1167
1154
beforeAnyCount int
1168
- afterAnyCount int
1155
+ onSuccessCount int
1169
1156
onErrorCount int
1170
1157
beforePingCount int
1171
1158
afterPingCount int
@@ -1175,7 +1162,7 @@ func TestMCPServer_WithHooks(t *testing.T) {
1175
1162
1176
1163
// Collectors for message and result types
1177
1164
var beforeAnyMessages []any
1178
- var afterAnyData []struct {
1165
+ var onSuccessData []struct {
1179
1166
msg any
1180
1167
res any
1181
1168
}
@@ -1197,11 +1184,11 @@ func TestMCPServer_WithHooks(t *testing.T) {
1197
1184
}
1198
1185
})
1199
1186
1200
- hooks .AddAfterAny (func (id any , method mcp.MCPMethod , message any , result any ) {
1201
- afterAnyCount ++
1187
+ hooks .AddOnSuccess (func (id any , method mcp.MCPMethod , message any , result any ) {
1188
+ onSuccessCount ++
1202
1189
// Only collect ping responses for our test
1203
1190
if method == mcp .MethodPing {
1204
- afterAnyData = append (afterAnyData , struct {
1191
+ onSuccessData = append (onSuccessData , struct {
1205
1192
msg any
1206
1193
res any
1207
1194
}{message , result })
@@ -1301,8 +1288,8 @@ func TestMCPServer_WithHooks(t *testing.T) {
1301
1288
// General hooks should be called for all methods
1302
1289
// beforeAny is called for all 4 methods (initialize, ping, tools/list, tools/call)
1303
1290
assert .Equal (t , 4 , beforeAnyCount , "beforeAny should be called for each method" )
1304
- // afterAny is called for all 3 success methods (initialize, ping, tools/list) plus 1 error (tools/call )
1305
- assert .Equal (t , 4 , afterAnyCount , "afterAny should be called for all methods including errors " )
1291
+ // onSuccess is called for all 3 success methods (initialize, ping, tools/list)
1292
+ assert .Equal (t , 3 , onSuccessCount , "onSuccess should be called after all successful invocations " )
1306
1293
1307
1294
// Error hook should be called once for the failed tools/call
1308
1295
assert .Equal (t , 1 , onErrorCount , "onError should be called once" )
@@ -1312,9 +1299,9 @@ func TestMCPServer_WithHooks(t *testing.T) {
1312
1299
require .Len (t , beforeAnyMessages , 1 , "Expected one BeforeAny Ping message" )
1313
1300
assert .IsType (t , beforePingMessages [0 ], beforeAnyMessages [0 ], "BeforeAny message should be same type as BeforePing message" )
1314
1301
1315
- // Verify type matching between AfterAny and AfterPing
1302
+ // Verify type matching between OnSuccess and AfterPing
1316
1303
require .Len (t , afterPingData , 1 , "Expected one AfterPing message/result pair" )
1317
- require .Len (t , afterAnyData , 1 , "Expected one AfterAny Ping message/result pair" )
1318
- assert .IsType (t , afterPingData [0 ].msg , afterAnyData [0 ].msg , "AfterAny message should be same type as AfterPing message" )
1319
- assert .IsType (t , afterPingData [0 ].res , afterAnyData [0 ].res , "AfterAny result should be same type as AfterPing result" )
1304
+ require .Len (t , onSuccessData , 1 , "Expected one OnSuccess Ping message/result pair" )
1305
+ assert .IsType (t , afterPingData [0 ].msg , onSuccessData [0 ].msg , "OnSuccess message should be same type as AfterPing message" )
1306
+ assert .IsType (t , afterPingData [0 ].res , onSuccessData [0 ].res , "OnSuccess result should be same type as AfterPing result" )
1320
1307
}
0 commit comments