Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,16 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = 5.0,
) -> None:
async with self.lifespan(self) as lifespan_context:
dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
read_stream,
write_stream,
raise_handler_exceptions=raise_exceptions,
drain_in_flight_on_read_eof=drain_in_flight_on_read_eof,
read_eof_response_drain_timeout=read_eof_response_drain_timeout,
# Handle `initialize` inline so a client that pipelines it with
# the next request (spec says SHOULD NOT, not MUST NOT) sees
# the initialized state instead of failing the init-gate.
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ async def run_stdio_async(self) -> None:
read_stream,
write_stream,
self._lowlevel_server.create_initialization_options(),
drain_in_flight_on_read_eof=True,
)

async def run_sse_async( # pragma: no cover
Expand Down
23 changes: 23 additions & 0 deletions src/mcp/shared/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def __init__(
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = 5.0,
) -> None: ...
@overload
def __init__(
Expand All @@ -237,6 +239,8 @@ def __init__(
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = 5.0,
) -> None: ...
def __init__(
self,
Expand All @@ -247,6 +251,8 @@ def __init__(
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = 5.0,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
Expand All @@ -259,6 +265,8 @@ def __init__(
)
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
self._raise_handler_exceptions = raise_handler_exceptions
self._drain_in_flight_on_read_eof = drain_in_flight_on_read_eof
self._read_eof_response_drain_timeout = read_eof_response_drain_timeout
# Request methods handled inline in the read loop (awaited before the
# next message is dequeued) instead of spawned concurrently. Use for
# methods whose side effects must be observable to the next message,
Expand All @@ -272,6 +280,7 @@ def __init__(
self._next_id = 0
self._pending: dict[RequestId, _Pending] = {}
self._in_flight: dict[RequestId, _InFlight[TransportT]] = {}
self._responses_in_flight: set[RequestId] = set()
self._tg: anyio.abc.TaskGroup | None = None
self._running = False

Expand Down Expand Up @@ -421,6 +430,12 @@ async def run(
# back to `__anext__` on the now-closed stream
# (stateless SHTTP teardown). Same as EOF.
logger.debug("read stream closed by transport; treating as EOF")
if self._drain_in_flight_on_read_eof:
with anyio.move_on_after(self._read_eof_response_drain_timeout) as scope:
while self._in_flight or self._responses_in_flight:
await anyio.sleep(0)
if scope.cancelled_caught:
logger.debug("timed out draining in-flight responses after read EOF")
# Read stream EOF: wake any blocked `send_raw_request` waiters
# (callers outside this task group) with CONNECTION_CLOSED.
self._running = False
Expand Down Expand Up @@ -716,16 +731,24 @@ async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))

async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None:
key = _coerce_id(request_id)
self._responses_in_flight.add(key)
try:
await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result))
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped result for %r: write stream closed", request_id)
finally:
self._responses_in_flight.discard(key)

async def _write_error(self, request_id: RequestId, error: ErrorData) -> None:
key = _coerce_id(request_id)
self._responses_in_flight.add(key)
try:
await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error))
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped error for %r: write stream closed", request_id)
finally:
self._responses_in_flight.discard(key)

async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None:
# Thread `related_request_id` so streamable-HTTP routes the cancel onto
Expand Down
65 changes: 65 additions & 0 deletions tests/server/test_cancel_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,71 @@ async def run_server():
assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out():
"""A bounded read-EOF drain still cancels handlers that never finish."""
handler_started = anyio.Event()
handler_cancelled = anyio.Event()
server_run_returned = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
try:
await anyio.sleep_forever()
finally:
handler_cancelled.set()
raise AssertionError # pragma: no cover

server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
await server.run(
server_read,
server_write,
server.create_initialization_options(),
drain_in_flight_on_read_eof=True,
read_eof_response_drain_timeout=0.01,
)
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive()
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
await to_server.aclose()

await server_run_returned.wait()

assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
"""When the transport closes while handlers are blocked on server→client
Expand Down
57 changes: 56 additions & 1 deletion tests/server/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import json
import sys
import threading
from collections.abc import AsyncIterator
Expand All @@ -7,11 +8,12 @@

import anyio
import pytest
from anyio.lowlevel import checkpoint

from mcp.server.mcpserver import MCPServer
from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter


@pytest.mark.anyio
Expand Down Expand Up @@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke
assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={})


def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""stdin EOF must not drop responses for requests the server already accepted."""
server = MCPServer(name="DrainStdioServer")

@server.tool()
async def slow_echo(text: str) -> str:
await checkpoint()
return text

payload_lines = [
JSONRPCRequest(
jsonrpc="2.0",
id=0,
method="initialize",
params={
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "stdio-replay", "version": "0.1"},
},
).model_dump_json(by_alias=True, exclude_none=True),
JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json(
by_alias=True, exclude_none=True
),
JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="tools/call",
params={"name": "slow_echo", "arguments": {"text": "first"}},
).model_dump_json(by_alias=True, exclude_none=True),
JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "slow_echo", "arguments": {"text": "second"}},
).model_dump_json(by_alias=True, exclude_none=True),
]
stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode())
captured = _KeepOpenBytesIO()
monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8"))

_run_stdio_bounded(server)

output = captured.getvalue().decode()
responses = [json.loads(line) for line in output.splitlines() if line]

assert [response["id"] for response in responses] == [0, 1, 2]
assert responses[1]["result"]["content"][0]["text"] == "first"
assert responses[2]["result"]["content"][0]["text"] == "second"


def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None:
"""Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.

Expand Down
Loading