-
Notifications
You must be signed in to change notification settings - Fork 0
[pull] main from modelcontextprotocol:main #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,14 @@ class Meta(BaseModel): | |
meta: Meta | None = Field(alias="_meta", default=None) | ||
|
||
|
||
class PaginatedRequestParams(RequestParams): | ||
cursor: Cursor | None = None | ||
""" | ||
An opaque token representing the current pagination position. | ||
If provided, the server should return results starting after this cursor. | ||
""" | ||
|
||
|
||
class NotificationParams(BaseModel): | ||
class Meta(BaseModel): | ||
model_config = ConfigDict(extra="allow") | ||
|
@@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): | |
model_config = ConfigDict(extra="allow") | ||
|
||
|
||
class PaginatedRequest(Request[RequestParamsT, MethodT]): | ||
cursor: Cursor | None = None | ||
""" | ||
An opaque token representing the current pagination position. | ||
If provided, the server should return results starting after this cursor. | ||
""" | ||
class PaginatedRequest( | ||
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT] | ||
): | ||
"""Base class for paginated requests, | ||
matching the schema's PaginatedRequest interface.""" | ||
|
||
params: PaginatedRequestParams | None = None | ||
|
||
|
||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): | ||
|
@@ -358,13 +367,10 @@ class ProgressNotification( | |
params: ProgressNotificationParams | ||
|
||
|
||
class ListResourcesRequest( | ||
PaginatedRequest[RequestParams | None, Literal["resources/list"]] | ||
): | ||
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): | ||
"""Sent from the client to request a list of resources the server has.""" | ||
|
||
method: Literal["resources/list"] | ||
params: RequestParams | None = None | ||
|
||
Comment on lines
+370
to
374
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
class Annotations(BaseModel): | ||
|
@@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult): | |
|
||
|
||
class ListResourceTemplatesRequest( | ||
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]] | ||
PaginatedRequest[Literal["resources/templates/list"]] | ||
): | ||
"""Sent from the client to request a list of resource templates the server has.""" | ||
|
||
method: Literal["resources/templates/list"] | ||
params: RequestParams | None = None | ||
|
||
|
||
class ListResourceTemplatesResult(PaginatedResult): | ||
|
@@ -570,13 +575,10 @@ class ResourceUpdatedNotification( | |
params: ResourceUpdatedNotificationParams | ||
|
||
|
||
class ListPromptsRequest( | ||
PaginatedRequest[RequestParams | None, Literal["prompts/list"]] | ||
): | ||
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]): | ||
"""Sent from the client to request a list of prompts and prompt templates.""" | ||
|
||
method: Literal["prompts/list"] | ||
params: RequestParams | None = None | ||
|
||
|
||
class PromptArgument(BaseModel): | ||
|
@@ -703,11 +705,10 @@ class PromptListChangedNotification( | |
params: NotificationParams | None = None | ||
|
||
|
||
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]): | ||
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): | ||
"""Sent from the client to request a list of tools the server has.""" | ||
|
||
method: Literal["tools/list"] | ||
params: RequestParams | None = None | ||
|
||
|
||
class ToolAnnotations(BaseModel): | ||
|
@@ -741,7 +742,7 @@ class ToolAnnotations(BaseModel): | |
|
||
idempotentHint: bool | None = None | ||
""" | ||
If true, calling the tool repeatedly with the same arguments | ||
If true, calling the tool repeatedly with the same arguments | ||
will have no additional effect on the its environment. | ||
(This property is meaningful only when `readOnlyHint == false`) | ||
Default: false | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from contextlib import asynccontextmanager | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
import mcp.shared.memory | ||
from mcp.shared.message import SessionMessage | ||
from mcp.types import ( | ||
JSONRPCNotification, | ||
JSONRPCRequest, | ||
) | ||
|
||
|
||
class SpyMemoryObjectSendStream: | ||
def __init__(self, original_stream): | ||
self.original_stream = original_stream | ||
self.sent_messages: list[SessionMessage] = [] | ||
|
||
async def send(self, message): | ||
self.sent_messages.append(message) | ||
await self.original_stream.send(message) | ||
|
||
async def aclose(self): | ||
await self.original_stream.aclose() | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, *args): | ||
await self.aclose() | ||
|
||
|
||
class StreamSpyCollection: | ||
def __init__( | ||
self, | ||
client_spy: SpyMemoryObjectSendStream, | ||
server_spy: SpyMemoryObjectSendStream, | ||
): | ||
self.client = client_spy | ||
self.server = server_spy | ||
|
||
def clear(self) -> None: | ||
"""Clear all captured messages.""" | ||
self.client.sent_messages.clear() | ||
self.server.sent_messages.clear() | ||
|
||
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: | ||
"""Get client-sent requests, optionally filtered by method.""" | ||
return [ | ||
req.message.root | ||
for req in self.client.sent_messages | ||
if isinstance(req.message.root, JSONRPCRequest) | ||
and (method is None or req.message.root.method == method) | ||
] | ||
|
||
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: | ||
"""Get server-sent requests, optionally filtered by method.""" | ||
return [ | ||
req.message.root | ||
for req in self.server.sent_messages | ||
if isinstance(req.message.root, JSONRPCRequest) | ||
and (method is None or req.message.root.method == method) | ||
] | ||
|
||
def get_client_notifications( | ||
self, method: str | None = None | ||
) -> list[JSONRPCNotification]: | ||
"""Get client-sent notifications, optionally filtered by method.""" | ||
return [ | ||
notif.message.root | ||
for notif in self.client.sent_messages | ||
if isinstance(notif.message.root, JSONRPCNotification) | ||
and (method is None or notif.message.root.method == method) | ||
] | ||
|
||
def get_server_notifications( | ||
self, method: str | None = None | ||
) -> list[JSONRPCNotification]: | ||
"""Get server-sent notifications, optionally filtered by method.""" | ||
return [ | ||
notif.message.root | ||
for notif in self.server.sent_messages | ||
if isinstance(notif.message.root, JSONRPCNotification) | ||
and (method is None or notif.message.root.method == method) | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def stream_spy(): | ||
"""Fixture that provides spies for both client and server write streams. | ||
|
||
Example usage: | ||
async def test_something(stream_spy): | ||
# ... set up server and client ... | ||
|
||
spies = stream_spy() | ||
|
||
# Run some operation that sends messages | ||
await client.some_operation() | ||
|
||
# Check the messages | ||
requests = spies.get_client_requests(method="some/method") | ||
assert len(requests) == 1 | ||
|
||
# Clear for the next operation | ||
spies.clear() | ||
""" | ||
client_spy = None | ||
server_spy = None | ||
|
||
# Store references to our spy objects | ||
def capture_spies(c_spy, s_spy): | ||
nonlocal client_spy, server_spy | ||
client_spy = c_spy | ||
server_spy = s_spy | ||
|
||
# Create patched version of stream creation | ||
original_create_streams = mcp.shared.memory.create_client_server_memory_streams | ||
|
||
@asynccontextmanager | ||
async def patched_create_streams(): | ||
async with original_create_streams() as (client_streams, server_streams): | ||
client_read, client_write = client_streams | ||
server_read, server_write = server_streams | ||
|
||
# Create spy wrappers | ||
spy_client_write = SpyMemoryObjectSendStream(client_write) | ||
spy_server_write = SpyMemoryObjectSendStream(server_write) | ||
|
||
# Capture references for the test to use | ||
capture_spies(spy_client_write, spy_server_write) | ||
|
||
yield (client_read, spy_client_write), (server_read, spy_server_write) | ||
|
||
# Apply the patch for the duration of the test | ||
with patch( | ||
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams | ||
): | ||
# Return a collection with helper methods | ||
def get_spy_collection() -> StreamSpyCollection: | ||
assert client_spy is not None, "client_spy was not initialized" | ||
assert server_spy is not None, "server_spy was not initialized" | ||
return StreamSpyCollection(client_spy, server_spy) | ||
|
||
yield get_spy_collection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional formatting for the
params
parameter could be simplified for better readability. Consider using a one-liner or a separate variable assignment before the method call.