Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 101 additions & 29 deletions marimo/_save/loaders/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
import queue
import threading
from enum import Enum, auto
from pathlib import Path
from typing import Any, Optional

Expand All @@ -18,22 +19,48 @@
from marimo._save.loaders.loader import BasePersistenceLoader
from marimo._save.stores import FileStore, Store
from marimo._save.stubs import (
LAZY_STUB_LOOKUP,
FunctionStub,
ModuleStub,
)
from marimo._save.stubs.lazy_stub import (
_LAZY_STUB_CACHE,
BLOB_DESERIALIZERS,
BLOB_SERIALIZERS,
LAZY_STUB_LOOKUP,
Cache as CacheSchema,
CacheType,
ImmediateReferenceStub,
Item,
Meta,
ReferenceStub,
)
from marimo._save.stubs.stubs import mro_lookup

LOGGER = _loggers.marimo_logger()


class _BlobStatus(Enum):
"""Sentinel placed in the results queue when a blob is missing."""

MISSING = auto()


def maybe_update_lazy_stub(value: Any) -> str:
"""Return the loader strategy string for *value*, caching the result.

Walks the MRO of ``type(value)`` against ``LAZY_STUB_LOOKUP`` (a
fq-class-name → loader-string registry). Falls back to ``"pickle"``
when no match is found.
"""
value_type = type(value)
if value_type in _LAZY_STUB_CACHE:
return _LAZY_STUB_CACHE[value_type]
result = mro_lookup(value_type, LAZY_STUB_LOOKUP)
loader = result[1] if result else "pickle"
_LAZY_STUB_CACHE[value_type] = loader
return loader


def to_item(
path: Path,
value: Optional[Any],
Expand All @@ -45,11 +72,27 @@ def to_item(
return Item()

if loader is None:
loader = LAZY_STUB_LOOKUP.get(type(value), "pickle")
loader = maybe_update_lazy_stub(value)

type_hint = f"{type(value).__module__}.{type(value).__qualname__}"
Comment thread
dmadisetti marked this conversation as resolved.
Outdated

if loader == "pickle":
return Item(
reference=(path / f"{var_name}.pickle").as_posix(), hash=hash
reference=(path / f"{var_name}.pickle").as_posix(),
hash=hash,
type_hint=type_hint,
)
if loader == "npz":
return Item(
reference=(path / f"{var_name}.npz").as_posix(),
hash=hash,
type_hint=type_hint,
)
if loader == "arrow":
return Item(
reference=(path / f"{var_name}.arrow").as_posix(),
hash=hash,
type_hint=type_hint,
)
if loader == "ui":
return Item(reference=(path / "ui.pickle").as_posix())
Expand All @@ -60,7 +103,11 @@ def to_item(
if isinstance(value, (int, str, float, bool, bytes, type(None))):
return Item(primitive=value)

return Item(reference=(path / f"{var_name}.pickle").as_posix(), hash=hash)
return Item(
reference=(path / f"{var_name}.pickle").as_posix(),
hash=hash,
type_hint=type_hint,
)


def from_item(item: Item) -> Any:
Expand Down Expand Up @@ -117,56 +164,68 @@ def restore_cache(self, _key: HashKey, blob: bytes) -> Cache:

# Collect references to load
ref_vars: dict[str, str] = {}
ref_type_hints: dict[str, Optional[str]] = {}
variable_hashes: dict[str, str] = {}
for var_name, item in cache_data.defs.items():
if var_name in cache_data.ui_defs:
ref_vars[var_name] = (base / "ui.pickle").as_posix()
elif item.reference is not None:
ref_vars[var_name] = item.reference
ref_type_hints[item.reference] = item.type_hint
if item.hash:
variable_hashes[var_name] = item.hash

# Eagerly resolve return value reference alongside defs
return_ref: Optional[str] = None
return_type_hint: Optional[str] = None
if (
cache_data.meta.return_value
and cache_data.meta.return_value.reference
):
return_ref = cache_data.meta.return_value.reference
return_type_hint = cache_data.meta.return_value.type_hint

# Read + unpickle in parallel, stream results via queue
# Read + deserialize in parallel, stream results via queue.
# Every thread unconditionally puts exactly one item — either the
# deserialized value or _BlobStatus.MISSING — so queue.get() needs
# no timeout.
results: queue.Queue[tuple[str, Any]] = queue.Queue()
unique_keys = set(ref_vars.values())
if return_ref:
unique_keys.add(return_ref)

def _load_and_unpickle(key: str) -> None:
def _load_blob(key: str) -> None:
data = self.store.get(key)
if data:
results.put((key, pickle.loads(data)))
ext = Path(key).suffix
deserialize = BLOB_DESERIALIZERS.get(
ext, BLOB_DESERIALIZERS[".pickle"]
)
type_hint = ref_type_hints.get(key) or (
return_type_hint if key == return_ref else None
)
results.put((key, deserialize(data, type_hint)))
else:
results.put((key, _BlobStatus.MISSING))
Comment thread
dmadisetti marked this conversation as resolved.

threads = [
threading.Thread(target=_load_and_unpickle, args=(key,))
threading.Thread(target=_load_blob, args=(key,))
for key in unique_keys
]
for t in threads:
t.start()

# Stream results as they arrive
# N threads → N results guaranteed; no timeout needed.
unpickled: dict[str, Any] = {}
for _ in unique_keys:
try:
key, val = results.get(timeout=30)
unpickled[key] = val
except queue.Empty:
break
key, val = results.get()
if val is _BlobStatus.MISSING:
raise FileNotFoundError("Incomplete cache: missing blobs")
unpickled[key] = val

for t in threads:
t.join()
Comment thread
dmadisetti marked this conversation as resolved.
Outdated

if len(unpickled) < len(unique_keys):
raise FileNotFoundError("Incomplete cache: missing blobs")

# Distribute to defs
defs: dict[str, Any] = {}
for var_name, item in cache_data.defs.items():
Expand Down Expand Up @@ -210,25 +269,28 @@ def save_cache(self, cache: Cache) -> bool:
path, cache.meta.get("return", None), var_name="return"
)
if return_item.reference:
return_item.reference = (path / "return.pickle").as_posix()
# Normalize base name to "return" while preserving format extension.
ext = Path(return_item.reference).suffix
return_item.reference = (path / f"return{ext}").as_posix()

try:
cache_type_enum = CacheType(cache.cache_type)
except ValueError:
cache_type_enum = CacheType.UNKNOWN

pickle_vars: dict[str, Any] = {}
# Separate vars by loader strategy
format_vars: dict[str, dict[str, Any]] = {} # loader → {var: obj}
ui_vars: dict[str, Any] = {}
defs_dict: dict[str, Item] = {}
ui_defs_list: list[str] = []

for var, obj in cache.defs.items():
loader = LAZY_STUB_LOOKUP.get(type(obj), "pickle")
if loader == "pickle":
pickle_vars[var] = obj
elif loader == "ui":
loader = maybe_update_lazy_stub(obj)
if loader == "ui":
ui_vars[var] = obj
ui_defs_list.append(var)
elif loader not in ("inline",):
format_vars.setdefault(loader, {})[var] = obj
defs_dict[var] = to_item(
path,
obj,
Expand All @@ -255,23 +317,33 @@ def save_cache(self, cache: Cache) -> bool:
store = self.store
return_ref = return_item.reference
return_value = cache.meta.get("return", None)
return_loader = (
maybe_update_lazy_stub(return_value)
if return_value is not None
else "pickle"
)
manifest_key = str(self.build_path(cache.key))

def _serialize_and_write() -> None:
"""Serialize and write all blobs + manifest in background."""
try:
if return_ref:
store.put(return_ref, pickle.dumps(return_value))
serialize = BLOB_SERIALIZERS.get(
return_loader, pickle.dumps
)
store.put(return_ref, serialize(return_value))
if ui_vars:
store.put(
(path / "ui.pickle").as_posix(),
pickle.dumps(ui_vars),
)
for var, obj in pickle_vars.items():
store.put(
(path / f"{var}.pickle").as_posix(),
pickle.dumps(obj),
)
for loader, vars_dict in format_vars.items():
serialize = BLOB_SERIALIZERS.get(loader, pickle.dumps)
for var, obj in vars_dict.items():
store.put(
(path / f"{var}.{loader}").as_posix(),
serialize(obj),
)
# Manifest last — readers check for it to detect complete writes
store.put(manifest_key, manifest)
except Exception:
Expand Down
42 changes: 10 additions & 32 deletions marimo/_save/stubs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from typing import Any, Callable

from marimo._save.stubs.function_stub import FunctionStub
from marimo._save.stubs.lazy_stub import ReferenceStub
from marimo._save.stubs.lazy_stub import LAZY_STUB_LOOKUP, ReferenceStub
from marimo._save.stubs.module_stub import ModuleStub
from marimo._save.stubs.pydantic_stub import PydanticStub
from marimo._save.stubs.stubs import (
CUSTOM_STUBS,
CustomStub,
mro_lookup,
register_stub,
)
from marimo._save.stubs.ui_element_stub import UIElementStub
Expand All @@ -24,18 +25,6 @@
"pydantic.main.BaseModel": PydanticStub.register,
}

LAZY_STUB_LOOKUP: dict[type, str] = {
int: "inline",
str: "inline",
float: "inline",
bool: "inline",
bytes: "inline",
type(None): "inline",
FunctionStub: "inline",
ModuleStub: "inline",
UIElementStub: "ui",
}


def maybe_register_stub(value: Any) -> bool:
"""Lazily register a stub for a value's type if not already registered.
Expand All @@ -54,25 +43,13 @@ def maybe_register_stub(value: Any) -> bool:
if value_type in CUSTOM_STUBS:
return True

# Walk MRO to find matching base class
try:
mro_list = value_type.mro()
except BaseException:
# Some exotic metaclasses or broken types may raise when calling mro
mro_list = [value_type]

for cls in mro_list:
if not (hasattr(cls, "__module__") and hasattr(cls, "__name__")):
continue

cls_name = f"{cls.__module__}.{cls.__name__}"

if cls_name in STUB_REGISTRATIONS:
if cls_name not in _REGISTERED_NAMES:
_REGISTERED_NAMES.add(cls_name)
STUB_REGISTRATIONS[cls_name](value)
# After registration attempt, check if now in CUSTOM_STUBS
return value_type in CUSTOM_STUBS
result = mro_lookup(value_type, STUB_REGISTRATIONS)
if result is not None:
cls_name, register_fn = result
if cls_name not in _REGISTERED_NAMES:
_REGISTERED_NAMES.add(cls_name)
register_fn(value)
return value_type in CUSTOM_STUBS

return False

Expand Down Expand Up @@ -104,5 +81,6 @@ def maybe_get_custom_stub(value: Any) -> CustomStub | None:
"UIElementStub",
"maybe_get_custom_stub",
"maybe_register_stub",
"mro_lookup",
"register_stub",
]
Loading
Loading