Skip to content

feat: Add update_pull_request tool #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 9, 2025
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,17 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description
- `draft`: Create as draft PR (boolean, optional)
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)

- **update_pull_request** - Update an existing pull request in a GitHub repository

- `owner`: Repository owner (string, required)
- `repo`: Repository name (string, required)
- `pullNumber`: Pull request number to update (number, required)
- `title`: New title (string, optional)
- `body`: New description (string, optional)
- `state`: New state ('open' or 'closed') (string, optional)
- `base`: New base branch name (string, optional)
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)

### Repositories

- **create_or_update_file** - Create or update a single file in a repository
Expand Down
112 changes: 112 additions & 0 deletions pkg/github/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,115 @@ func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent {
assert.Equal(t, "text", textContent.Type)
return textContent
}

func TestOptionalParamOK(t *testing.T) {
tests := []struct {
name string
args map[string]interface{}
paramName string
expectedVal interface{}
expectedOk bool
expectError bool
errorMsg string
}{
{
name: "present and correct type (string)",
args: map[string]interface{}{"myParam": "hello"},
paramName: "myParam",
expectedVal: "hello",
expectedOk: true,
expectError: false,
},
{
name: "present and correct type (bool)",
args: map[string]interface{}{"myParam": true},
paramName: "myParam",
expectedVal: true,
expectedOk: true,
expectError: false,
},
{
name: "present and correct type (number)",
args: map[string]interface{}{"myParam": float64(123)},
paramName: "myParam",
expectedVal: float64(123),
expectedOk: true,
expectError: false,
},
{
name: "present but wrong type (string expected, got bool)",
args: map[string]interface{}{"myParam": true},
paramName: "myParam",
expectedVal: "", // Zero value for string
expectedOk: true, // ok is true because param exists
expectError: true,
errorMsg: "parameter myParam is not of type string, is bool",
},
{
name: "present but wrong type (bool expected, got string)",
args: map[string]interface{}{"myParam": "true"},
paramName: "myParam",
expectedVal: false, // Zero value for bool
expectedOk: true, // ok is true because param exists
expectError: true,
errorMsg: "parameter myParam is not of type bool, is string",
},
{
name: "parameter not present",
args: map[string]interface{}{"anotherParam": "value"},
paramName: "myParam",
expectedVal: "", // Zero value for string
expectedOk: false,
expectError: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.args)

// Test with string type assertion
if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" {
val, ok, err := OptionalParamOK[string](request, tc.paramName)
if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errorMsg)
assert.Equal(t, tc.expectedOk, ok) // Check ok even on error
assert.Equal(t, tc.expectedVal, val) // Check zero value on error
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedOk, ok)
assert.Equal(t, tc.expectedVal, val)
}
}

// Test with bool type assertion
if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" {
val, ok, err := OptionalParamOK[bool](request, tc.paramName)
if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errorMsg)
assert.Equal(t, tc.expectedOk, ok) // Check ok even on error
assert.Equal(t, tc.expectedVal, val) // Check zero value on error
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedOk, ok)
assert.Equal(t, tc.expectedVal, val)
}
}

// Test with float64 type assertion (for number case)
if _, isFloat := tc.expectedVal.(float64); isFloat {
val, ok, err := OptionalParamOK[float64](request, tc.paramName)
if tc.expectError {
// This case shouldn't happen for float64 in the defined tests
require.Fail(t, "Unexpected error case for float64")
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedOk, ok)
assert.Equal(t, tc.expectedVal, val)
}
}
})
}
}
113 changes: 113 additions & 0 deletions pkg/github/pullrequests.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,119 @@ func GetPullRequest(client *github.Client, t translations.TranslationHelperFunc)
}
}

// UpdatePullRequest creates a tool to update an existing pull request.
func UpdatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("update_pull_request",
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")),
mcp.WithString("owner",
mcp.Required(),
mcp.Description("Repository owner"),
),
mcp.WithString("repo",
mcp.Required(),
mcp.Description("Repository name"),
),
mcp.WithNumber("pullNumber",
mcp.Required(),
mcp.Description("Pull request number to update"),
),
mcp.WithString("title",
mcp.Description("New title"),
),
mcp.WithString("body",
mcp.Description("New description"),
),
mcp.WithString("state",
mcp.Description("New state ('open' or 'closed')"),
mcp.Enum("open", "closed"),
),
mcp.WithString("base",
mcp.Description("New base branch name"),
),
mcp.WithBoolean("maintainer_can_modify",
mcp.Description("Allow maintainer edits"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
pullNumber, err := RequiredInt(request, "pullNumber")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

// Build the update struct only with provided fields
update := &github.PullRequest{}
updateNeeded := false

if title, ok, err := OptionalParamOK[string](request, "title"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Title = github.Ptr(title)
updateNeeded = true
}

if body, ok, err := OptionalParamOK[string](request, "body"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Body = github.Ptr(body)
updateNeeded = true
}

if state, ok, err := OptionalParamOK[string](request, "state"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.State = github.Ptr(state)
updateNeeded = true
}

if base, ok, err := OptionalParamOK[string](request, "base"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
updateNeeded = true
}

if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
updateNeeded = true
}

if !updateNeeded {
return mcp.NewToolResultError("No update parameters provided."), nil
}

pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
if err != nil {
return nil, fmt.Errorf("failed to update pull request: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
}

r, err := json.Marshal(pr)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}

// ListPullRequests creates a tool to list and filter repository pull requests.
func ListPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("list_pull_requests",
Expand Down
Loading