diff --git a/marimo/_ast/codegen.py b/marimo/_ast/codegen.py index d7dabb03eed..88fb5b1b8e4 100644 --- a/marimo/_ast/codegen.py +++ b/marimo/_ast/codegen.py @@ -543,9 +543,11 @@ def generate_filecontents( ) -> str: """Translates a sequences of codes (cells) to a Python file""" - # Update old internal cell names to the new ones + # Normalize internal cell names. Empty names would emit ``def ():`` + # (invalid Python) and fall back to the unparsable-cell path; + # ``"__"`` is a legacy internal marker. for idx, name in enumerate(names): - if name == "__": + if name == "__" or not name: names[idx] = DEFAULT_CELL_NAME setup_cell = pop_setup_cell(codes, names, cell_configs) diff --git a/marimo/_session/extensions/extensions.py b/marimo/_session/extensions/extensions.py index b0dfe6b4e83..cb805457a16 100644 --- a/marimo/_session/extensions/extensions.py +++ b/marimo/_session/extensions/extensions.py @@ -8,14 +8,19 @@ from __future__ import annotations import asyncio +import copy +import html from enum import Enum +from functools import partial from typing import TYPE_CHECKING import msgspec from marimo import _loggers from marimo._cli.print import red +from marimo._messaging.notebook.document import NotebookCell from marimo._messaging.notification import ( + AlertNotification, NotebookDocumentTransactionNotification, NotificationMessage, ) @@ -26,6 +31,7 @@ EventAwareExtension, SessionExtension, ) +from marimo._session.model import SessionMode from marimo._session.state.serialize import ( SessionCacheKey, SessionCacheManager, @@ -41,6 +47,7 @@ QueueDistributor, ) from marimo._utils.print import print_, print_tabbed +from marimo._utils.serial_task_runner import SerialTaskRunner if TYPE_CHECKING: from logging import Logger @@ -218,6 +225,12 @@ def __init__( self.kernel_manager = kernel_manager self.queue_manager = queue_manager self.distributor: Distributor[KernelMessage] | None = None + # Log the unnamed-notebook skip once per session, not per mutation. + self._unnamed_autosave_logged = False + # FIFO so a slow older save never clobbers a newer one. + self._autosave_runner = SerialTaskRunner( + thread_name_prefix="marimo-autosave" + ) def _create_distributor( self, @@ -232,19 +245,22 @@ def _create_distributor( # Edit mode with original kernel manager uses connection return ConnectionDistributor(kernel_manager.kernel_connection) - @staticmethod - def _on_kernel_message(session: Session, msg: KernelMessage) -> None: + def _on_kernel_message(self, session: Session, msg: KernelMessage) -> None: """Route a raw kernel message to the appropriate session method. Document transactions are intercepted and applied to the ``session.document``, then ``session.notify()`` is invoked with the (versioned) result. + Kernel-sourced transactions also trigger an auto-save so agent-driven + mutations via ``code_mode`` land on disk the same way frontend edits do. + Everything else is forwarded verbatim via ``session.notify()``. TODO: if more notification types need server-side interception, consider a middleware chain instead of inline dispatch. """ notif: KernelMessage | NotificationMessage = msg + kernel_transaction_applied = False name = try_deserialize_kernel_notification_name(msg) if name == NotebookDocumentTransactionNotification.name: @@ -257,6 +273,7 @@ def _on_kernel_message(session: Session, msg: KernelMessage) -> None: notif = NotebookDocumentTransactionNotification( transaction=applied ) + kernel_transaction_applied = applied.source == "kernel" except Exception: LOGGER.warning( "Failed to decode/apply kernel document transaction" @@ -264,6 +281,62 @@ def _on_kernel_message(session: Session, msg: KernelMessage) -> None: session.notify(notif, from_consumer_id=None) + if kernel_transaction_applied: + self._maybe_autosave(session) + + def _maybe_autosave(self, session: Session) -> None: + """Best-effort persistence of kernel-driven mutations to disk. + + Skipped in run mode and for unnamed notebooks. Failures surface as + an ``AlertNotification`` toast; they never raise out of the + interceptor. + """ + if self.kernel_manager.mode != SessionMode.EDIT: + return + + if session.app_file_manager.path is None: + if not self._unnamed_autosave_logged: + LOGGER.debug( + "Skipping code_mode auto-save for unnamed notebook" + ) + self._unnamed_autosave_logged = True + return + + # Deep-copy on the caller thread. ``NotebookCell`` and + # ``CellConfig`` are mutable and owned by the document, so a + # shallow copy would let the event-loop thread mutate fields + # under the worker thread's feet (torn snapshot). + cells_snapshot: list[NotebookCell] = copy.deepcopy( + session.document.cells + ) + + self._autosave_runner.submit( + partial(session.app_file_manager.save_from_cells, cells_snapshot), + on_error=partial(self._post_autosave_failure, session), + ) + + @staticmethod + def _post_autosave_failure(session: Session, err: Exception) -> None: + # Runs on the event loop thread — the runner routes on_error there + # so session.notify can safely touch the per-consumer asyncio.Queue. + LOGGER.warning( + "Failed to auto-save notebook after kernel mutation: %s", err + ) + try: + session.notify( + AlertNotification( + title="Auto-save failed", + description=html.escape( + f"Could not persist kernel changes to " + f"{session.app_file_manager.path}: {err}" + ), + variant="danger", + ), + from_consumer_id=None, + ) + except Exception: + LOGGER.exception("Failed to broadcast auto-save failure alert") + def on_attach(self, session: Session, event_bus: SessionEventBus) -> None: del event_bus self.distributor = self._create_distributor( @@ -279,6 +352,8 @@ def on_detach(self) -> None: if self.distributor is not None: self.distributor.stop() self.distributor = None + # Don't block session close on disk I/O; kernel still holds state. + self._autosave_runner.shutdown(wait=False) def flush(self) -> None: """Flush any pending messages from the distributor.""" diff --git a/marimo/_session/notebook/file_manager.py b/marimo/_session/notebook/file_manager.py index 61b8222d88c..2c14c09299a 100644 --- a/marimo/_session/notebook/file_manager.py +++ b/marimo/_session/notebook/file_manager.py @@ -2,6 +2,7 @@ from __future__ import annotations import os +import threading from pathlib import Path from typing import TYPE_CHECKING, Any @@ -30,6 +31,9 @@ LOGGER = _loggers.marimo_logger() if TYPE_CHECKING: + from collections.abc import Sequence + + from marimo._messaging.notebook.document import NotebookCell from marimo._server.models.models import ( CopyNotebookRequest, SaveNotebookRequest, @@ -72,6 +76,11 @@ def __init__( # Track the last saved content to avoid reloading our own writes self._last_saved_content: str | None = None + # Serializes concurrent writers. Reentrant so public entry points + # can wrap the full "mutate app + _save_file" sequence while + # ``_save_file`` re-acquires for any direct caller. + self._save_lock = threading.RLock() + @property def filename(self) -> str | None: """Get the current filename as a Path object.""" @@ -176,6 +185,8 @@ def _save_file( ) -> str: """Save notebook to storage using appropriate format handler. + All file writes go through this method under ``_save_lock``. + Args: path: Target file path notebook: Notebook in IR format @@ -187,50 +198,53 @@ def _save_file( """ LOGGER.debug("Saving app to %s", path) - # Get the header in case it was modified by the user (e.g. package installation) - handler = get_notebook_serializer(path) - header: str | None = None - if previous_path and previous_path.exists(): - header = handler.extract_header(previous_path) - elif path.exists(): - header = handler.extract_header(path) - - # For new .py files in sandbox mode, generate header with marimo - if header is None and str(path).endswith(".py"): - from marimo._config.settings import GLOBAL_SETTINGS - - if GLOBAL_SETTINGS.MANAGE_SCRIPT_METADATA: - from marimo._utils.scripts import write_pyproject_to_script - - header = write_pyproject_to_script( - with_python_version_requirement( - { - "dependencies": ["marimo"], - } + with self._save_lock: + # Get the header in case it was modified by the user (e.g. package installation) + handler = get_notebook_serializer(path) + header: str | None = None + if previous_path and previous_path.exists(): + header = handler.extract_header(previous_path) + elif path.exists(): + header = handler.extract_header(path) + + # For new .py files in sandbox mode, generate header with marimo + if header is None and str(path).endswith(".py"): + from marimo._config.settings import GLOBAL_SETTINGS + + if GLOBAL_SETTINGS.MANAGE_SCRIPT_METADATA: + from marimo._utils.scripts import ( + write_pyproject_to_script, ) - ) - # Rewrap with header if relevant and set filename. - notebook = NotebookSerializationV1( - app=notebook.app, - header=Header(value=header) if header else notebook.header, - cells=notebook.cells, - violations=notebook.violations, - valid=notebook.valid, - filename=str(path), - ) - contents = handler.serialize(notebook) + header = write_pyproject_to_script( + with_python_version_requirement( + { + "dependencies": ["marimo"], + } + ) + ) - if persist: - self.storage.write(path, contents) - # Record the last saved content to avoid reloading our own writes - self._last_saved_content = contents.strip() + # Rewrap with header if relevant and set filename. + notebook = NotebookSerializationV1( + app=notebook.app, + header=Header(value=header) if header else notebook.header, + cells=notebook.cells, + violations=notebook.violations, + valid=notebook.valid, + filename=str(path), + ) + contents = handler.serialize(notebook) + + if persist: + self.storage.write(path, contents) + # Record the last saved content to avoid reloading our own writes + self._last_saved_content = contents.strip() - # If this is a new unnamed notebook, update the filename - if self._is_unnamed(): - self._filename = path + # If this is a new unnamed notebook, update the filename + if self._is_unnamed(): + self._filename = path - return contents + return contents def _load_app(self, path: str | None) -> InternalApp: """Load app from storage. @@ -287,29 +301,30 @@ def rename(self, new_filename: str | Path) -> str: """ new_path = Path(canonicalize_filename(str(new_filename))) - if self._is_same_path(new_path): - return new_path.name + with self._save_lock: + if self._is_same_path(new_path): + return new_path.name - self._assert_path_does_not_exist(new_path) + self._assert_path_does_not_exist(new_path) - if self._filename is not None: - self.storage.rename(self._filename, new_path) - else: - # Create new file for unnamed notebooks - self.storage.write(new_path, "") + if self._filename is not None: + self.storage.rename(self._filename, new_path) + else: + # Create new file for unnamed notebooks + self.storage.write(new_path, "") - previous_filename = self._filename - self._filename = new_path - self.app._app._filename = str(new_path) + previous_filename = self._filename + self._filename = new_path + self.app._app._filename = str(new_path) - self._save_file( - new_path, - notebook=self.app.to_ir(), - persist=True, - previous_path=previous_filename, - ) + self._save_file( + new_path, + notebook=self.app.to_ir(), + persist=True, + previous_path=previous_filename, + ) - return new_path.name + return new_path.name def read_layout_config(self) -> LayoutConfig | None: """Read layout configuration file. @@ -366,14 +381,15 @@ def save_app_config(self, config: dict[str, Any]) -> str: Returns: Serialized notebook content """ - self.app.update_config(config) - if self._filename is not None: - return self._save_file( - self._filename, - notebook=self.app.to_ir(), - persist=True, - ) - return "" + with self._save_lock: + self.app.update_config(config) + if self._filename is not None: + return self._save_file( + self._filename, + notebook=self.app.to_ir(), + persist=True, + ) + return "" def save(self, request: SaveNotebookRequest) -> str: """Save the notebook. @@ -398,37 +414,70 @@ def save(self, request: SaveNotebookRequest) -> str: filename_path = Path(canonicalize_filename(filename)) - # Update app with new cell data - self.app.with_data( - cell_ids=cell_ids, - codes=codes, - names=names, - configs=configs, - ) + with self._save_lock: + # Update app with new cell data + self.app.with_data( + cell_ids=cell_ids, + codes=codes, + names=names, + configs=configs, + ) + + if self.is_notebook_named and not self._is_same_path( + filename_path + ): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Save handler cannot rename files.", + ) + + # Save layout if provided + if layout is not None: + app_dir = filename_path.parent + app_name = filename_path.name + layout_filename = save_layout_config( + app_dir, app_name, LayoutConfig(**layout) + ) + self.app.update_config({"layout_file": layout_filename}) + else: + # Remove the layout from the config + self.app.update_config({"layout_file": None}) + + return self._save_file( + filename_path, + notebook=self.app.to_ir(), + persist=request.persist, + ) + + def save_from_cells(self, cells: Sequence[NotebookCell]) -> str: + """Persist the notebook from a snapshot of document cells. - if self.is_notebook_named and not self._is_same_path(filename_path): + Used by the server-side auto-save path for ``code_mode`` + mutations. Unlike ``save()``, this takes cells directly — the + caller is responsible for snapshotting ``session.document.cells`` + on a thread where the document is quiescent. + + Raises: + HTTPException: If the notebook is unnamed or the write fails + """ + if self._filename is None: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, - detail="Save handler cannot rename files.", + detail="Cannot save an unnamed notebook", ) - # Save layout if provided - if layout is not None: - app_dir = filename_path.parent - app_name = filename_path.name - layout_filename = save_layout_config( - app_dir, app_name, LayoutConfig(**layout) + with self._save_lock: + self.app.with_data( + cell_ids=[cell.id for cell in cells], + codes=[cell.code for cell in cells], + names=[cell.name for cell in cells], + configs=[cell.config for cell in cells], + ) + return self._save_file( + self._filename, + notebook=self.app.to_ir(), + persist=True, ) - self.app.update_config({"layout_file": layout_filename}) - else: - # Remove the layout from the config - self.app.update_config({"layout_file": None}) - - return self._save_file( - filename_path, - notebook=self.app.to_ir(), - persist=request.persist, - ) def copy(self, request: CopyNotebookRequest) -> str: """Copy a notebook file. diff --git a/marimo/_utils/serial_task_runner.py b/marimo/_utils/serial_task_runner.py new file mode 100644 index 00000000000..f2101a10deb --- /dev/null +++ b/marimo/_utils/serial_task_runner.py @@ -0,0 +1,150 @@ +# Copyright 2026 Marimo. All rights reserved. +"""FIFO dispatch of blocking work off the asyncio event loop. + +Usage:: + + runner = SerialTaskRunner(thread_name_prefix="autosave") + runner.submit( + lambda: file_manager.save_from_cells(cells), + on_error=lambda err: session.notify(AlertNotification(...)), + ) + runner.shutdown() # on session close + await runner.drain() # in async tests + +``submit`` and ``shutdown`` must be called from the asyncio event loop +thread (or from any thread when no loop is running). ``work`` runs on +the executor thread; ``on_error`` is routed back to the loop thread via +``call_soon_threadsafe`` so it can safely touch asyncio primitives. +""" + +from __future__ import annotations + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from marimo import _loggers + +if TYPE_CHECKING: + from collections.abc import Callable + +LOGGER = _loggers.marimo_logger() + + +class SerialTaskRunner: + """FIFO-ordered dispatch of blocking work to a dedicated worker thread.""" + + def __init__(self, *, thread_name_prefix: str = "serial-task") -> None: + self._thread_name_prefix = thread_name_prefix + self._pending: list[asyncio.Future[Any]] = [] + self._closed = False + + @cached_property + def _executor(self) -> ThreadPoolExecutor: + return ThreadPoolExecutor( + max_workers=1, thread_name_prefix=self._thread_name_prefix + ) + + @property + def pending(self) -> list[asyncio.Future[Any]]: + """In-flight futures; tests can await these to synchronize.""" + return self._pending + + def submit( + self, + work: Callable[[], object], + *, + on_error: Callable[[Exception], None] | None = None, + ) -> None: + """Run ``work()`` on the serial worker thread. + + Offloads to the executor when called from the event loop; + otherwise runs inline on the caller thread (e.g. the + ``QueueDistributor`` worker thread). ``on_error`` is invoked + with any exception raised by ``work`` — posted back to the event + loop when off-loop, inline otherwise. A failing ``on_error`` is + logged and swallowed. + + After ``shutdown()``, ``submit`` is a no-op (and logs at debug). + Without this guard, ``cached_property`` would silently + re-materialize a fresh executor — and a new worker thread — for + any late submissions that race session teardown. + """ + if self._closed: + LOGGER.debug( + "SerialTaskRunner.submit called after shutdown; dropping task" + ) + return + + try: + loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop() + except RuntimeError: + loop = None + + def _run() -> None: + try: + work() + except Exception as err: + self._handle_error(loop, on_error, err) + + if loop is None: + _run() + return + + fut = loop.run_in_executor(self._executor, _run) + # Prune done futures so the list stays bounded over long sessions. + self._pending[:] = [f for f in self._pending if not f.done()] + self._pending.append(fut) + + @staticmethod + def _handle_error( + loop: asyncio.AbstractEventLoop | None, + on_error: Callable[[Exception], None] | None, + err: Exception, + ) -> None: + if on_error is None: + LOGGER.error( + "SerialTaskRunner task failed with no on_error handler: %s", + err, + exc_info=err, + ) + return + + def _safe_on_error() -> None: + try: + on_error(err) + except Exception as handler_err: + LOGGER.error( + "SerialTaskRunner on_error callback failed: %s", + handler_err, + exc_info=handler_err, + ) + + if loop is None: + _safe_on_error() + else: + loop.call_soon_threadsafe(_safe_on_error) + + async def drain(self) -> None: + """Await every in-flight task, then clear the pending list. + + ``return_exceptions=True`` so a failing task doesn't abort the drain. + """ + if not self._pending: + return + await asyncio.gather(*self._pending, return_exceptions=True) + self._pending.clear() + + def shutdown(self, *, wait: bool = False) -> None: + """Tear down the executor. Idempotent; no-op if never materialized. + + Uses ``__dict__.pop`` so we don't trigger the ``cached_property`` + just to shut it down. Sets ``_closed`` so any subsequent + ``submit`` becomes a no-op instead of re-materializing a new + executor via ``cached_property``. + """ + self._closed = True + executor = self.__dict__.pop("_executor", None) + if executor is not None: + executor.shutdown(wait=wait) diff --git a/tests/_code_mode/test_context_autosave.py b/tests/_code_mode/test_context_autosave.py new file mode 100644 index 00000000000..68a21c6009d --- /dev/null +++ b/tests/_code_mode/test_context_autosave.py @@ -0,0 +1,279 @@ +# Copyright 2026 Marimo. All rights reserved. +"""End-to-end tests: code_mode mutations → interceptor → file on disk. + +These tests connect two well-tested pieces (``code_mode`` producing +transactions and ``NotificationListenerExtension`` consuming them) by +taking the transactions code_mode emits from a real ``Kernel`` and +feeding them through the auto-save interceptor against a real +``AppFileManager`` backed by a temp file. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import Mock + +import pytest + +from marimo._ast.cell_id import CellIdGenerator +from marimo._code_mode._context import AsyncCodeModeContext +from marimo._messaging.notebook.document import ( + NotebookCell, + NotebookDocument, + notebook_document_context, +) +from marimo._messaging.notification import ( + NotebookDocumentTransactionNotification, +) +from marimo._messaging.serde import serialize_kernel_message +from marimo._session.extensions.extensions import ( + NotificationListenerExtension, +) +from marimo._session.model import SessionMode +from marimo._session.notebook import AppFileManager + +if TYPE_CHECKING: + from collections.abc import Generator + + from marimo._runtime.runtime import Kernel + + +INITIAL_NOTEBOOK_PY = """ +import marimo +__generated_with = "0.0.1" +app = marimo.App() + +@app.cell +def _(): + return + +if __name__ == "__main__": + app.run() +""" + +INITIAL_NOTEBOOK_MD = """--- +title: Test +marimo-version: "0.0.1" +--- + +```python {.marimo} +``` +""" + + +@contextmanager +def _ctx(k: Kernel) -> Generator[AsyncCodeModeContext, None, None]: + cells = [ + NotebookCell(id=cid, code=cell.code, name="", config=cell.config) + for cid, cell in k.graph.cells.items() + ] + doc = NotebookDocument(cells) + with notebook_document_context(doc): + ctx = AsyncCodeModeContext(k) + ctx._id_generator = CellIdGenerator(seed=7) + ctx._id_generator.seen_ids = set(doc.cell_ids) + yield ctx + + +def _read_disk(app_file_manager: AppFileManager) -> str: + """Sync helper to read a notebook file from disk (avoids ASYNC240 in + async tests driven by ``pytest-asyncio`` + code_mode fixtures).""" + path = app_file_manager.path + assert path is not None + return Path(path).read_text() + + +def _make_notebook_fixture(filename: str, contents: str): + """Build a parametrizable fixture that creates an AppFileManager backed + by ``tmp_path / filename`` pre-populated with ``contents``.""" + + @pytest.fixture + def _fixture(tmp_path: Path) -> AppFileManager: + temp_file = tmp_path / filename + temp_file.write_text(contents) + return AppFileManager(filename=str(temp_file)) + + return _fixture + + +py_notebook = _make_notebook_fixture("notebook.py", INITIAL_NOTEBOOK_PY) +md_notebook = _make_notebook_fixture("notebook.md", INITIAL_NOTEBOOK_MD) + + +def _make_session_for(app_file_manager: AppFileManager) -> Mock: + s = Mock() + s.app_file_manager = app_file_manager + s.document = NotebookDocument() + s.notify = Mock() + return s + + +@pytest.fixture +def ext() -> NotificationListenerExtension: + kernel_manager = Mock() + kernel_manager.mode = SessionMode.EDIT + queue_manager = Mock() + queue_manager.stream_queue = None + return NotificationListenerExtension(kernel_manager, queue_manager) + + +@pytest.fixture +def py_session(py_notebook: AppFileManager) -> Mock: + return _make_session_for(py_notebook) + + +@pytest.fixture +def md_session(md_notebook: AppFileManager) -> Mock: + return _make_session_for(md_notebook) + + +async def _drain( + k: Kernel, + ext: NotificationListenerExtension, + session: Mock, +) -> None: + """Forward every NotebookDocumentTransactionNotification on ``k.stream`` + through the interceptor so disk state catches up with the kernel graph. + + Auto-save is dispatched to the runner's executor when a running loop + is detected, so after feeding messages we must drain pending tasks + before asserting on the file contents. + """ + for notif in list(k.stream.operations): + if isinstance(notif, NotebookDocumentTransactionNotification): + ext._on_kernel_message(session, serialize_kernel_message(notif)) + await ext._autosave_runner.drain() + + +class TestCodeModeAutoSavePy: + """code_mode ops land on a ``.py`` file on disk.""" + + async def test_create_cell_persists( + self, + k: Kernel, + py_notebook: AppFileManager, + py_session: Mock, + ext: NotificationListenerExtension, + ) -> None: + with _ctx(k) as ctx: + async with ctx as nb: + nb.create_cell("greeting = 42") + await _drain(k, ext, py_session) + + contents = _read_disk(py_notebook) + assert "greeting = 42" in contents + # Must be serialized as a proper @app.cell, not an unparsable fallback + assert "_unparsable_cell" not in contents + + async def test_edit_cell_persists( + self, + k: Kernel, + py_notebook: AppFileManager, + py_session: Mock, + ext: NotificationListenerExtension, + ) -> None: + with _ctx(k) as ctx: + async with ctx as nb: + cid = nb.create_cell("x = 1") + async with ctx as nb: + nb.edit_cell(cid, code="x = 999") + await _drain(k, ext, py_session) + + contents = _read_disk(py_notebook) + assert "x = 999" in contents + assert "x = 1\n" not in contents + + async def test_delete_cell_persists( + self, + k: Kernel, + py_notebook: AppFileManager, + py_session: Mock, + ext: NotificationListenerExtension, + ) -> None: + with _ctx(k) as ctx: + async with ctx as nb: + nb.create_cell("keep = 1") + drop = nb.create_cell("drop = 2") + async with ctx as nb: + nb.delete_cell(drop) + await _drain(k, ext, py_session) + + contents = _read_disk(py_notebook) + assert "keep = 1" in contents + assert "drop = 2" not in contents + + async def test_mixed_batch_persists( + self, + k: Kernel, + py_notebook: AppFileManager, + py_session: Mock, + ext: NotificationListenerExtension, + ) -> None: + """Create + edit + delete in a single context block all land.""" + with _ctx(k) as ctx: + async with ctx as nb: + first = nb.create_cell("first = 1") + second = nb.create_cell("second = 2") + async with ctx as nb: + nb.edit_cell(first, code="first = 100") + nb.create_cell("third = 3") + nb.delete_cell(second) + await _drain(k, ext, py_session) + + contents = _read_disk(py_notebook) + assert "first = 100" in contents + assert "third = 3" in contents + assert "second = 2" not in contents + + +class TestCodeModeAutoSaveMd: + """code_mode ops land on a ``.md`` file on disk.""" + + async def test_create_cell_persists( + self, + k: Kernel, + md_notebook: AppFileManager, + md_session: Mock, + ext: NotificationListenerExtension, + ) -> None: + with _ctx(k) as ctx: + async with ctx as nb: + nb.create_cell("answer = 42") + await _drain(k, ext, md_session) + + assert "answer = 42" in _read_disk(md_notebook) + + +class TestExecutorOrdering: + """A slower earlier save must never overwrite a newer one.""" + + async def test_rapid_mutations_preserve_latest_state( + self, + k: Kernel, + py_notebook: AppFileManager, + py_session: Mock, + ext: NotificationListenerExtension, + ) -> None: + """Rapid-fire kernel mutations should all serialize through the + single-worker executor in FIFO order, leaving the newest snapshot + on disk.""" + with _ctx(k) as ctx: + async with ctx as nb: + cid = nb.create_cell("version = 1") + async with ctx as nb: + nb.edit_cell(cid, code="version = 2") + async with ctx as nb: + nb.edit_cell(cid, code="version = 3") + async with ctx as nb: + nb.edit_cell(cid, code="version = 4") + + await _drain(k, ext, py_session) + + contents = _read_disk(py_notebook) + assert "version = 4" in contents + # None of the earlier snapshots should have clobbered the latest + assert "version = 1\n" not in contents + assert "version = 2\n" not in contents + assert "version = 3\n" not in contents diff --git a/tests/_server/test_file_manager.py b/tests/_server/test_file_manager.py index c31cc9b0e93..dfcb597c042 100644 --- a/tests/_server/test_file_manager.py +++ b/tests/_server/test_file_manager.py @@ -228,6 +228,148 @@ def test_save_cannot_rename(app_file_manager: AppFileManager) -> None: assert e.value.status_code == HTTPStatus.BAD_REQUEST +def test_save_from_cells_persists_cells( + app_file_manager: AppFileManager, +) -> None: + """``save_from_cells`` should round-trip cells through the serializer.""" + from marimo._messaging.notebook.document import NotebookCell + + app_file_manager.save_from_cells( + [ + NotebookCell( + id=CellId_t("first"), + code="z = 99", + name="first", + config=CellConfig(), + ), + ] + ) + assert app_file_manager.filename is not None + with open(app_file_manager.filename, encoding="utf-8") as f: + contents = f.read() + assert "z = 99" in contents + assert "def first" in contents + + +def test_save_from_cells_empty_name_normalizes( + app_file_manager: AppFileManager, +) -> None: + """Empty cell names must serialize as the default ``_`` rather than + falling back to the unparsable-cell path.""" + from marimo._messaging.notebook.document import NotebookCell + + app_file_manager.save_from_cells( + [ + NotebookCell( + id=CellId_t("c"), + code="greeting = 42", + name="", + config=CellConfig(), + ), + ] + ) + assert app_file_manager.filename is not None + with open(app_file_manager.filename, encoding="utf-8") as f: + contents = f.read() + assert "greeting = 42" in contents + assert "_unparsable_cell" not in contents + + +def test_save_from_cells_unnamed_raises( + app_file_manager: AppFileManager, +) -> None: + """Unnamed notebooks cannot be persisted from a cell snapshot.""" + app_file_manager.filename = None + with pytest.raises(HTTPException) as e: + app_file_manager.save_from_cells([]) + assert e.value.status_code == HTTPStatus.BAD_REQUEST + + +def test_save_from_cells_preserves_layout_file( + app_file_manager: AppFileManager, +) -> None: + """``save_from_cells`` must keep ``layout_file`` in app config.""" + from marimo._messaging.notebook.document import NotebookCell + + app_file_manager.app.update_config({"layout_file": "layouts/x.grid.json"}) + app_file_manager.save_from_cells( + [ + NotebookCell( + id=CellId_t("c"), + code="x = 1", + name="", + config=CellConfig(), + ), + ] + ) + assert app_file_manager.app.config.layout_file == "layouts/x.grid.json" + + +def test_save_and_save_from_cells_serialize_under_lock( + app_file_manager: AppFileManager, +) -> None: + """Concurrent ``save`` + ``save_from_cells`` on the same manager must + produce a valid (non-torn) file. Also regression-tests that the + reentrant lock covers both entry points.""" + import threading + + from marimo._messaging.notebook.document import NotebookCell + + assert app_file_manager.filename is not None + save_request.filename = app_file_manager.filename + errors: list[Exception] = [] + + def _frontend_save() -> None: + try: + for _ in range(20): + app_file_manager.save(save_request) + except Exception as e: # pragma: no cover — should never happen + errors.append(e) + + def _autosave() -> None: + cells = [ + NotebookCell( + id=CellId_t("auto"), + code="auto = 1", + name="", + config=CellConfig(), + ) + ] + try: + for _ in range(20): + app_file_manager.save_from_cells(cells) + except Exception as e: # pragma: no cover + errors.append(e) + + try: + t1 = threading.Thread(target=_frontend_save) + t2 = threading.Thread(target=_autosave) + t1.start() + t2.start() + t1.join(timeout=10) + t2.join(timeout=10) + + assert not t1.is_alive(), ( + "frontend save thread did not terminate within 10s " + "(likely deadlock in AppFileManager write lock)" + ) + assert not t2.is_alive(), ( + "autosave thread did not terminate within 10s " + "(likely deadlock in AppFileManager write lock)" + ) + assert not errors, f"unexpected errors: {errors}" + # File ends in a parseable state — the serializer's codegen would + # raise on a torn write, and the final content must be one of the + # two write paths, never a mix. + with open(save_request.filename, encoding="utf-8") as f: + contents = f.read() + assert "import marimo" in contents + assert "app = marimo.App" in contents + finally: + if os.path.exists(save_request.filename): + os.remove(save_request.filename) + + def test_save_with_header( app_file_manager: AppFileManager, tmp_path: Path ) -> None: diff --git a/tests/_session/extensions/test_notification_listener_autosave.py b/tests/_session/extensions/test_notification_listener_autosave.py new file mode 100644 index 00000000000..fbaf8379706 --- /dev/null +++ b/tests/_session/extensions/test_notification_listener_autosave.py @@ -0,0 +1,516 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Tests for code_mode auto-save in NotificationListenerExtension.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from marimo._ast.cell import CellConfig +from marimo._messaging.notebook.changes import ( + CreateCell, + DeleteCell, + DocumentChange, + MoveCell, + SetCode, + SetConfig, + SetName, + Transaction, +) +from marimo._messaging.notebook.document import NotebookCell, NotebookDocument +from marimo._messaging.notification import ( + AlertNotification, + NotebookDocumentTransactionNotification, +) +from marimo._messaging.serde import serialize_kernel_message +from marimo._messaging.types import KernelMessage +from marimo._session.extensions.extensions import ( + NotificationListenerExtension, +) +from marimo._session.model import SessionMode +from marimo._session.notebook import AppFileManager +from marimo._types.ids import CellId_t + +INITIAL_NOTEBOOK = """ +import marimo +__generated_with = "0.0.1" +app = marimo.App() + +@app.cell +def _(): + x = 1 + return (x,) + +if __name__ == "__main__": + app.run() +""" + + +def _make_extension( + *, mode: SessionMode = SessionMode.EDIT +) -> NotificationListenerExtension: + kernel_manager = Mock() + kernel_manager.mode = mode + queue_manager = Mock() + queue_manager.stream_queue = None + return NotificationListenerExtension(kernel_manager, queue_manager) + + +def _document_from(app_file_manager: AppFileManager) -> NotebookDocument: + return NotebookDocument( + [ + NotebookCell( + id=d.cell_id, code=d.code, name=d.name, config=d.config + ) + for d in app_file_manager.app.cell_manager.cell_data() + ] + ) + + +def _serialize_tx( + *changes: DocumentChange, source: str = "kernel" +) -> KernelMessage: + return serialize_kernel_message( + NotebookDocumentTransactionNotification( + transaction=Transaction(changes=changes, source=source) + ) + ) + + +@pytest.fixture +def app_file_manager(tmp_path: Path) -> AppFileManager: + temp_file = tmp_path / "test_autosave.py" + temp_file.write_text(INITIAL_NOTEBOOK) + return AppFileManager(filename=str(temp_file)) + + +@pytest.fixture +def notebook_path(app_file_manager: AppFileManager) -> Path: + assert app_file_manager.path is not None + return Path(app_file_manager.path) + + +@pytest.fixture +def existing_cell_id(app_file_manager: AppFileManager) -> CellId_t: + return next(iter(app_file_manager.app.cell_manager.cell_ids())) + + +@pytest.fixture +def session(app_file_manager: AppFileManager) -> Mock: + s = Mock() + s.app_file_manager = app_file_manager + s.document = _document_from(app_file_manager) + s.notify = Mock() + return s + + +@pytest.fixture +def ext() -> NotificationListenerExtension: + return _make_extension() + + +class TestKernelSourcedAutoSave: + """Kernel-sourced transactions should persist to disk in edit mode.""" + + def test_create_cell_writes_to_disk( + self, + ext: NotificationListenerExtension, + session: Mock, + notebook_path: Path, + ) -> None: + ext._on_kernel_message( + session, + _serialize_tx( + CreateCell( + cell_id=CellId_t("new-cell-1"), + code="y = 2", + name="", + config=CellConfig(), + ) + ), + ) + contents = notebook_path.read_text() + assert "y = 2" in contents + assert "x = 1" in contents + + def test_set_code_writes_to_disk( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + notebook_path: Path, + ) -> None: + ext._on_kernel_message( + session, + _serialize_tx(SetCode(cell_id=existing_cell_id, code="x = 42")), + ) + contents = notebook_path.read_text() + assert "x = 42" in contents + assert "x = 1\n" not in contents + + def test_set_name_writes_to_disk( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + notebook_path: Path, + ) -> None: + ext._on_kernel_message( + session, + _serialize_tx(SetName(cell_id=existing_cell_id, name="my_cell")), + ) + assert "def my_cell" in notebook_path.read_text() + + def test_set_config_writes_to_disk( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + notebook_path: Path, + ) -> None: + ext._on_kernel_message( + session, + _serialize_tx(SetConfig(cell_id=existing_cell_id, hide_code=True)), + ) + contents = notebook_path.read_text() + assert "hide_code=True" in contents or "hide_code: true" in contents + + def test_delete_cell_writes_to_disk( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + notebook_path: Path, + ) -> None: + # Add a second cell first so delete doesn't leave us empty + ext._on_kernel_message( + session, + _serialize_tx( + CreateCell( + cell_id=CellId_t("tmp-keep"), + code="keeper = 99", + name="", + config=CellConfig(), + ) + ), + ) + ext._on_kernel_message( + session, _serialize_tx(DeleteCell(cell_id=existing_cell_id)) + ) + contents = notebook_path.read_text() + assert "keeper = 99" in contents + assert "x = 1" not in contents + + def test_move_cell_writes_to_disk( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + notebook_path: Path, + ) -> None: + second_id = CellId_t("second") + ext._on_kernel_message( + session, + _serialize_tx( + CreateCell( + cell_id=second_id, + code="y = 2", + name="", + config=CellConfig(), + ) + ), + ) + ext._on_kernel_message( + session, + _serialize_tx(MoveCell(cell_id=existing_cell_id, after=second_id)), + ) + contents = notebook_path.read_text() + assert contents.index("y = 2") < contents.index("x = 1") + + def test_notify_still_called_on_kernel_transaction( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + ) -> None: + """Auto-save must not suppress the frontend broadcast.""" + ext._on_kernel_message( + session, + _serialize_tx(SetCode(cell_id=existing_cell_id, code="x = 99")), + ) + assert session.notify.called + notif = session.notify.call_args_list[0].args[0] + assert isinstance(notif, NotebookDocumentTransactionNotification) + + +class TestAutoSaveSkipped: + """Scenarios where auto-save must be a no-op.""" + + def test_client_sourced_transaction_does_not_rewrite( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + notebook_path: Path, + ) -> None: + before_mtime = notebook_path.stat().st_mtime + ext._on_kernel_message( + session, + _serialize_tx( + SetCode(cell_id=existing_cell_id, code="x = 77"), + source="frontend", + ), + ) + assert session.notify.called + assert notebook_path.stat().st_mtime == before_mtime + assert "x = 77" not in notebook_path.read_text() + + def test_run_mode_skips_autosave( + self, + session: Mock, + existing_cell_id: CellId_t, + notebook_path: Path, + ) -> None: + run_ext = _make_extension(mode=SessionMode.RUN) + before_mtime = notebook_path.stat().st_mtime + run_ext._on_kernel_message( + session, + _serialize_tx(SetCode(cell_id=existing_cell_id, code="x = 999")), + ) + assert notebook_path.stat().st_mtime == before_mtime + + +class TestUnnamedNotebook: + """Auto-save is a silent no-op for unnamed notebooks.""" + + @pytest.fixture + def unnamed_session(self) -> tuple[Mock, CellId_t]: + mgr = AppFileManager(filename=None) + seed_id = next(iter(mgr.app.cell_manager.cell_ids())) + s = Mock() + s.app_file_manager = mgr + s.document = _document_from(mgr) + s.notify = Mock() + return s, seed_id + + def test_skips_without_raising( + self, + ext: NotificationListenerExtension, + unnamed_session: tuple[Mock, CellId_t], + ) -> None: + sess, seed_id = unnamed_session + ext._on_kernel_message( + sess, _serialize_tx(SetCode(cell_id=seed_id, code="x = 2")) + ) + assert sess.notify.called + + def test_debug_log_flag_flips( + self, + ext: NotificationListenerExtension, + unnamed_session: tuple[Mock, CellId_t], + ) -> None: + sess, seed_id = unnamed_session + ext._on_kernel_message( + sess, _serialize_tx(SetCode(cell_id=seed_id, code="x = 2")) + ) + ext._on_kernel_message( + sess, _serialize_tx(SetCode(cell_id=seed_id, code="x = 3")) + ) + assert ext._unnamed_autosave_logged is True + + +def _get_alerts(session: Mock) -> list[AlertNotification]: + """Extract every ``AlertNotification`` the interceptor broadcast.""" + return [ + call.args[0] + for call in session.notify.call_args_list + if isinstance(call.args[0], AlertNotification) + ] + + +def _get_tx_broadcasts( + session: Mock, +) -> list[NotebookDocumentTransactionNotification]: + return [ + call.args[0] + for call in session.notify.call_args_list + if isinstance(call.args[0], NotebookDocumentTransactionNotification) + ] + + +def _install_failing_write( + app_file_manager: AppFileManager, message: str = "disk full" +) -> None: + def _fail(*_args: object, **_kwargs: object) -> None: + raise OSError(message) + + app_file_manager.storage.write = _fail # type: ignore[method-assign] + + +class TestFailureSurfaces: + """Write failures should surface as an AlertNotification toast.""" + + @pytest.fixture + def failing_storage(self, app_file_manager: AppFileManager) -> None: + _install_failing_write(app_file_manager) + + @pytest.mark.usefixtures("failing_storage") + def test_write_failure_broadcasts_alert( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + ) -> None: + # Must not raise out of the interceptor + ext._on_kernel_message( + session, + _serialize_tx(SetCode(cell_id=existing_cell_id, code="x = 2")), + ) + + alerts = _get_alerts(session) + assert len(alerts) == 1 + assert alerts[0].variant == "danger" + assert alerts[0].title == "Auto-save failed" + + @pytest.mark.usefixtures("failing_storage") + def test_transaction_still_broadcast_when_save_fails( + self, + ext: NotificationListenerExtension, + session: Mock, + existing_cell_id: CellId_t, + ) -> None: + """Even on save failure, the frontend must see the transaction so + its local state stays in sync with the kernel graph.""" + ext._on_kernel_message( + session, + _serialize_tx(SetCode(cell_id=existing_cell_id, code="x = 2")), + ) + assert len(_get_tx_broadcasts(session)) == 1 + + def test_alert_description_is_html_escaped( + self, + ext: NotificationListenerExtension, + session: Mock, + app_file_manager: AppFileManager, + existing_cell_id: CellId_t, + ) -> None: + """User-controllable path + OS error strings must be HTML-escaped + before landing in AlertNotification.description, which the frontend + renders via renderHTML (sanitized today, but defense in depth).""" + _install_failing_write( + app_file_manager, message="" + ) + ext._on_kernel_message( + session, + _serialize_tx(SetCode(cell_id=existing_cell_id, code="x = 2")), + ) + + alerts = _get_alerts(session) + assert len(alerts) == 1 + desc = alerts[0].description + assert "