Skip to content

Commit 252237e

Browse files
committed
Rename AfterAny --> OnSuccess to clarify convention
1 parent 0f06109 commit 252237e

File tree

5 files changed

+52
-69
lines changed

5 files changed

+52
-69
lines changed

examples/everything/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ func NewMCPServer() *server.MCPServer {
3535
hooks.AddBeforeAny(func(id any, method mcp.MCPMethod, message any) {
3636
fmt.Printf("beforeAny: %s, %v, %v\n", method, id, message)
3737
})
38-
hooks.AddAfterAny(func(id any, method mcp.MCPMethod, message any, result any) {
39-
fmt.Printf("afterAny: %s, %v, %v, %v\n", method, id, message, result)
38+
hooks.AddOnSuccess(func(id any, method mcp.MCPMethod, message any, result any) {
39+
fmt.Printf("onSuccess: %s, %v, %v, %v\n", method, id, message, result)
4040
})
4141
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
4242
fmt.Printf("onError: %s, %v, %v, %v\n", method, id, message, err)

server/hooks.go

Lines changed: 20 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/internal/gen/hooks.go.tmpl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ import (
1111
"github.com/mark3labs/mcp-go/mcp"
1212
)
1313

14-
// OnBeforeAnyHookFunc is a function that is called after the request is
14+
// BeforeAnyHookFunc is a function that is called after the request is
1515
// parsed but before the method is called.
16-
type OnBeforeAnyHookFunc func(id any, method mcp.MCPMethod, message any)
16+
type BeforeAnyHookFunc func(id any, method mcp.MCPMethod, message any)
1717

18-
// OnAfterAnyHookFunc is a hook that will be called after the request
18+
// OnSuccessHookFunc is a hook that will be called after the request
1919
// successfully generates a result, but before the result is sent to the client.
20-
type OnAfterAnyHookFunc func(id any, method mcp.MCPMethod, message any, result any)
20+
type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result any)
2121

2222
// OnErrorHookFunc is a hook that will be called when an error occurs,
2323
// either during the request parsing or the method execution.
@@ -59,21 +59,21 @@ type OnAfter{{.HookName}}Func func(id any, message *mcp.{{.ParamType}}, result *
5959
{{end}}
6060

6161
type Hooks struct {
62-
OnBeforeAny []OnBeforeAnyHookFunc
63-
OnAfterAny []OnAfterAnyHookFunc
62+
OnBeforeAny []BeforeAnyHookFunc
63+
OnSuccess []OnSuccessHookFunc
6464
OnError []OnErrorHookFunc
6565
{{- range .}}
6666
OnBefore{{.HookName}} []OnBefore{{.HookName}}Func
6767
OnAfter{{.HookName}} []OnAfter{{.HookName}}Func
6868
{{- end}}
6969
}
7070

71-
func (c *Hooks) AddBeforeAny(hook OnBeforeAnyHookFunc) {
71+
func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) {
7272
c.OnBeforeAny = append(c.OnBeforeAny, hook)
7373
}
7474

75-
func (c *Hooks) AddAfterAny(hook OnAfterAnyHookFunc) {
76-
c.OnAfterAny = append(c.OnAfterAny, hook)
75+
func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) {
76+
c.OnSuccess = append(c.OnSuccess, hook)
7777
}
7878

7979
// AddOnError registers a hook function that will be called when an error occurs.
@@ -133,11 +133,11 @@ func (c *Hooks) beforeAny(id any, method mcp.MCPMethod, message any) {
133133
}
134134
}
135135

136-
func (c *Hooks) afterAny(id any, method mcp.MCPMethod, message any, result any) {
136+
func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any) {
137137
if c == nil {
138138
return
139139
}
140-
for _, hook := range c.OnAfterAny {
140+
for _, hook := range c.OnSuccess {
141141
hook(id, method, message, result)
142142
}
143143
}
@@ -157,7 +157,6 @@ func (c *Hooks) afterAny(id any, method mcp.MCPMethod, message any, result any)
157157
// - ErrPromptNotFound: When a prompt is not found
158158
// - ErrToolNotFound: When a tool is not found
159159
func (c *Hooks) onError(id any, method mcp.MCPMethod, message any, err error) {
160-
c.afterAny(id, method, message, err)
161160
if c == nil {
162161
return
163162
}
@@ -186,7 +185,7 @@ func (c *Hooks) before{{.HookName}}(id any, message *mcp.{{.ParamType}}) {
186185
}
187186

188187
func (c *Hooks) after{{.HookName}}(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
189-
c.afterAny(id, mcp.{{.MethodName}}, message, result)
188+
c.onSuccess(id, mcp.{{.MethodName}}, message, result)
190189
if c == nil {
191190
return
192191
}

server/server.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,6 @@ func (s *MCPServer) handleGetPrompt(
646646
handler, ok := s.promptHandlers[request.Params.Name]
647647
s.mu.RUnlock()
648648

649-
fmt.Println("request.Params.Name", request.Params.Name)
650-
651649
if !ok {
652650
return nil, &requestError{
653651
id: id,

server/server_test.go

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
750750
hooks.AddBeforeAny(func(id any, method mcp.MCPMethod, message any) {
751751
beforeResults = append(beforeResults, beforeResult{method, message})
752752
})
753-
hooks.AddAfterAny(func(id any, method mcp.MCPMethod, message any, result any) {
753+
hooks.AddOnSuccess(func(id any, method mcp.MCPMethod, message any, result any) {
754754
afterResults = append(afterResults, afterResult{method, message, result})
755755
})
756756

@@ -777,7 +777,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
777777
name string
778778
message string
779779
expectedErr int
780-
validateCallbacks func(t *testing.T, err error, beforeResults beforeResult, afterResults afterResult)
780+
validateCallbacks func(t *testing.T, err error, beforeResults beforeResult)
781781
}{
782782
{
783783
name: "Undefined tool",
@@ -791,12 +791,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
791791
}
792792
}`,
793793
expectedErr: mcp.INVALID_PARAMS,
794-
validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult, afterResults afterResult) {
794+
validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult) {
795795
assert.Equal(t, mcp.MethodToolsCall, beforeResults.method)
796-
assert.Equal(t, mcp.MethodToolsCall, afterResults.method)
797-
afterResultErr, ok := afterResults.result.(error)
798-
assert.True(t, ok)
799-
assert.Same(t, err, afterResultErr)
800796
assert.True(t, errors.Is(err, ErrToolNotFound))
801797
},
802798
},
@@ -812,12 +808,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
812808
}
813809
}`,
814810
expectedErr: mcp.INVALID_PARAMS,
815-
validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult, afterResults afterResult) {
811+
validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult) {
816812
assert.Equal(t, mcp.MethodPromptsGet, beforeResults.method)
817-
assert.Equal(t, mcp.MethodPromptsGet, afterResults.method)
818-
afterResultErr, ok := afterResults.result.(error)
819-
assert.True(t, ok)
820-
assert.Same(t, err, afterResultErr)
821813
assert.True(t, errors.Is(err, ErrPromptNotFound))
822814
},
823815
},
@@ -832,12 +824,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
832824
}
833825
}`,
834826
expectedErr: mcp.INVALID_PARAMS,
835-
validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult, afterResults afterResult) {
827+
validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult) {
836828
assert.Equal(t, mcp.MethodResourcesRead, beforeResults.method)
837-
assert.Equal(t, mcp.MethodResourcesRead, afterResults.method)
838-
afterResultErr, ok := afterResults.result.(error)
839-
assert.True(t, ok)
840-
assert.Same(t, err, afterResultErr)
841829
assert.True(t, errors.Is(err, ErrResourceNotFound))
842830
},
843831
},
@@ -847,7 +835,6 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
847835
t.Run(tt.name, func(t *testing.T) {
848836
errs = nil // Reset errors for each test case
849837
beforeResults = nil
850-
afterResults = nil
851838
response := server.HandleMessage(
852839
context.Background(),
853840
[]byte(tt.message),
@@ -861,8 +848,8 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) {
861848
if tt.validateCallbacks != nil {
862849
require.Len(t, errs, 1, "Expected exactly one error")
863850
require.Len(t, beforeResults, 1, "Expected exactly one before result")
864-
require.Len(t, afterResults, 1, "Expected exactly one after result")
865-
tt.validateCallbacks(t, errs[0], beforeResults[0], afterResults[0])
851+
require.Len(t, afterResults, 0, "Expected no after results because these calls generate errors")
852+
tt.validateCallbacks(t, errs[0], beforeResults[0])
866853
}
867854
})
868855
}
@@ -1165,7 +1152,7 @@ func TestMCPServer_WithHooks(t *testing.T) {
11651152
// Create hook counters to verify calls
11661153
var (
11671154
beforeAnyCount int
1168-
afterAnyCount int
1155+
onSuccessCount int
11691156
onErrorCount int
11701157
beforePingCount int
11711158
afterPingCount int
@@ -1175,7 +1162,7 @@ func TestMCPServer_WithHooks(t *testing.T) {
11751162

11761163
// Collectors for message and result types
11771164
var beforeAnyMessages []any
1178-
var afterAnyData []struct {
1165+
var onSuccessData []struct {
11791166
msg any
11801167
res any
11811168
}
@@ -1197,11 +1184,11 @@ func TestMCPServer_WithHooks(t *testing.T) {
11971184
}
11981185
})
11991186

1200-
hooks.AddAfterAny(func(id any, method mcp.MCPMethod, message any, result any) {
1201-
afterAnyCount++
1187+
hooks.AddOnSuccess(func(id any, method mcp.MCPMethod, message any, result any) {
1188+
onSuccessCount++
12021189
// Only collect ping responses for our test
12031190
if method == mcp.MethodPing {
1204-
afterAnyData = append(afterAnyData, struct {
1191+
onSuccessData = append(onSuccessData, struct {
12051192
msg any
12061193
res any
12071194
}{message, result})
@@ -1301,8 +1288,8 @@ func TestMCPServer_WithHooks(t *testing.T) {
13011288
// General hooks should be called for all methods
13021289
// beforeAny is called for all 4 methods (initialize, ping, tools/list, tools/call)
13031290
assert.Equal(t, 4, beforeAnyCount, "beforeAny should be called for each method")
1304-
// afterAny is called for all 3 success methods (initialize, ping, tools/list) plus 1 error (tools/call)
1305-
assert.Equal(t, 4, afterAnyCount, "afterAny should be called for all methods including errors")
1291+
// onSuccess is called for all 3 success methods (initialize, ping, tools/list)
1292+
assert.Equal(t, 3, onSuccessCount, "onSuccess should be called after all successful invocations")
13061293

13071294
// Error hook should be called once for the failed tools/call
13081295
assert.Equal(t, 1, onErrorCount, "onError should be called once")
@@ -1312,9 +1299,9 @@ func TestMCPServer_WithHooks(t *testing.T) {
13121299
require.Len(t, beforeAnyMessages, 1, "Expected one BeforeAny Ping message")
13131300
assert.IsType(t, beforePingMessages[0], beforeAnyMessages[0], "BeforeAny message should be same type as BeforePing message")
13141301

1315-
// Verify type matching between AfterAny and AfterPing
1302+
// Verify type matching between OnSuccess and AfterPing
13161303
require.Len(t, afterPingData, 1, "Expected one AfterPing message/result pair")
1317-
require.Len(t, afterAnyData, 1, "Expected one AfterAny Ping message/result pair")
1318-
assert.IsType(t, afterPingData[0].msg, afterAnyData[0].msg, "AfterAny message should be same type as AfterPing message")
1319-
assert.IsType(t, afterPingData[0].res, afterAnyData[0].res, "AfterAny result should be same type as AfterPing result")
1304+
require.Len(t, onSuccessData, 1, "Expected one OnSuccess Ping message/result pair")
1305+
assert.IsType(t, afterPingData[0].msg, onSuccessData[0].msg, "OnSuccess message should be same type as AfterPing message")
1306+
assert.IsType(t, afterPingData[0].res, onSuccessData[0].res, "OnSuccess result should be same type as AfterPing result")
13201307
}

0 commit comments

Comments
 (0)