|
6 | 6 | import pathlib |
7 | 7 | import sys |
8 | 8 | import textwrap |
9 | | -from typing import TYPE_CHECKING, cast |
10 | | -from unittest.mock import Mock, patch |
| 9 | +from contextlib import ExitStack |
| 10 | +from typing import TYPE_CHECKING, Any, cast |
| 11 | +from unittest.mock import MagicMock, Mock, patch |
11 | 12 |
|
12 | 13 | import pytest |
13 | 14 |
|
|
45 | 46 | from marimo._runtime.dataflow import EdgeWithVar |
46 | 47 | from marimo._runtime.patches import create_main_module |
47 | 48 | from marimo._runtime.runner.hooks import create_default_hooks |
48 | | -from marimo._runtime.runtime import Kernel, notebook_dir, notebook_location |
| 49 | +from marimo._runtime.runtime import ( |
| 50 | + Kernel, |
| 51 | + launch_kernel, |
| 52 | + notebook_dir, |
| 53 | + notebook_location, |
| 54 | +) |
49 | 55 | from marimo._runtime.scratch import SCRATCH_CELL_ID |
50 | 56 | from marimo._session.model import SessionMode |
51 | 57 | from marimo._utils.parse_dataclass import parse_raw |
52 | 58 | from tests._messaging.mocks import MockStderr, MockStream |
53 | 59 | from tests.conftest import ExecReqProvider, MockedKernel |
54 | 60 |
|
55 | 61 | if TYPE_CHECKING: |
56 | | - from collections.abc import Sequence |
| 62 | + from collections.abc import Coroutine, Sequence |
57 | 63 |
|
58 | 64 |
|
59 | 65 | def _check_edges(error: Error, expected_edges: Sequence[EdgeWithVar]) -> None: |
@@ -4175,3 +4181,148 @@ async def test_request_handler_only_created_once( |
4175 | 4181 | assert handler1 is handler2 |
4176 | 4182 | assert handler2 is handler3 |
4177 | 4183 | assert handler1 is handler3 |
| 4184 | + |
| 4185 | + |
| 4186 | +class TestLaunchKernelEventLoop: |
| 4187 | + """Event-loop policy / factory selection in launch_kernel. |
| 4188 | +
|
| 4189 | + The kernel subprocess must run on the Windows ProactorEventLoop so |
| 4190 | + user code can use asyncio.create_subprocess_exec() and other APIs |
| 4191 | + the SelectorEventLoop does not implement. The server keeps the |
| 4192 | + SelectorEventLoop because ConnectionDistributor relies on |
| 4193 | + loop.add_reader(). |
| 4194 | +
|
| 4195 | + These tests stub out everything after the event-loop setup so only |
| 4196 | + the policy / loop_factory decision is exercised. |
| 4197 | + """ |
| 4198 | + |
| 4199 | + _HEAVY_DEPENDENCY_TARGETS = [ |
| 4200 | + "marimo._runtime.runtime.restore_signals", |
| 4201 | + "marimo._runtime.runtime.ThreadSafeStream", |
| 4202 | + "marimo._runtime.runtime.ThreadSafeStdout", |
| 4203 | + "marimo._runtime.runtime.ThreadSafeStderr", |
| 4204 | + "marimo._runtime.runtime.ThreadSafeStdin", |
| 4205 | + "marimo._runtime.runtime.marimo_pdb.MarimoPdb", |
| 4206 | + "marimo._runtime.runtime.Kernel", |
| 4207 | + "marimo._runtime.runtime.initialize_kernel_context", |
| 4208 | + "marimo._runtime.runtime.patches.patch_main_module", |
| 4209 | + "marimo._output.formatters.formatters.register_formatters", |
| 4210 | + ] |
| 4211 | + |
| 4212 | + class _StopAfterAsyncioRun(Exception): |
| 4213 | + """Sentinel raised from the mocked asyncio.run so we skip the |
| 4214 | + post-run teardown path (which touches a runtime context we |
| 4215 | + haven't initialized).""" |
| 4216 | + |
| 4217 | + @classmethod |
| 4218 | + def _fake_asyncio_run( |
| 4219 | + cls, coro: Coroutine[Any, Any, Any], **_kwargs: Any |
| 4220 | + ) -> None: |
| 4221 | + # Close the never-awaited coroutine to suppress the |
| 4222 | + # RuntimeWarning, then bail so we don't execute the post-run |
| 4223 | + # teardown. |
| 4224 | + coro.close() |
| 4225 | + raise cls._StopAfterAsyncioRun |
| 4226 | + |
| 4227 | + @classmethod |
| 4228 | + def _call_launch_kernel(cls, *, is_edit_mode: bool) -> None: |
| 4229 | + with pytest.raises(cls._StopAfterAsyncioRun): |
| 4230 | + launch_kernel( |
| 4231 | + control_queue=MagicMock(), |
| 4232 | + set_ui_element_queue=MagicMock(), |
| 4233 | + completion_queue=MagicMock(), |
| 4234 | + input_queue=MagicMock(), |
| 4235 | + stream_queue=MagicMock(), |
| 4236 | + socket_addr=None, |
| 4237 | + is_edit_mode=is_edit_mode, |
| 4238 | + configs={}, |
| 4239 | + app_metadata=AppMetadata( |
| 4240 | + query_params={}, cli_args={}, app_config=_AppConfig() |
| 4241 | + ), |
| 4242 | + user_config=DEFAULT_CONFIG, |
| 4243 | + virtual_file_storage=None, |
| 4244 | + redirect_console_to_browser=False, |
| 4245 | + ) |
| 4246 | + |
| 4247 | + @pytest.fixture |
| 4248 | + def harness(self): |
| 4249 | + """Neutralize launch_kernel's heavy dependencies so the test |
| 4250 | + only observes the event-loop policy / loop_factory decision.""" |
| 4251 | + with ExitStack() as stack: |
| 4252 | + for target in self._HEAVY_DEPENDENCY_TARGETS: |
| 4253 | + stack.enter_context(patch(target)) |
| 4254 | + # `signal` is used as `signal.signal(...)` and references |
| 4255 | + # `signal.SIGBREAK`, which only exists on Windows — swap |
| 4256 | + # the whole module ref so non-Windows hosts don't blow up. |
| 4257 | + stack.enter_context( |
| 4258 | + patch("marimo._runtime.runtime.signal", new=MagicMock()) |
| 4259 | + ) |
| 4260 | + run_mock = MagicMock(side_effect=self._fake_asyncio_run) |
| 4261 | + stack.enter_context(patch("asyncio.run", run_mock)) |
| 4262 | + yield run_mock |
| 4263 | + |
| 4264 | + def test_non_windows_does_not_change_event_loop_policy(self, harness): |
| 4265 | + with ( |
| 4266 | + patch("sys.platform", "linux"), |
| 4267 | + patch.object(asyncio, "set_event_loop_policy") as set_policy, |
| 4268 | + ): |
| 4269 | + self._call_launch_kernel(is_edit_mode=True) |
| 4270 | + |
| 4271 | + set_policy.assert_not_called() |
| 4272 | + assert harness.call_count == 1 |
| 4273 | + assert "loop_factory" not in harness.call_args.kwargs |
| 4274 | + |
| 4275 | + def test_windows_pre_314_installs_proactor_event_loop_policy( |
| 4276 | + self, harness |
| 4277 | + ): |
| 4278 | + with ( |
| 4279 | + patch("sys.platform", "win32"), |
| 4280 | + patch("sys.version_info", (3, 12, 0, "final", 0)), |
| 4281 | + patch.object( |
| 4282 | + asyncio, "WindowsProactorEventLoopPolicy", create=True |
| 4283 | + ) as policy_cls, |
| 4284 | + patch.object(asyncio, "set_event_loop_policy") as set_policy, |
| 4285 | + ): |
| 4286 | + self._call_launch_kernel(is_edit_mode=True) |
| 4287 | + |
| 4288 | + policy_cls.assert_called_once_with() |
| 4289 | + set_policy.assert_called_once_with(policy_cls.return_value) |
| 4290 | + # Pre-3.14 uses the policy API, not loop_factory. |
| 4291 | + assert "loop_factory" not in harness.call_args.kwargs |
| 4292 | + |
| 4293 | + def test_windows_314_plus_uses_proactor_loop_factory(self, harness): |
| 4294 | + # Event loop policies are deprecated in 3.14; launch_kernel must |
| 4295 | + # pass ProactorEventLoop as the loop_factory to asyncio.run |
| 4296 | + # instead of mutating the global policy. |
| 4297 | + with ( |
| 4298 | + patch("sys.platform", "win32"), |
| 4299 | + patch("sys.version_info", (3, 14, 0, "final", 0)), |
| 4300 | + patch.object( |
| 4301 | + asyncio, "ProactorEventLoop", create=True |
| 4302 | + ) as proactor_cls, |
| 4303 | + patch.object(asyncio, "set_event_loop_policy") as set_policy, |
| 4304 | + ): |
| 4305 | + self._call_launch_kernel(is_edit_mode=True) |
| 4306 | + |
| 4307 | + set_policy.assert_not_called() |
| 4308 | + assert harness.call_args.kwargs.get("loop_factory") is proactor_cls |
| 4309 | + |
| 4310 | + def test_run_mode_on_windows_does_not_touch_event_loop_policy( |
| 4311 | + self, harness |
| 4312 | + ): |
| 4313 | + # Run mode (not edit, not IPC) runs in-process on the server's |
| 4314 | + # loop and must NOT mutate the event loop policy — the server |
| 4315 | + # uses the Selector loop for ConnectionDistributor.add_reader(). |
| 4316 | + with ( |
| 4317 | + patch("sys.platform", "win32"), |
| 4318 | + patch("sys.version_info", (3, 12, 0, "final", 0)), |
| 4319 | + patch.object( |
| 4320 | + asyncio, "WindowsProactorEventLoopPolicy", create=True |
| 4321 | + ) as policy_cls, |
| 4322 | + patch.object(asyncio, "set_event_loop_policy") as set_policy, |
| 4323 | + ): |
| 4324 | + self._call_launch_kernel(is_edit_mode=False) |
| 4325 | + |
| 4326 | + policy_cls.assert_not_called() |
| 4327 | + set_policy.assert_not_called() |
| 4328 | + assert "loop_factory" not in harness.call_args.kwargs |
0 commit comments