From 000cce335569483f701c91b9850a14ec3584c2e0 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Sat, 30 May 2026 09:48:41 -0400 Subject: [PATCH] fix(server): guard stateful MCP sessions --- docs/handler-authoring.md | 44 ++++ src/adcp/server/__init__.py | 3 + src/adcp/server/mcp_sessions.py | 357 +++++++++++++++++++++++++++++ src/adcp/server/serve.py | 40 +++- tests/test_mcp_stateful_session.py | 150 +++++++++++- tests/test_serve_config.py | 33 ++- 6 files changed, 617 insertions(+), 10 deletions(-) create mode 100644 src/adcp/server/mcp_sessions.py diff --git a/docs/handler-authoring.md b/docs/handler-authoring.md index 737e14fdc..f3cf8533a 100644 --- a/docs/handler-authoring.md +++ b/docs/handler-authoring.md @@ -1385,6 +1385,50 @@ For the full constructor reference and a migration table from legacy HMAC / bare [`docs/webhooks/migration-from-fragmented-senders.md`](webhooks/migration-from-fragmented-senders.md). See `examples/hello_seller_with_webhooks.py` for a runnable end-to-end wiring example. +## MCP stateful Streamable HTTP lifecycle + +The default MCP transport is stateful Streamable HTTP. It is fastest when a +client initializes once, reuses the returned `Mcp-Session-Id` for the whole +workflow, and closes the client session when the workflow ends. Do not create a +fresh MCP client for every AdCP operation (`get_products`, `create_media_buy`, +`sync_creatives`, etc.) unless you also close each session promptly. + +The SDK defaults `session_idle_timeout=1800.0`, so abandoned sessions are reaped +after 30 minutes. Public or service-to-service sellers that may see one-shot +callers should tune this lower: + +```python +from adcp.server import serve + +serve( + handler, + session_idle_timeout=300.0, +) +``` + +For a hard ceiling, set `max_active_sessions`. New session-creating requests are +rejected with HTTP 429 when the cap is reached, while requests that carry an +existing `Mcp-Session-Id` continue: + +```python +serve( + handler, + session_idle_timeout=300.0, + max_active_sessions=200, +) +``` + +If you build the MCP server yourself, `get_mcp_session_stats()` returns a +snapshot you can export to logs or metrics: + +```python +from adcp.server import create_mcp_server, get_mcp_session_stats + +mcp = create_mcp_server(handler, max_active_sessions=200) +stats = get_mcp_session_stats(mcp).as_dict() +# stats["active_sessions"], stats["session_age_seconds"], etc. +``` + ## Testing The integration test pattern in `tests/test_mcp_middleware_composition.py` diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index cb926f21f..2fc2b9b84 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -108,6 +108,7 @@ async def get_products(params, context=None): valid_actions_for_status, ) from adcp.server.idempotency import IdempotencyStore, MemoryBackend +from adcp.server.mcp_sessions import MCPSessionStats, get_mcp_session_stats from adcp.server.mcp_tools import ( DISCOVERY_METHODS, DISCOVERY_TOOLS, @@ -209,11 +210,13 @@ async def get_products(params, context=None): "DISCOVERY_TOOLS", "LifespanHook", "MCPToolSet", + "MCPSessionStats", "RequestMetadata", "ServeConfig", "create_mcp_tools", "create_mcp_server", "get_tools_for_handler", + "get_mcp_session_stats", "register_handler_tools", "serve", "validate_discovery_set", diff --git a/src/adcp/server/mcp_sessions.py b/src/adcp/server/mcp_sessions.py new file mode 100644 index 000000000..0bf630926 --- /dev/null +++ b/src/adcp/server/mcp_sessions.py @@ -0,0 +1,357 @@ +"""ADCP-managed Streamable HTTP session controls. + +The upstream MCP session manager owns the transport lifecycle. This +module keeps the SDK-specific safety knobs and observability wrapper in +one place so ``serve.py`` does not have to grow more private-FastMCP +plumbing at every call site. +""" + +from __future__ import annotations + +import json +import logging +import time +from collections import deque +from dataclasses import dataclass +from http import HTTPStatus +from typing import Any +from uuid import uuid4 + +import anyio +from anyio.abc import TaskStatus +from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import INVALID_REQUEST, ErrorData, JSONRPCError +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +logger = logging.getLogger("adcp.server") + + +@dataclass(frozen=True) +class MCPSessionStats: + """Snapshot of a Streamable HTTP session manager. + + Numeric age and idle values are seconds from the manager's local + monotonic clock. They are intended for metrics/debug visibility, not + wall-clock audit records. + """ + + active_sessions: int + max_active_sessions: int | None + total_sessions_created: int + sessions_created_last_60s: int + stateless: bool + session_idle_timeout: float | None + session_age_seconds: tuple[float, ...] + session_idle_seconds: tuple[float, ...] + + def as_dict(self) -> dict[str, Any]: + """Return a JSON-serializable representation.""" + return { + "active_sessions": self.active_sessions, + "max_active_sessions": self.max_active_sessions, + "total_sessions_created": self.total_sessions_created, + "sessions_created_last_60s": self.sessions_created_last_60s, + "stateless": self.stateless, + "session_idle_timeout": self.session_idle_timeout, + "session_age_seconds": list(self.session_age_seconds), + "session_idle_seconds": list(self.session_idle_seconds), + } + + +class ADCPStreamableHTTPSessionManager(StreamableHTTPSessionManager): + """Streamable HTTP manager with ADCP safety knobs. + + ``max_active_sessions`` is enforced under the same creation lock that + guards upstream session creation, so concurrent one-shot clients + cannot overshoot the configured cap. + """ + + def __init__( + self, + *args: Any, + max_active_sessions: int | None = None, + **kwargs: Any, + ) -> None: + if max_active_sessions is not None and ( + isinstance(max_active_sessions, bool) + or not isinstance(max_active_sessions, int) + or max_active_sessions <= 0 + ): + raise ValueError( + f"max_active_sessions must be a positive integer (got {max_active_sessions!r}); " + "set None to disable the guard." + ) + super().__init__(*args, **kwargs) + self.max_active_sessions = max_active_sessions + self._session_created_at: dict[str, float] = {} + self._session_last_seen_at: dict[str, float] = {} + self._session_creation_events: deque[float] = deque() + self._total_sessions_created = 0 + + def session_stats(self) -> MCPSessionStats: + """Return a point-in-time session snapshot.""" + now = time.monotonic() + self._prune_tracking(now) + active_ids = set(self._server_instances) + ages = tuple( + max(0.0, now - self._session_created_at[session_id]) + for session_id in sorted(active_ids) + if session_id in self._session_created_at + ) + idle = tuple( + max(0.0, now - self._session_last_seen_at[session_id]) + for session_id in sorted(active_ids) + if session_id in self._session_last_seen_at + ) + return MCPSessionStats( + active_sessions=len(active_ids), + max_active_sessions=self.max_active_sessions, + total_sessions_created=self._total_sessions_created, + sessions_created_last_60s=len(self._session_creation_events), + stateless=self.stateless, + session_idle_timeout=self.session_idle_timeout, + session_age_seconds=ages, + session_idle_seconds=idle, + ) + + async def _handle_stateful_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """Process a stateful request with ADCP max-session enforcement. + + This mirrors upstream MCP 1.27.x with the cap check inserted + under ``_session_creation_lock`` and bookkeeping attached to the + session create / cleanup points. + """ + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: + transport = self._server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + self._session_last_seen_at[request_mcp_session_id] = time.monotonic() + if transport.idle_scope is not None and self.session_idle_timeout is not None: + transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout + await transport.handle_request(scope, receive, send) + if transport.is_terminated: + self._server_instances.pop(request_mcp_session_id, None) + self._forget_session(request_mcp_session_id) + return + + if request_mcp_session_id is None: + logger.debug("Creating new transport") + body = await request.body() + if not _is_initialize_request(body): + await self._send_missing_session_response(scope, receive, send) + return + receive = _replay_body_receive(body) + async with self._session_creation_lock: + if ( + self.max_active_sessions is not None + and len(self._server_instances) >= self.max_active_sessions + ): + await self._send_max_sessions_response(scope, receive, send) + return + + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=self.json_response, + event_store=self.event_store, + security_settings=self.security_settings, + retry_interval=self.retry_interval, + ) + + assert http_transport.mcp_session_id is not None + self._server_instances[http_transport.mcp_session_id] = http_transport + self._remember_session(http_transport.mcp_session_id) + logger.info("Created new MCP stateful transport") + + async def run_server( + *, + task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + try: + idle_scope = anyio.CancelScope() + if self.session_idle_timeout is not None: + idle_scope.deadline = ( + anyio.current_time() + self.session_idle_timeout + ) + http_transport.idle_scope = idle_scope + + with idle_scope: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, + ) + + if idle_scope.cancelled_caught: + assert http_transport.mcp_session_id is not None + logger.info("MCP stateful session idle timeout") + self._server_instances.pop(http_transport.mcp_session_id, None) + self._forget_session(http_transport.mcp_session_id) + await http_transport.terminate() + except Exception: + logger.exception("MCP stateful session crashed") + finally: + if ( + http_transport.mcp_session_id + and http_transport.mcp_session_id in self._server_instances + and not http_transport.is_terminated + ): + logger.info( + "Cleaning up crashed MCP stateful session " + "from active instances." + ) + del self._server_instances[http_transport.mcp_session_id] + self._forget_session(http_transport.mcp_session_id) + + task_group = getattr(self, "_task_group", None) + if task_group is None: + raise RuntimeError("Task group is not initialized. Make sure to use run().") + await task_group.start(run_server) + await http_transport.handle_request(scope, receive, send) + else: + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", + error=ErrorData( + code=INVALID_REQUEST, + message="Session not found", + ), + ) + response = Response( + content=error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=HTTPStatus.NOT_FOUND, + media_type="application/json", + ) + await response(scope, receive, send) + + def _remember_session(self, session_id: str) -> None: + now = time.monotonic() + self._session_created_at[session_id] = now + self._session_last_seen_at[session_id] = now + self._session_creation_events.append(now) + self._total_sessions_created += 1 + self._prune_tracking(now) + + def _forget_session(self, session_id: str) -> None: + self._session_created_at.pop(session_id, None) + self._session_last_seen_at.pop(session_id, None) + + def _prune_tracking(self, now: float) -> None: + active_ids = set(self._server_instances) + for session_id in set(self._session_created_at) - active_ids: + self._forget_session(session_id) + cutoff = now - 60.0 + while self._session_creation_events and self._session_creation_events[0] < cutoff: + self._session_creation_events.popleft() + + async def _send_max_sessions_response( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + limit = self.max_active_sessions + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", + error=ErrorData( + code=INVALID_REQUEST, + message=f"Too many active MCP sessions (limit {limit})", + ), + ) + response = Response( + content=error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=HTTPStatus.TOO_MANY_REQUESTS, + media_type="application/json", + ) + await response(scope, receive, send) + + async def _send_missing_session_response( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", + error=ErrorData( + code=INVALID_REQUEST, + message="Bad Request: Missing session ID", + ), + ) + response = Response( + content=error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=HTTPStatus.BAD_REQUEST, + media_type="application/json", + ) + await response(scope, receive, send) + + +def get_mcp_session_stats(mcp_or_manager: Any) -> MCPSessionStats: + """Return session stats for a FastMCP server or session manager. + + The helper accepts either the object returned by + :func:`adcp.server.create_mcp_server` or its ``_session_manager``. + For non-ADCP managers, only the fields available from upstream MCP + internals are populated. + """ + manager = getattr(mcp_or_manager, "_session_manager", mcp_or_manager) + if hasattr(manager, "session_stats"): + stats = manager.session_stats() + if isinstance(stats, MCPSessionStats): + return stats + + server_instances = getattr(manager, "_server_instances", {}) or {} + return MCPSessionStats( + active_sessions=len(server_instances), + max_active_sessions=getattr(manager, "max_active_sessions", None), + total_sessions_created=0, + sessions_created_last_60s=0, + stateless=bool(getattr(manager, "stateless", False)), + session_idle_timeout=getattr(manager, "session_idle_timeout", None), + session_age_seconds=(), + session_idle_seconds=(), + ) + + +def _is_initialize_request(body: bytes) -> bool: + try: + raw_message = json.loads(body) + except json.JSONDecodeError: + return False + return isinstance(raw_message, dict) and raw_message.get("method") == "initialize" + + +def _replay_body_receive(body: bytes) -> Receive: + sent = False + + async def receive() -> dict[str, Any]: + nonlocal sent + if sent: + return {"type": "http.request", "body": b"", "more_body": False} + sent = True + return {"type": "http.request", "body": body, "more_body": False} + + return receive + + +__all__ = [ + "ADCPStreamableHTTPSessionManager", + "MCPSessionStats", + "get_mcp_session_stats", +] diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 0c7ad35ad..aff59a1e2 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -28,9 +28,8 @@ async def get_adcp_capabilities(self, params, context=None): logger = logging.getLogger("adcp.server") -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager - from adcp.server.base import ADCPHandler, ToolContext +from adcp.server.mcp_sessions import ADCPStreamableHTTPSessionManager from adcp.server.mcp_tools import ( _HANDLER_TOOLS, create_tool_caller, @@ -152,6 +151,7 @@ class ServeConfig: streaming_responses: bool = False stateless_http: bool = False session_idle_timeout: float | None = 1800.0 + max_active_sessions: int | None = None # --- A2A / both --- task_store: TaskStore | None = None @@ -190,7 +190,12 @@ def __post_init__(self) -> None: # spuriously under transport='a2a'. ``stateless_http`` (default # False) and ``streaming_responses`` (default False) work # cleanly with the heuristic. - _mcp_only = ("instructions", "streaming_responses", "stateless_http") + _mcp_only = ( + "instructions", + "streaming_responses", + "stateless_http", + "max_active_sessions", + ) if self.transport == "a2a": mcp_set = sorted(f for f in _mcp_only if getattr(self, f) not in (None, False)) if mcp_set: @@ -598,6 +603,7 @@ def serve( streaming_responses: bool = False, stateless_http: bool = False, session_idle_timeout: float | None = 1800.0, + max_active_sessions: int | None = None, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, pre_validation_hooks: PreValidationHooks | None = None, enable_debug_endpoints: bool = False, @@ -750,6 +756,11 @@ def serve( sessions are terminated and their per-session state freed. Defaults to 1800 (30 min); ``None`` disables reaping. Ignored when ``stateless_http=True``. + max_active_sessions: Optional cap for simultaneously active + stateful MCP sessions. When the cap is reached, new + session-creating requests are rejected with HTTP 429 while + requests carrying an existing ``Mcp-Session-Id`` continue. + Ignored when ``stateless_http=True``. enable_debug_endpoints: When ``True``, mount ``GET /_debug/traffic`` on the outer HTTP app. Returns the JSON dict from ``debug_traffic_source()`` — typically wired to the @@ -896,6 +907,7 @@ async def force_account_status(self, account_id, status): streaming_responses = config.streaming_responses stateless_http = config.stateless_http session_idle_timeout = config.session_idle_timeout + max_active_sessions = config.max_active_sessions validation = config.validation pre_validation_hooks = config.pre_validation_hooks enable_debug_endpoints = config.enable_debug_endpoints @@ -986,6 +998,7 @@ async def force_account_status(self, account_id, status): streaming_responses=streaming_responses, stateless_http=stateless_http, session_idle_timeout=session_idle_timeout, + max_active_sessions=max_active_sessions, validation=validation, pre_validation_hooks=pre_validation_hooks, base_url=base_url, @@ -1016,6 +1029,7 @@ async def force_account_status(self, account_id, status): streaming_responses=streaming_responses, stateless_http=stateless_http, session_idle_timeout=session_idle_timeout, + max_active_sessions=max_active_sessions, validation=validation, pre_validation_hooks=pre_validation_hooks, base_url=base_url, @@ -1421,6 +1435,7 @@ def _serve_mcp( streaming_responses: bool = False, stateless_http: bool = False, session_idle_timeout: float | None = 1800.0, + max_active_sessions: int | None = None, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, pre_validation_hooks: PreValidationHooks | None = None, base_url: str | None = None, @@ -1445,6 +1460,7 @@ def _serve_mcp( streaming_responses=streaming_responses, stateless_http=stateless_http, session_idle_timeout=session_idle_timeout, + max_active_sessions=max_active_sessions, validation=validation, pre_validation_hooks=pre_validation_hooks, allowed_hosts=allowed_hosts, @@ -1671,6 +1687,7 @@ def _build_mcp_and_a2a_app( streaming_responses: bool = False, stateless_http: bool = False, session_idle_timeout: float | None = 1800.0, + max_active_sessions: int | None = None, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, pre_validation_hooks: PreValidationHooks | None = None, base_url: str | None = None, @@ -1718,6 +1735,7 @@ def _build_mcp_and_a2a_app( streaming_responses=streaming_responses, stateless_http=stateless_http, session_idle_timeout=session_idle_timeout, + max_active_sessions=max_active_sessions, validation=validation, pre_validation_hooks=pre_validation_hooks, allowed_hosts=allowed_hosts, @@ -1922,6 +1940,7 @@ def _serve_mcp_and_a2a( streaming_responses: bool = False, stateless_http: bool = False, session_idle_timeout: float | None = 1800.0, + max_active_sessions: int | None = None, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, pre_validation_hooks: PreValidationHooks | None = None, base_url: str | None = None, @@ -1974,6 +1993,7 @@ def _serve_mcp_and_a2a( streaming_responses=streaming_responses, stateless_http=stateless_http, session_idle_timeout=session_idle_timeout, + max_active_sessions=max_active_sessions, validation=validation, pre_validation_hooks=pre_validation_hooks, base_url=base_url, @@ -2060,6 +2080,7 @@ def create_mcp_server( streaming_responses: bool = False, stateless_http: bool = False, session_idle_timeout: float | None = 1800.0, + max_active_sessions: int | None = None, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, pre_validation_hooks: PreValidationHooks | None = None, allowed_hosts: Sequence[str] | None = None, @@ -2148,6 +2169,13 @@ def create_mcp_server( because without it ``StreamableHTTPSessionManager._server_instances`` grows without bound for clients that disconnect without DELETE. + max_active_sessions: Optional cap for active stateful MCP + sessions. When the cap is reached, new session-creating + requests are rejected with HTTP 429; requests that carry an + existing ``Mcp-Session-Id`` continue. Set this on public or + service-to-service sellers that need a hard ceiling against + clients opening one session per operation. Ignored when + ``stateless_http=True``. Returns: A configured FastMCP server instance. Call ``mcp.run()`` to start, @@ -2256,7 +2284,8 @@ def create_mcp_server( pre_validation_hooks=pre_validation_hooks, ) # Pre-create the StreamableHTTPSessionManager so we can pass - # ``session_idle_timeout`` — FastMCP's settings don't expose it as of + # ``session_idle_timeout`` and ADCP's session safety knobs — + # FastMCP's settings don't expose these as of # mcp 1.27.x. ``streamable_http_app()`` lazy-creates the manager only # if ``_session_manager`` is ``None``, so populating it here is the # extension point. Reaches into FastMCP private attrs ``_mcp_server``, @@ -2274,7 +2303,7 @@ def create_mcp_server( # warn on every server boot otherwise. Adopters who explicitly want a # timeout should set ``stateless_http=False``. idle_timeout = None if mcp.settings.stateless_http else session_idle_timeout - mcp._session_manager = StreamableHTTPSessionManager( + mcp._session_manager = ADCPStreamableHTTPSessionManager( app=mcp._mcp_server, event_store=mcp._event_store, retry_interval=mcp._retry_interval, @@ -2282,6 +2311,7 @@ def create_mcp_server( stateless=mcp.settings.stateless_http, security_settings=mcp.settings.transport_security, session_idle_timeout=idle_timeout, + max_active_sessions=max_active_sessions, ) return mcp diff --git a/tests/test_mcp_stateful_session.py b/tests/test_mcp_stateful_session.py index dd22e86fa..90415760c 100644 --- a/tests/test_mcp_stateful_session.py +++ b/tests/test_mcp_stateful_session.py @@ -22,6 +22,8 @@ combine with ``stateless=True`` — we suppress the timeout automatically when adopters opt back into stateless rather than letting the upstream constructor raise. +5. Optional active-session guardrails and stats for one-shot clients + that create sessions without closing them. """ from __future__ import annotations @@ -32,7 +34,7 @@ import pytest from asgi_lifespan import LifespanManager -from adcp.server import ADCPHandler, create_mcp_server +from adcp.server import ADCPHandler, create_mcp_server, get_mcp_session_stats class _BareHandler(ADCPHandler[Any]): @@ -50,6 +52,10 @@ def test_default_is_stateful() -> None: assert mcp.settings.json_response is True assert mcp._session_manager.session_idle_timeout == 1800.0 assert mcp._session_manager.stateless is False + stats = get_mcp_session_stats(mcp) + assert stats.active_sessions == 0 + assert stats.max_active_sessions is None + assert stats.session_idle_timeout == 1800.0 def test_stateless_opt_in_drops_idle_timeout() -> None: @@ -74,6 +80,30 @@ def test_stateful_opt_in_explicit_timeout() -> None: assert mcp._session_manager.session_idle_timeout == 600.0 +def test_stateful_opt_in_max_active_sessions() -> None: + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + max_active_sessions=2, + ) + assert mcp._session_manager.max_active_sessions == 2 + assert get_mcp_session_stats(mcp).max_active_sessions == 2 + + +@pytest.mark.parametrize("value", [0, -1, True, 1.5, "2"]) +def test_invalid_max_active_sessions_rejected_at_boundary(value: Any) -> None: + with pytest.raises(ValueError, match="max_active_sessions must be a positive integer"): + create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + max_active_sessions=value, + ) + + def test_stateful_with_disabled_timeout() -> None: mcp = create_mcp_server( _BareHandler(), @@ -143,6 +173,7 @@ async def test_stateful_session_reuses_across_calls() -> None: name="t", advertise_all=True, stateless_http=False, + max_active_sessions=1, allowed_hosts=["localhost", "127.0.0.1"], ) app = mcp.streamable_http_app() @@ -183,6 +214,103 @@ async def test_stateful_session_reuses_across_calls() -> None: assert list_resp.status_code == 200, list_resp.text +@pytest.mark.asyncio +async def test_stateful_max_active_sessions_rejects_new_sessions() -> None: + """One-shot clients that repeatedly initialize without closing can + be bounded with ``max_active_sessions`` while existing sessions keep + working.""" + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + max_active_sessions=1, + allowed_hosts=["localhost", "127.0.0.1"], + ) + app = mcp.streamable_http_app() + headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + init_resp = await client.post( + "/mcp/", + json={ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "t", "version": "1"}, + }, + }, + headers=headers, + ) + assert init_resp.status_code == 200, init_resp.text + session_id = init_resp.headers.get("mcp-session-id") + assert session_id + + stats = get_mcp_session_stats(mcp) + assert stats.active_sessions == 1 + assert stats.total_sessions_created == 1 + assert stats.sessions_created_last_60s == 1 + assert len(stats.session_age_seconds) == 1 + + second_init = await client.post( + "/mcp/", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "t2", "version": "1"}, + }, + }, + headers=headers, + ) + assert second_init.status_code == 429 + assert "Too many active MCP sessions" in second_init.text + + list_resp = await client.post( + "/mcp/", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}, + headers={**headers, "mcp-session-id": session_id}, + ) + assert list_resp.status_code == 200, list_resp.text + + delete_resp = await client.delete( + "/mcp/", + headers={**headers, "mcp-session-id": session_id}, + ) + assert delete_resp.status_code == 200, delete_resp.text + assert get_mcp_session_stats(mcp).active_sessions == 0 + + after_delete = await client.post( + "/mcp/", + json={ + "jsonrpc": "2.0", + "id": 3, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "t3", "version": "1"}, + }, + }, + headers=headers, + ) + assert after_delete.status_code == 200, after_delete.text + + @pytest.mark.asyncio async def test_stateful_auth_propagates_via_request_state() -> None: """The headline guarantee for the default flip: in stateful mode @@ -303,3 +431,23 @@ async def test_stateful_rejects_request_without_session_id() -> None: ) assert resp.status_code == 400 assert "session" in resp.text.lower() + assert get_mcp_session_stats(mcp).active_sessions == 0 + + init = await client.post( + "/mcp/", + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "t", "version": "1"}, + }, + }, + headers={ + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + ) + assert init.status_code == 200, init.text diff --git a/tests/test_serve_config.py b/tests/test_serve_config.py index a4f1e8c88..6545e57c2 100644 --- a/tests/test_serve_config.py +++ b/tests/test_serve_config.py @@ -44,6 +44,7 @@ def test_serve_config_defaults() -> None: assert cfg.host is None assert cfg.advertise_all is False assert cfg.streaming_responses is False + assert cfg.max_active_sessions is None assert cfg.enable_debug_endpoints is False assert cfg.middleware is None assert cfg.validation is None @@ -85,7 +86,7 @@ def test_serve_config_warns_a2a_only_on_mcp_transport() -> None: def test_serve_config_warns_mcp_only_on_a2a_transport() -> None: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - ServeConfig(transport="a2a", instructions="hello") + ServeConfig(transport="a2a", instructions="hello", max_active_sessions=10) messages = [str(w.message) for w in caught if issubclass(w.category, UserWarning)] assert any("MCP-only" in m for m in messages), messages @@ -134,9 +135,9 @@ def test_serve_config_kwargs_ignored_when_config_provided() -> None: mock_mcp.assert_called_once() _, kwargs = mock_mcp.call_args - assert kwargs.get("name") == "from-config", ( - "config.name should override the per-kwarg name when config= is provided" - ) + assert ( + kwargs.get("name") == "from-config" + ), "config.name should override the per-kwarg name when config= is provided" def test_serve_without_config_uses_kwargs() -> None: @@ -161,3 +162,27 @@ def test_serve_config_advertise_all_propagates() -> None: mock_mcp.assert_called_once() _, kwargs = mock_mcp.call_args assert kwargs.get("advertise_all") is True + + +def test_serve_config_max_active_sessions_propagates() -> None: + handler = _StubHandler() + cfg = ServeConfig(transport="streamable-http", max_active_sessions=10) + + with patch.object(_serve_mod, "_serve_mcp") as mock_mcp: + _serve_mod.serve(handler, config=cfg) + + mock_mcp.assert_called_once() + _, kwargs = mock_mcp.call_args + assert kwargs.get("max_active_sessions") == 10 + + +def test_serve_config_max_active_sessions_propagates_to_both_transport() -> None: + handler = _StubHandler() + cfg = ServeConfig(transport="both", max_active_sessions=10) + + with patch.object(_serve_mod, "_serve_mcp_and_a2a") as mock_both: + _serve_mod.serve(handler, config=cfg) + + mock_both.assert_called_once() + _, kwargs = mock_both.call_args + assert kwargs.get("max_active_sessions") == 10