Skip to content

Close unclosed resources in the whole project #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ packages = ["src/mcp"]
include = ["src/mcp", "tests"]
venvPath = "."
venv = ".venv"
strict = [
"src/mcp/server/fastmcp/tools/base.py",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More should be added here.

]

[tool.ruff.lint]
select = ["E", "F", "I"]
Expand All @@ -85,3 +88,13 @@ members = ["examples/servers/*"]

[tool.uv.sources]
mcp = { workspace = true }

[tool.pytest.ini_options]
xfail_strict = true
filterwarnings = [
"error",
# This should be fixed on Uvicorn's side.
"ignore::DeprecationWarning:websockets",
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
]
10 changes: 7 additions & 3 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ async def _default_list_roots_callback(
)


ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)


class ClientSession(
Expand Down Expand Up @@ -219,7 +221,7 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
)

async def call_tool(
self, name: str, arguments: dict | None = None
self, name: str, arguments: dict[str, Any] | None = None
) -> types.CallToolResult:
"""Send a tools/call request."""
return await self.send_request(
Expand Down Expand Up @@ -258,7 +260,9 @@ async def get_prompt(
)

async def complete(
self, ref: types.ResourceReference | types.PromptReference, argument: dict
self,
ref: types.ResourceReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""
return await self.send_request(
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
class Tool(BaseModel):
"""Internal tool registration info."""

fn: Callable = Field(exclude=True)
fn: Callable[..., Any] = Field(exclude=True)
name: str = Field(description="Name of the tool")
description: str = Field(description="Description of what the tool does")
parameters: dict = Field(description="JSON schema for tool parameters")
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool"
" arguments"
Expand All @@ -34,7 +34,7 @@ class Tool(BaseModel):
@classmethod
def from_function(
cls,
fn: Callable,
fn: Callable[..., Any],
name: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/server/fastmcp/utilities/func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
)


def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata:
def func_metadata(
func: Callable[..., Any], skip_names: Sequence[str] = ()
) -> FuncMetadata:
"""Given a function, return metadata including a pydantic model representing its
signature.

Expand Down
12 changes: 12 additions & 0 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Any, Callable, Generic, TypeVar

Expand Down Expand Up @@ -180,13 +181,20 @@ def __init__(
self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}

self._exit_stack = AsyncExitStack()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
Expand All @@ -195,6 +203,7 @@ async def __aenter__(self) -> Self:
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
Expand Down Expand Up @@ -222,6 +231,9 @@ async def send_request(
](1)
self._response_streams[request_id] = response_stream

self._exit_stack.push_async_callback(lambda: response_stream.aclose())
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())

jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
Expand Down
4 changes: 4 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ async def listen_session():
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
tg.start_soon(listen_session)
Expand Down
8 changes: 7 additions & 1 deletion tests/issues/test_192_request_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ async def run_server():
)

# Start server task
async with anyio.create_task_group() as tg:
async with (
anyio.create_task_group() as tg,
client_writer,
client_reader,
server_writer,
server_reader,
):
tg.start_soon(run_server)

# Send initialize request
Expand Down
18 changes: 15 additions & 3 deletions tests/server/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def test_lowlevel_server_lifespan():
"""Test that lifespan works in low-level server."""

@asynccontextmanager
async def test_lifespan(server: Server) -> AsyncIterator[dict]:
async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]:
"""Test lifespan context that tracks startup/shutdown."""
context = {"started": False, "shutdown": False}
try:
Expand All @@ -50,7 +50,13 @@ async def check_lifespan(name: str, arguments: dict) -> list:
return [{"type": "text", "text": "true"}]

# Run server in background task
async with anyio.create_task_group() as tg:
async with (
anyio.create_task_group() as tg,
send_stream1,
receive_stream1,
send_stream2,
receive_stream2,
):

async def run_server():
await server.run(
Expand Down Expand Up @@ -147,7 +153,13 @@ def check_lifespan(ctx: Context) -> bool:
return True

# Run server in background task
async with anyio.create_task_group() as tg:
async with (
anyio.create_task_group() as tg,
send_stream1,
receive_stream1,
send_stream2,
receive_stream2,
):

async def run_server():
await server._mcp_server.run(
Expand Down