diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index d2536189d..a92b771ae 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -406,12 +406,16 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = 5.0, ) -> None: async with self.lifespan(self) as lifespan_context: dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( read_stream, write_stream, raise_handler_exceptions=raise_exceptions, + drain_in_flight_on_read_eof=drain_in_flight_on_read_eof, + read_eof_response_drain_timeout=read_eof_response_drain_timeout, # Handle `initialize` inline so a client that pipelines it with # the next request (spec says SHOULD NOT, not MUST NOT) sees # the initialized state instead of failing the init-gate. diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index fdb69571d..06547db49 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -848,6 +848,7 @@ async def run_stdio_async(self) -> None: read_stream, write_stream, self._lowlevel_server.create_initialization_options(), + drain_in_flight_on_read_eof=True, ) async def run_sse_async( # pragma: no cover diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 457e6b6f7..b7adeeaea 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -226,6 +226,8 @@ def __init__( peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = 5.0, ) -> None: ... @overload def __init__( @@ -237,6 +239,8 @@ def __init__( peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = 5.0, ) -> None: ... def __init__( self, @@ -247,6 +251,8 @@ def __init__( peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = 5.0, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -259,6 +265,8 @@ def __init__( ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions + self._drain_in_flight_on_read_eof = drain_in_flight_on_read_eof + self._read_eof_response_drain_timeout = read_eof_response_drain_timeout # Request methods handled inline in the read loop (awaited before the # next message is dequeued) instead of spawned concurrently. Use for # methods whose side effects must be observable to the next message, @@ -272,6 +280,7 @@ def __init__( self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._responses_in_flight: set[RequestId] = set() self._tg: anyio.abc.TaskGroup | None = None self._running = False @@ -421,6 +430,12 @@ async def run( # back to `__anext__` on the now-closed stream # (stateless SHTTP teardown). Same as EOF. logger.debug("read stream closed by transport; treating as EOF") + if self._drain_in_flight_on_read_eof: + with anyio.move_on_after(self._read_eof_response_drain_timeout) as scope: + while self._in_flight or self._responses_in_flight: + await anyio.sleep(0) + if scope.cancelled_caught: + logger.debug("timed out draining in-flight responses after read EOF") # Read stream EOF: wake any blocked `send_raw_request` waiters # (callers outside this task group) with CONNECTION_CLOSED. self._running = False @@ -716,16 +731,24 @@ async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + key = _coerce_id(request_id) + self._responses_in_flight.add(key) try: await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped result for %r: write stream closed", request_id) + finally: + self._responses_in_flight.discard(key) async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + key = _coerce_id(request_id) + self._responses_in_flight.add(key) try: await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped error for %r: write stream closed", request_id) + finally: + self._responses_in_flight.discard(key) async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None: # Thread `related_request_id` so streamable-HTTP routes the cancel onto diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index cff5a37c1..17b43c9c9 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -172,6 +172,71 @@ async def run_server(): assert handler_cancelled.is_set() +@pytest.mark.anyio +async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out(): + """A bounded read-EOF drain still cancels handlers that never finish.""" + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run( + server_read, + server_write, + server.create_initialization_options(), + drain_in_flight_on_read_eof=True, + read_eof_response_drain_timeout=0.01, + ) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + await to_server.aclose() + + await server_run_returned.wait() + + assert handler_cancelled.is_set() + + @pytest.mark.anyio async def test_server_handles_transport_close_with_pending_server_to_client_requests(): """When the transport closes while handlers are blocked on server→client diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 054a157b3..544a74412 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,4 +1,5 @@ import io +import json import sys import threading from collections.abc import AsyncIterator @@ -7,11 +8,12 @@ import anyio import pytest +from anyio.lowlevel import checkpoint from mcp.server.mcpserver import MCPServer from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter @pytest.mark.anyio @@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={}) +def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """stdin EOF must not drop responses for requests the server already accepted.""" + server = MCPServer(name="DrainStdioServer") + + @server.tool() + async def slow_echo(text: str) -> str: + await checkpoint() + return text + + payload_lines = [ + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "stdio-replay", "version": "0.1"}, + }, + ).model_dump_json(by_alias=True, exclude_none=True), + JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json( + by_alias=True, exclude_none=True + ), + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="tools/call", + params={"name": "slow_echo", "arguments": {"text": "first"}}, + ).model_dump_json(by_alias=True, exclude_none=True), + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "slow_echo", "arguments": {"text": "second"}}, + ).model_dump_json(by_alias=True, exclude_none=True), + ] + stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode()) + captured = _KeepOpenBytesIO() + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8")) + + _run_stdio_bounded(server) + + output = captured.getvalue().decode() + responses = [json.loads(line) for line in output.splitlines() if line] + + assert [response["id"] for response in responses] == [0, 1, 2] + assert responses[1]["result"]["content"][0]["text"] == "first" + assert responses[2]["result"]["content"][0]["text"] == "second" + + def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None: """Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.