Skip to content

Commit 6baa22d

Browse files
authored
feat: add lazy cache blob types (arrow, npy) (#9035)
## 📝 Summary Fixes a bug where a lazy cache blob miss blocked execution. Additionally adds `arrow` and `npz` formats over `pickle` for type relevant storage.
1 parent c6ed59f commit 6baa22d

5 files changed

Lines changed: 456 additions & 70 deletions

File tree

marimo/_save/loaders/lazy.py

Lines changed: 111 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pickle
55
import queue
66
import threading
7+
from enum import Enum, auto
78
from pathlib import Path
89
from typing import Any
910

@@ -18,22 +19,48 @@
1819
from marimo._save.loaders.loader import BasePersistenceLoader
1920
from marimo._save.stores import FileStore, Store
2021
from marimo._save.stubs import (
21-
LAZY_STUB_LOOKUP,
2222
FunctionStub,
2323
ModuleStub,
2424
)
2525
from marimo._save.stubs.lazy_stub import (
26+
_LAZY_STUB_CACHE,
27+
BLOB_DESERIALIZERS,
28+
BLOB_SERIALIZERS,
29+
LAZY_STUB_LOOKUP,
2630
Cache as CacheSchema,
2731
CacheType,
2832
ImmediateReferenceStub,
2933
Item,
3034
Meta,
3135
ReferenceStub,
3236
)
37+
from marimo._save.stubs.stubs import mro_lookup
3338

3439
LOGGER = _loggers.marimo_logger()
3540

3641

42+
class _BlobStatus(Enum):
43+
"""Sentinel placed in the results queue when a blob is missing."""
44+
45+
MISSING = auto()
46+
47+
48+
def maybe_update_lazy_stub(value: Any) -> str:
49+
"""Return the loader strategy string for *value*, caching the result.
50+
51+
Walks the MRO of ``type(value)`` against ``LAZY_STUB_LOOKUP`` (a
52+
fq-class-name → loader-string registry). Falls back to ``"pickle"``
53+
when no match is found.
54+
"""
55+
value_type = type(value)
56+
if value_type in _LAZY_STUB_CACHE:
57+
return _LAZY_STUB_CACHE[value_type]
58+
result = mro_lookup(value_type, LAZY_STUB_LOOKUP)
59+
loader = result[1] if result else "pickle"
60+
_LAZY_STUB_CACHE[value_type] = loader
61+
return loader
62+
63+
3764
def to_item(
3865
path: Path,
3966
value: Any | None,
@@ -45,11 +72,27 @@ def to_item(
4572
return Item()
4673

4774
if loader is None:
48-
loader = LAZY_STUB_LOOKUP.get(type(value), "pickle")
75+
loader = maybe_update_lazy_stub(value)
76+
77+
type_hint = f"{type(value).__module__}.{type(value).__name__}"
4978

5079
if loader == "pickle":
5180
return Item(
52-
reference=(path / f"{var_name}.pickle").as_posix(), hash=hash
81+
reference=(path / f"{var_name}.pickle").as_posix(),
82+
hash=hash,
83+
type_hint=type_hint,
84+
)
85+
if loader == "npy":
86+
return Item(
87+
reference=(path / f"{var_name}.npy").as_posix(),
88+
hash=hash,
89+
type_hint=type_hint,
90+
)
91+
if loader == "arrow":
92+
return Item(
93+
reference=(path / f"{var_name}.arrow").as_posix(),
94+
hash=hash,
95+
type_hint=type_hint,
5396
)
5497
if loader == "ui":
5598
return Item(reference=(path / "ui.pickle").as_posix())
@@ -60,7 +103,11 @@ def to_item(
60103
if isinstance(value, (int, str, float, bool, bytes, type(None))):
61104
return Item(primitive=value)
62105

63-
return Item(reference=(path / f"{var_name}.pickle").as_posix(), hash=hash)
106+
return Item(
107+
reference=(path / f"{var_name}.pickle").as_posix(),
108+
hash=hash,
109+
type_hint=type_hint,
110+
)
64111

65112

66113
def from_item(item: Item) -> Any:
@@ -117,55 +164,72 @@ def restore_cache(self, _key: HashKey, blob: bytes) -> Cache:
117164

118165
# Collect references to load
119166
ref_vars: dict[str, str] = {}
167+
ref_type_hints: dict[str, str | None] = {}
120168
variable_hashes: dict[str, str] = {}
121169
for var_name, item in cache_data.defs.items():
122170
if var_name in cache_data.ui_defs:
123171
ref_vars[var_name] = (base / "ui.pickle").as_posix()
124172
elif item.reference is not None:
125173
ref_vars[var_name] = item.reference
174+
ref_type_hints[item.reference] = item.type_hint
126175
if item.hash:
127176
variable_hashes[var_name] = item.hash
128177

129178
# Eagerly resolve return value reference alongside defs
130179
return_ref: str | None = None
180+
return_type_hint: str | None = None
131181
if (
132182
cache_data.meta.return_value
133183
and cache_data.meta.return_value.reference
134184
):
135185
return_ref = cache_data.meta.return_value.reference
186+
return_type_hint = cache_data.meta.return_value.type_hint
136187

137-
# Read + unpickle in parallel, stream results via queue
188+
# Read + deserialize in parallel, stream results via queue.
189+
# Every thread unconditionally puts exactly one item — either the
190+
# deserialized value or _BlobStatus.MISSING — so queue.get() needs
191+
# no timeout.
138192
results: queue.Queue[tuple[str, Any]] = queue.Queue()
139193
unique_keys = set(ref_vars.values())
140194
if return_ref:
141195
unique_keys.add(return_ref)
142196

143-
def _load_and_unpickle(key: str) -> None:
144-
data = self.store.get(key)
145-
if data:
146-
results.put((key, pickle.loads(data)))
197+
def _load_blob(key: str) -> None:
198+
try:
199+
data = self.store.get(key)
200+
if data:
201+
ext = Path(key).suffix
202+
deserialize = BLOB_DESERIALIZERS.get(
203+
ext, BLOB_DESERIALIZERS[".pickle"]
204+
)
205+
type_hint = ref_type_hints.get(key) or (
206+
return_type_hint if key == return_ref else None
207+
)
208+
results.put((key, deserialize(data, type_hint)))
209+
else:
210+
results.put((key, _BlobStatus.MISSING))
211+
except Exception as e:
212+
LOGGER.warning("Failed to deserialize blob %s: %s", key, e)
213+
results.put((key, _BlobStatus.MISSING))
147214

148215
threads = [
149-
threading.Thread(target=_load_and_unpickle, args=(key,))
216+
threading.Thread(target=_load_blob, args=(key,))
150217
for key in unique_keys
151218
]
152219
for t in threads:
153220
t.start()
154221

155-
# Stream results as they arrive
222+
# N threads → N results guaranteed; no timeout needed.
156223
unpickled: dict[str, Any] = {}
157-
for _ in unique_keys:
158-
try:
159-
key, val = results.get(timeout=30)
224+
try:
225+
for _ in unique_keys:
226+
key, val = results.get()
227+
if val is _BlobStatus.MISSING:
228+
raise FileNotFoundError("Incomplete cache: missing blobs")
160229
unpickled[key] = val
161-
except queue.Empty:
162-
break
163-
164-
for t in threads:
165-
t.join()
166-
167-
if len(unpickled) < len(unique_keys):
168-
raise FileNotFoundError("Incomplete cache: missing blobs")
230+
finally:
231+
for t in threads:
232+
t.join()
169233

170234
# Distribute to defs
171235
defs: dict[str, Any] = {}
@@ -210,25 +274,28 @@ def save_cache(self, cache: Cache) -> bool:
210274
path, cache.meta.get("return", None), var_name="return"
211275
)
212276
if return_item.reference:
213-
return_item.reference = (path / "return.pickle").as_posix()
277+
# Normalize base name to "return" while preserving format extension.
278+
ext = Path(return_item.reference).suffix
279+
return_item.reference = (path / f"return{ext}").as_posix()
214280

215281
try:
216282
cache_type_enum = CacheType(cache.cache_type)
217283
except ValueError:
218284
cache_type_enum = CacheType.UNKNOWN
219285

220-
pickle_vars: dict[str, Any] = {}
286+
# Separate vars by loader strategy
287+
format_vars: dict[str, dict[str, Any]] = {} # loader → {var: obj}
221288
ui_vars: dict[str, Any] = {}
222289
defs_dict: dict[str, Item] = {}
223290
ui_defs_list: list[str] = []
224291

225292
for var, obj in cache.defs.items():
226-
loader = LAZY_STUB_LOOKUP.get(type(obj), "pickle")
227-
if loader == "pickle":
228-
pickle_vars[var] = obj
229-
elif loader == "ui":
293+
loader = maybe_update_lazy_stub(obj)
294+
if loader == "ui":
230295
ui_vars[var] = obj
231296
ui_defs_list.append(var)
297+
elif loader not in ("inline",):
298+
format_vars.setdefault(loader, {})[var] = obj
232299
defs_dict[var] = to_item(
233300
path,
234301
obj,
@@ -255,23 +322,33 @@ def save_cache(self, cache: Cache) -> bool:
255322
store = self.store
256323
return_ref = return_item.reference
257324
return_value = cache.meta.get("return", None)
325+
return_loader = (
326+
maybe_update_lazy_stub(return_value)
327+
if return_value is not None
328+
else "pickle"
329+
)
258330
manifest_key = str(self.build_path(cache.key))
259331

260332
def _serialize_and_write() -> None:
261333
"""Serialize and write all blobs + manifest in background."""
262334
try:
263335
if return_ref:
264-
store.put(return_ref, pickle.dumps(return_value))
336+
serialize = BLOB_SERIALIZERS.get(
337+
return_loader, pickle.dumps
338+
)
339+
store.put(return_ref, serialize(return_value))
265340
if ui_vars:
266341
store.put(
267342
(path / "ui.pickle").as_posix(),
268343
pickle.dumps(ui_vars),
269344
)
270-
for var, obj in pickle_vars.items():
271-
store.put(
272-
(path / f"{var}.pickle").as_posix(),
273-
pickle.dumps(obj),
274-
)
345+
for loader, vars_dict in format_vars.items():
346+
serialize = BLOB_SERIALIZERS.get(loader, pickle.dumps)
347+
for var, obj in vars_dict.items():
348+
store.put(
349+
(path / f"{var}.{loader}").as_posix(),
350+
serialize(obj),
351+
)
275352
# Manifest last — readers check for it to detect complete writes
276353
store.put(manifest_key, manifest)
277354
except Exception:

marimo/_save/stubs/__init__.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from typing import TYPE_CHECKING, Any
77

88
from marimo._save.stubs.function_stub import FunctionStub
9-
from marimo._save.stubs.lazy_stub import ReferenceStub
9+
from marimo._save.stubs.lazy_stub import LAZY_STUB_LOOKUP, ReferenceStub
1010
from marimo._save.stubs.module_stub import ModuleStub
1111
from marimo._save.stubs.pydantic_stub import PydanticStub
1212
from marimo._save.stubs.stubs import (
1313
CUSTOM_STUBS,
1414
CustomStub,
15+
mro_lookup,
1516
register_stub,
1617
)
1718
from marimo._save.stubs.ui_element_stub import UIElementStub
@@ -27,18 +28,6 @@
2728
"pydantic.main.BaseModel": PydanticStub.register,
2829
}
2930

30-
LAZY_STUB_LOOKUP: dict[type, str] = {
31-
int: "inline",
32-
str: "inline",
33-
float: "inline",
34-
bool: "inline",
35-
bytes: "inline",
36-
type(None): "inline",
37-
FunctionStub: "inline",
38-
ModuleStub: "inline",
39-
UIElementStub: "ui",
40-
}
41-
4231

4332
def maybe_register_stub(value: Any) -> bool:
4433
"""Lazily register a stub for a value's type if not already registered.
@@ -57,25 +46,13 @@ def maybe_register_stub(value: Any) -> bool:
5746
if value_type in CUSTOM_STUBS:
5847
return True
5948

60-
# Walk MRO to find matching base class
61-
try:
62-
mro_list = value_type.mro()
63-
except BaseException:
64-
# Some exotic metaclasses or broken types may raise when calling mro
65-
mro_list = [value_type]
66-
67-
for cls in mro_list:
68-
if not (hasattr(cls, "__module__") and hasattr(cls, "__name__")):
69-
continue
70-
71-
cls_name = f"{cls.__module__}.{cls.__name__}"
72-
73-
if cls_name in STUB_REGISTRATIONS:
74-
if cls_name not in _REGISTERED_NAMES:
75-
_REGISTERED_NAMES.add(cls_name)
76-
STUB_REGISTRATIONS[cls_name](value)
77-
# After registration attempt, check if now in CUSTOM_STUBS
78-
return value_type in CUSTOM_STUBS
49+
result = mro_lookup(value_type, STUB_REGISTRATIONS)
50+
if result is not None:
51+
cls_name, register_fn = result
52+
if cls_name not in _REGISTERED_NAMES:
53+
_REGISTERED_NAMES.add(cls_name)
54+
register_fn(value)
55+
return value_type in CUSTOM_STUBS
7956

8057
return False
8158

@@ -107,5 +84,6 @@ def maybe_get_custom_stub(value: Any) -> CustomStub | None:
10784
"UIElementStub",
10885
"maybe_get_custom_stub",
10986
"maybe_register_stub",
87+
"mro_lookup",
11088
"register_stub",
11189
]

0 commit comments

Comments
 (0)