Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ async def main():

logger = logging.getLogger(__name__)

DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS = 1.0

LifespanResultT = TypeVar("LifespanResultT", default=Any)

_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel)
Expand Down Expand Up @@ -406,6 +408,13 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
# When True, treat read EOF as a half-close and allow in-flight handlers
# to drain their responses via the still-open write stream (e.g. stdio
# with bash-redirected stdin).
drain_on_read_close: bool = False,
# Maximum time to wait for in-flight handlers to drain after read EOF.
# None means wait indefinitely.
read_eof_drain_timeout_seconds: float | None = DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS,
) -> None:
async with self.lifespan(self) as lifespan_context:
dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
Expand All @@ -416,6 +425,8 @@ async def run(
# the next request (spec says SHOULD NOT, not MUST NOT) sees
# the initialized state instead of failing the init-gate.
inline_methods=frozenset({"initialize"}),
close_write_stream_on_read_close=not drain_on_read_close,
read_eof_drain_timeout_seconds=read_eof_drain_timeout_seconds,
)
runner = ServerRunner(
server=self,
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ async def run_stdio_async(self) -> None:
read_stream,
write_stream,
self._lowlevel_server.create_initialization_options(),
drain_on_read_close=True,
)

async def run_sse_async( # pragma: no cover
Expand Down
33 changes: 25 additions & 8 deletions src/mcp/shared/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import contextvars
import logging
from collections.abc import Awaitable, Callable, Mapping
from contextlib import AsyncExitStack
from dataclasses import dataclass, field
from typing import Any, Generic, Literal, TypeVar, cast, overload

Expand Down Expand Up @@ -226,6 +227,8 @@ def __init__(
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
close_write_stream_on_read_close: bool = True,
read_eof_drain_timeout_seconds: float | None = None,
) -> None: ...
@overload
def __init__(
Expand All @@ -237,6 +240,8 @@ def __init__(
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
close_write_stream_on_read_close: bool = True,
read_eof_drain_timeout_seconds: float | None = None,
) -> None: ...
def __init__(
self,
Expand All @@ -247,6 +252,8 @@ def __init__(
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
close_write_stream_on_read_close: bool = True,
read_eof_drain_timeout_seconds: float | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
Expand All @@ -259,6 +266,8 @@ def __init__(
)
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
self._raise_handler_exceptions = raise_handler_exceptions
self._close_write_stream_on_read_close = close_write_stream_on_read_close
self._read_eof_drain_timeout_seconds = read_eof_drain_timeout_seconds
# Request methods handled inline in the read loop (awaited before the
# next message is dequeued) instead of spawned concurrently. Use for
# methods whose side effects must be observable to the next message,
Expand Down Expand Up @@ -400,13 +409,17 @@ async def run(
`await tg.start(dispatcher.run, ...)` resumes when `send_raw_request`
is usable.
"""
normal_eof = False
try:
async with anyio.create_task_group() as tg:
self._tg = tg
self._running = True
task_status.started()
try:
async with self._read_stream, self._write_stream:
async with AsyncExitStack() as stack:
await stack.enter_async_context(self._read_stream)
if self._close_write_stream_on_read_close:
await stack.enter_async_context(self._write_stream)
try:
async for item in self._read_stream:
# Duck-typed: `_context_streams.ContextReceiveStream`
Expand All @@ -425,20 +438,24 @@ async def run(
# (callers outside this task group) with CONNECTION_CLOSED.
self._running = False
self._fan_out_closed()
normal_eof = True
finally:
# Transport closed: cancel in-flight handlers. Without this
# the task-group join waits for them, and a handler that
# outlives its caller (its request timed out client-side, or
# the client disconnected mid-call) would keep `run()` from
# returning forever. Same behaviour as `Server.run()` before
# the dispatcher rework.
tg.cancel_scope.cancel()
if not normal_eof or self._close_write_stream_on_read_close:
# Transport closed abnormally: cancel in-flight handlers.
# On normal EOF, let already-received handlers drain
# their responses before the task group exits.
tg.cancel_scope.cancel()
elif self._read_eof_drain_timeout_seconds is not None:
tg.cancel_scope.deadline = anyio.current_time() + self._read_eof_drain_timeout_seconds
finally:
# Covers the cancel/crash paths where the inline fan-out above is
# never reached. Idempotent.
self._running = False
self._tg = None
self._fan_out_closed()
if not self._close_write_stream_on_read_close:
with anyio.CancelScope(shield=True):
await self._write_stream.aclose()

async def _dispatch(
self,
Expand Down
15 changes: 14 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,24 @@ def __init__(
write_stream: WriteStream[SessionMessage],
# If none, reading will never time out
read_timeout_seconds: float | None = None,
# When True, closing/EOF on the read stream closes the write stream too.
#
# For full-duplex transports (e.g., stdio), an input EOF can be a
# half-close: the peer is done sending requests but still expects
# responses on the output stream. In that case, callers may opt out so
# in-flight handlers can drain their responses before shutdown.
close_write_stream_on_read_close: bool = True,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._session_read_timeout_seconds = read_timeout_seconds
self._close_write_stream_on_read_close = close_write_stream_on_read_close
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()
self._exit_stack.push_async_callback(self._read_stream.aclose)
self._exit_stack.push_async_callback(self._write_stream.aclose)

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
Expand Down Expand Up @@ -291,7 +301,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError

async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
async with AsyncExitStack() as stack:
await stack.enter_async_context(self._read_stream)
if self._close_write_stream_on_read_close:
await stack.enter_async_context(self._write_stream)
try:

async def _handle_session_message(message: SessionMessage) -> None:
Expand Down
Loading
Loading