Skip to content

Commit 7491bd9

Browse files
authored
feat: Add update_pull_request tool (github#122)
* feat: add update_pull_request tool * refactor: address feedback on optionalParamOK helper * docs: add update_pull_request tool documentation * refactor: update optionalParamsOK as exported member * fix: rename to exported function
1 parent 3a50f1c commit 7491bd9

File tree

5 files changed

+443
-0
lines changed

5 files changed

+443
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,17 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description
287287
- `draft`: Create as draft PR (boolean, optional)
288288
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)
289289

290+
- **update_pull_request** - Update an existing pull request in a GitHub repository
291+
292+
- `owner`: Repository owner (string, required)
293+
- `repo`: Repository name (string, required)
294+
- `pullNumber`: Pull request number to update (number, required)
295+
- `title`: New title (string, optional)
296+
- `body`: New description (string, optional)
297+
- `state`: New state ('open' or 'closed') (string, optional)
298+
- `base`: New base branch name (string, optional)
299+
- `maintainer_can_modify`: Allow maintainer edits (boolean, optional)
300+
290301
### Repositories
291302

292303
- **create_or_update_file** - Create or update a single file in a repository

pkg/github/helper_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,115 @@ func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent {
9393
assert.Equal(t, "text", textContent.Type)
9494
return textContent
9595
}
96+
97+
func TestOptionalParamOK(t *testing.T) {
98+
tests := []struct {
99+
name string
100+
args map[string]interface{}
101+
paramName string
102+
expectedVal interface{}
103+
expectedOk bool
104+
expectError bool
105+
errorMsg string
106+
}{
107+
{
108+
name: "present and correct type (string)",
109+
args: map[string]interface{}{"myParam": "hello"},
110+
paramName: "myParam",
111+
expectedVal: "hello",
112+
expectedOk: true,
113+
expectError: false,
114+
},
115+
{
116+
name: "present and correct type (bool)",
117+
args: map[string]interface{}{"myParam": true},
118+
paramName: "myParam",
119+
expectedVal: true,
120+
expectedOk: true,
121+
expectError: false,
122+
},
123+
{
124+
name: "present and correct type (number)",
125+
args: map[string]interface{}{"myParam": float64(123)},
126+
paramName: "myParam",
127+
expectedVal: float64(123),
128+
expectedOk: true,
129+
expectError: false,
130+
},
131+
{
132+
name: "present but wrong type (string expected, got bool)",
133+
args: map[string]interface{}{"myParam": true},
134+
paramName: "myParam",
135+
expectedVal: "", // Zero value for string
136+
expectedOk: true, // ok is true because param exists
137+
expectError: true,
138+
errorMsg: "parameter myParam is not of type string, is bool",
139+
},
140+
{
141+
name: "present but wrong type (bool expected, got string)",
142+
args: map[string]interface{}{"myParam": "true"},
143+
paramName: "myParam",
144+
expectedVal: false, // Zero value for bool
145+
expectedOk: true, // ok is true because param exists
146+
expectError: true,
147+
errorMsg: "parameter myParam is not of type bool, is string",
148+
},
149+
{
150+
name: "parameter not present",
151+
args: map[string]interface{}{"anotherParam": "value"},
152+
paramName: "myParam",
153+
expectedVal: "", // Zero value for string
154+
expectedOk: false,
155+
expectError: false,
156+
},
157+
}
158+
159+
for _, tc := range tests {
160+
t.Run(tc.name, func(t *testing.T) {
161+
request := createMCPRequest(tc.args)
162+
163+
// Test with string type assertion
164+
if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" {
165+
val, ok, err := OptionalParamOK[string](request, tc.paramName)
166+
if tc.expectError {
167+
require.Error(t, err)
168+
assert.Contains(t, err.Error(), tc.errorMsg)
169+
assert.Equal(t, tc.expectedOk, ok) // Check ok even on error
170+
assert.Equal(t, tc.expectedVal, val) // Check zero value on error
171+
} else {
172+
require.NoError(t, err)
173+
assert.Equal(t, tc.expectedOk, ok)
174+
assert.Equal(t, tc.expectedVal, val)
175+
}
176+
}
177+
178+
// Test with bool type assertion
179+
if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" {
180+
val, ok, err := OptionalParamOK[bool](request, tc.paramName)
181+
if tc.expectError {
182+
require.Error(t, err)
183+
assert.Contains(t, err.Error(), tc.errorMsg)
184+
assert.Equal(t, tc.expectedOk, ok) // Check ok even on error
185+
assert.Equal(t, tc.expectedVal, val) // Check zero value on error
186+
} else {
187+
require.NoError(t, err)
188+
assert.Equal(t, tc.expectedOk, ok)
189+
assert.Equal(t, tc.expectedVal, val)
190+
}
191+
}
192+
193+
// Test with float64 type assertion (for number case)
194+
if _, isFloat := tc.expectedVal.(float64); isFloat {
195+
val, ok, err := OptionalParamOK[float64](request, tc.paramName)
196+
if tc.expectError {
197+
// This case shouldn't happen for float64 in the defined tests
198+
require.Fail(t, "Unexpected error case for float64")
199+
} else {
200+
require.NoError(t, err)
201+
assert.Equal(t, tc.expectedOk, ok)
202+
assert.Equal(t, tc.expectedVal, val)
203+
}
204+
}
205+
})
206+
}
207+
}

pkg/github/pullrequests.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,119 @@ func GetPullRequest(client *github.Client, t translations.TranslationHelperFunc)
6767
}
6868
}
6969

70+
// UpdatePullRequest creates a tool to update an existing pull request.
71+
func UpdatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
72+
return mcp.NewTool("update_pull_request",
73+
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")),
74+
mcp.WithString("owner",
75+
mcp.Required(),
76+
mcp.Description("Repository owner"),
77+
),
78+
mcp.WithString("repo",
79+
mcp.Required(),
80+
mcp.Description("Repository name"),
81+
),
82+
mcp.WithNumber("pullNumber",
83+
mcp.Required(),
84+
mcp.Description("Pull request number to update"),
85+
),
86+
mcp.WithString("title",
87+
mcp.Description("New title"),
88+
),
89+
mcp.WithString("body",
90+
mcp.Description("New description"),
91+
),
92+
mcp.WithString("state",
93+
mcp.Description("New state ('open' or 'closed')"),
94+
mcp.Enum("open", "closed"),
95+
),
96+
mcp.WithString("base",
97+
mcp.Description("New base branch name"),
98+
),
99+
mcp.WithBoolean("maintainer_can_modify",
100+
mcp.Description("Allow maintainer edits"),
101+
),
102+
),
103+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
104+
owner, err := requiredParam[string](request, "owner")
105+
if err != nil {
106+
return mcp.NewToolResultError(err.Error()), nil
107+
}
108+
repo, err := requiredParam[string](request, "repo")
109+
if err != nil {
110+
return mcp.NewToolResultError(err.Error()), nil
111+
}
112+
pullNumber, err := RequiredInt(request, "pullNumber")
113+
if err != nil {
114+
return mcp.NewToolResultError(err.Error()), nil
115+
}
116+
117+
// Build the update struct only with provided fields
118+
update := &github.PullRequest{}
119+
updateNeeded := false
120+
121+
if title, ok, err := OptionalParamOK[string](request, "title"); err != nil {
122+
return mcp.NewToolResultError(err.Error()), nil
123+
} else if ok {
124+
update.Title = github.Ptr(title)
125+
updateNeeded = true
126+
}
127+
128+
if body, ok, err := OptionalParamOK[string](request, "body"); err != nil {
129+
return mcp.NewToolResultError(err.Error()), nil
130+
} else if ok {
131+
update.Body = github.Ptr(body)
132+
updateNeeded = true
133+
}
134+
135+
if state, ok, err := OptionalParamOK[string](request, "state"); err != nil {
136+
return mcp.NewToolResultError(err.Error()), nil
137+
} else if ok {
138+
update.State = github.Ptr(state)
139+
updateNeeded = true
140+
}
141+
142+
if base, ok, err := OptionalParamOK[string](request, "base"); err != nil {
143+
return mcp.NewToolResultError(err.Error()), nil
144+
} else if ok {
145+
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
146+
updateNeeded = true
147+
}
148+
149+
if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil {
150+
return mcp.NewToolResultError(err.Error()), nil
151+
} else if ok {
152+
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
153+
updateNeeded = true
154+
}
155+
156+
if !updateNeeded {
157+
return mcp.NewToolResultError("No update parameters provided."), nil
158+
}
159+
160+
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
161+
if err != nil {
162+
return nil, fmt.Errorf("failed to update pull request: %w", err)
163+
}
164+
defer func() { _ = resp.Body.Close() }()
165+
166+
if resp.StatusCode != http.StatusOK {
167+
body, err := io.ReadAll(resp.Body)
168+
if err != nil {
169+
return nil, fmt.Errorf("failed to read response body: %w", err)
170+
}
171+
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
172+
}
173+
174+
r, err := json.Marshal(pr)
175+
if err != nil {
176+
return nil, fmt.Errorf("failed to marshal response: %w", err)
177+
}
178+
179+
return mcp.NewToolResultText(string(r)), nil
180+
}
181+
}
182+
70183
// ListPullRequests creates a tool to list and filter repository pull requests.
71184
func ListPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
72185
return mcp.NewTool("list_pull_requests",

0 commit comments

Comments
 (0)