diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 0e337637..df2f6f58 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -286,13 +286,13 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // Get assignees - assignees, err := optionalParam[[]string](request, "assignees") + assignees, err := optionalStringArrayParam(request, "assignees") if err != nil { return mcp.NewToolResultError(err.Error()), nil } // Get labels - labels, err := optionalParam[[]string](request, "labels") + labels, err := optionalStringArrayParam(request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -401,7 +401,7 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to } // Get labels - opts.Labels, err = optionalParam[[]string](request, "labels") + opts.Labels, err = optionalStringArrayParam(request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -548,7 +548,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // Get labels - labels, err := optionalParam[[]string](request, "labels") + labels, err := optionalStringArrayParam(request, "labels") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -557,7 +557,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // Get assignees - assignees, err := optionalParam[[]string](request, "assignees") + assignees, err := optionalStringArrayParam(request, "assignees") if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index f29e2b04..5dab1631 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -436,8 +436,8 @@ func Test_CreateIssue(t *testing.T) { "repo": "repo", "title": "Test Issue", "body": "This is a test issue", - "assignees": []string{"user1", "user2"}, - "labels": []string{"bug", "help wanted"}, + "assignees": []any{"user1", "user2"}, + "labels": []any{"bug", "help wanted"}, "milestone": float64(5), }, expectError: false, @@ -636,7 +636,7 @@ func Test_ListIssues(t *testing.T) { "owner": "owner", "repo": "repo", "state": "open", - "labels": []string{"bug", "enhancement"}, + "labels": []any{"bug", "enhancement"}, "sort": "created", "direction": "desc", "since": "2023-01-01T00:00:00Z", @@ -790,8 +790,8 @@ func Test_UpdateIssue(t *testing.T) { "title": "Updated Issue Title", "body": "Updated issue description", "state": "closed", - "labels": []string{"bug", "priority"}, - "assignees": []string{"assignee1", "assignee2"}, + "labels": []any{"bug", "priority"}, + "assignees": []any{"assignee1", "assignee2"}, "milestone": float64(5), }, expectError: false, diff --git a/pkg/github/server.go b/pkg/github/server.go index f93ca37f..66dbfd1c 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -171,7 +171,7 @@ func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { // Check if the parameter is of the expected type if _, ok := r.Params.Arguments[p].(T); !ok { - return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) + return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, r.Params.Arguments[p]) } return r.Params.Arguments[p].(T), nil @@ -201,3 +201,31 @@ func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e } return v, nil } + +// optionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, iterates the elements and checks each is a string +func optionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return []string{}, nil + } + + switch v := r.Params.Arguments[p].(type) { + case []string: + return v, nil + case []any: + strSlice := make([]string, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + strSlice[i] = s + } + return strSlice, nil + default: + return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, r.Params.Arguments[p]) + } +} diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index ffaa4dd8..beb6ecbb 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -483,3 +483,71 @@ func Test_OptionalBooleanParam(t *testing.T) { }) } } + +func TestOptionalStringArrayParam(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected []string + expectError bool + }{ + { + name: "parameter not in request", + params: map[string]any{}, + paramName: "flag", + expected: []string{}, + expectError: false, + }, + { + name: "valid any array parameter", + params: map[string]any{ + "flag": []any{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "valid string array parameter", + params: map[string]any{ + "flag": []string{"v1", "v2"}, + }, + paramName: "flag", + expected: []string{"v1", "v2"}, + expectError: false, + }, + { + name: "wrong type parameter", + params: map[string]any{ + "flag": 1, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + { + name: "wrong slice type parameter", + params: map[string]any{ + "flag": []any{"foo", 2}, + }, + paramName: "flag", + expected: []string{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := optionalStringArrayParam(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +}