Skip to content

Commit 37ab9b8

Browse files
authored
fix serialization for DynamicCache with different layer classes (#396)
* fix serialization for DynamicCache * fix * fix none * assert * fix
1 parent 2485de9 commit 37ab9b8

9 files changed

Lines changed: 170 additions & 46 deletions

File tree

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ jobs:
1818
os: [ubuntu-latest]
1919
python: ['3.10', '3.11', '3.12', '3.13']
2020
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.56.2', '4.57.6', 'main']
21-
torch: ['2.9', 'main']
21+
torch: ['2.10', 'main']
2222
exclude:
2323
- python: '3.10' # 3.10
2424
torch: 'main'
2525
- python: '3.10'
26-
torch: '2.9'
26+
torch: '2.10'
2727
- python: '3.10'
2828
transformers: '4.55.4'
2929
- python: '3.10'
@@ -43,7 +43,7 @@ jobs:
4343
- python: '3.11'
4444
transformers: 'main'
4545
- python: '3.13' # 3.11
46-
torch: '2.9'
46+
torch: '2.10'
4747
- python: '3.13'
4848
transformers: '4.48.3'
4949
- python: '3.13'

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.11
55
++++++
66

7+
* :pr:`396`: fix serialization for DynamicCache with different layer classes
78
* :pr:`394`: add function make_model_with_local_functions to partition a model into local functions
89

910
0.8.10

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,25 @@ def test_torch_deepcopy_sliding_windon_cache(self):
362362
self.assertEqual(hash1, hash2)
363363
self.assertGreater(torch_tensor_size(cache), 1)
364364

365+
@unittest.skipIf(make_sliding_window_cache is not None, "transformers<5")
366+
def test_torch_deepcopy_sliding_windon_cache5(self):
367+
cache = make_dynamic_cache(
368+
[
369+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
370+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
371+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
372+
],
373+
cls_layers="DynamicSlidingWindowLayer",
374+
)
375+
at = torch_deepcopy(cache)
376+
self.assertEqual(type(cache), type(at))
377+
self.assertEqual(max_diff(cache, at)["abs"], 0)
378+
hash1 = string_type(at, with_shape=True, with_min_max=True)
379+
CacheKeyValue(cache).key_cache[0] += 1000
380+
hash2 = string_type(at, with_shape=True, with_min_max=True)
381+
self.assertEqual(hash1, hash2)
382+
self.assertGreater(torch_tensor_size(cache), 1)
383+
365384
def test_torch_deepcopy_none(self):
366385
self.assertEmpty(torch_deepcopy(None))
367386
self.assertEqual(torch_tensor_size(None), 0)

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class TestTasksImageTextToText(ExtTestCase):
1717
@hide_stdout()
18-
@requires_transformers("4.56")
18+
@requires_transformers("5.0.99")
1919
@requires_torch("2.7.99")
2020
def test_image_text_to_text_idefics(self):
2121
mid = "HuggingFaceM4/tiny-random-idefics"

_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,20 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
192192
cache2 = torch_deepcopy([cache])
193193
self.assertEqualAny([cache], cache2)
194194

195+
@ignore_warnings(UserWarning)
196+
@unittest.skipIf(make_sliding_window_cache, "transformers<5")
197+
def test_base_sliding_window_cache_unflatten_flatten5(self):
198+
cache = make_dynamic_cache(
199+
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))],
200+
cls_layers="DynamicSlidingWindowLayer",
201+
)
202+
with torch_export_patches(patch_transformers=True):
203+
cache2 = torch_deepcopy([cache])
204+
self.assertEqualAny([cache], cache2)
205+
self.assertEqual(
206+
[type(lay) for lay in cache.layers], [type(lay) for lay in cache2[0].layers]
207+
)
208+
195209
@ignore_warnings(UserWarning)
196210
@requires_torch("2.7.99")
197211
@unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed")
@@ -215,6 +229,30 @@ def forward(self, cache):
215229
with torch_export_patches(patch_transformers=True):
216230
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
217231

232+
@ignore_warnings(UserWarning)
233+
@requires_torch("2.7.99")
234+
@unittest.skipIf(make_sliding_window_cache, "transformers<5")
235+
def test_sliding_window_cache_export5(self):
236+
class Model(torch.nn.Module):
237+
def forward(self, cache):
238+
dc = CacheKeyValue(cache)
239+
return dc.key_cache[0]
240+
241+
cache = make_dynamic_cache(
242+
[
243+
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
244+
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
245+
],
246+
cls_layers="DynamicSlidingWindowLayer",
247+
)
248+
model = Model()
249+
model(cache)
250+
DYN = torch.export.Dim.DYNAMIC
251+
ds = make_dynamic_shapes_kv_cache(cache, {0: DYN})
252+
253+
with torch_export_patches(patch_transformers=True):
254+
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
255+
218256
@ignore_warnings(UserWarning)
219257
@unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed")
220258
def test_sliding_window_cache_flatten(self):
@@ -233,6 +271,28 @@ def test_sliding_window_cache_flatten(self):
233271
self.string_type(cache2, with_shape=True, with_min_max=True),
234272
)
235273

274+
@ignore_warnings(UserWarning)
275+
@unittest.skipIf(make_sliding_window_cache, "transformers<5")
276+
def test_sliding_window_cache_flatten5(self):
277+
cache = make_dynamic_cache(
278+
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))],
279+
cls_layers="DynamicSlidingWindowLayer",
280+
)
281+
with torch_export_patches(patch_transformers=True):
282+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
283+
self.assertEqual(
284+
"#2[T1s4x4x4x4,T1s4x4x4x4]",
285+
self.string_type(flat, with_shape=True),
286+
)
287+
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
288+
self.assertEqual(
289+
self.string_type(cache, with_shape=True, with_min_max=True),
290+
self.string_type(cache2, with_shape=True, with_min_max=True),
291+
)
292+
self.assertEqual(
293+
[type(lay) for lay in cache.layers], [type(lay) for lay in cache2.layers]
294+
)
295+
236296
@ignore_warnings(UserWarning)
237297
@requires_torch("2.7.99")
238298
def test_static_cache(self):

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,12 @@ def test_patched_qwen2_5_vl_get_window_index(self):
402402
self.assertEqualArray(torch.tensor(cu_window_seqlens1), cu_window_seqlens2)
403403

404404
@requires_transformers("4.55")
405-
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
405+
# @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
406+
# see https://github.com/huggingface/transformers/pull/42564/files#diff-09bc594f9680f1d042fd485106c68022d77b59831697a00b3b38f12a3e40f395
407+
@unittest.skip(
408+
"vision_outputs = self.visual(pixel_values, "
409+
"grid_thw=image_grid_thw, return_dict=True, **kwargs)"
410+
)
406411
def test_patched_qwen2_5_vl_forward(self):
407412
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
408413
patched_Qwen2_5_VisionTransformerPretrainedModel,
@@ -422,7 +427,7 @@ def test_patched_qwen2_5_vl_forward(self):
422427
instance, hidden_states, grid_thw
423428
)
424429
patched_class.get_window_index = f_get_window_index
425-
self.assertEqualArray(expected, got)
430+
self.assertEqualAny(expected, got)
426431

427432
@classmethod
428433
def _get_cu_seqlens(cls):

onnx_diagnostic/ext_test_case.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,19 @@ def assertEqualAny(
10281028
rtol=rtol,
10291029
msg=msg,
10301030
)
1031+
elif expected.__class__.__name__ == "BaseModelOutputWithPooling":
1032+
if expected.__class__.__name__ == value.__class__.__name__:
1033+
self.assertEqual(len(expected), len(value), msg=msg)
1034+
self.assertEqual(list(expected), list(value), msg=msg) # checks the order
1035+
self.assertEqualAny(
1036+
{k: v for k, v in expected.items()}, # noqa: C416
1037+
{k: v for k, v in value.items()}, # noqa: C416
1038+
atol=atol,
1039+
rtol=rtol,
1040+
msg=msg,
1041+
)
1042+
else:
1043+
self.assertEqualArray(expected.last_hidden_state, value)
10311044
elif isinstance(expected, (tuple, list, dict)):
10321045
self.assertIsInstance(value, type(expected), msg=msg)
10331046
self.assertEqual(len(expected), len(value), msg=msg)
@@ -1043,24 +1056,28 @@ def assertEqualAny(
10431056
"SlidingWindowCache",
10441057
"HybridCache",
10451058
):
1059+
from .helpers.cache_helper import CacheKeyValue
1060+
10461061
self.assertEqual(type(expected), type(value), msg=msg)
1047-
atts = ["key_cache", "value_cache"]
1048-
self.assertEqualAny(
1049-
{k: expected.__dict__.get(k, None) for k in atts},
1050-
{k: value.__dict__.get(k, None) for k in atts},
1051-
atol=atol,
1052-
rtol=rtol,
1053-
)
1062+
self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
10541063
elif expected.__class__.__name__ == "StaticCache":
1064+
from .helpers.cache_helper import CacheKeyValue
1065+
10551066
self.assertEqual(type(expected), type(value), msg=msg)
10561067
self.assertEqual(expected.max_cache_len, value.max_cache_len)
1057-
atts = ["key_cache", "value_cache"]
1058-
self.assertEqualAny(
1059-
{k: expected.__dict__.get(k, None) for k in atts},
1060-
{k: value.__dict__.get(k, None) for k in atts},
1061-
atol=atol,
1062-
rtol=rtol,
1063-
)
1068+
self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
1069+
elif expected.__class__.__name__ == "CacheKeyValue":
1070+
self.assertEqual(type(expected), type(value), msg=msg)
1071+
if expected.cls_layers is None:
1072+
self.assertEqual(expected.cls_layers, value.cls_layers)
1073+
else:
1074+
self.assertEqualAny(
1075+
[cls.__name__ for cls in expected.cls_layers],
1076+
[cls.__name__ for cls in value.cls_layers],
1077+
msg=msg,
1078+
)
1079+
self.assertEqualAny(expected.key_cache, value.key_cache, msg=msg)
1080+
self.assertEqualAny(expected.value_cache, value.value_cache, msg=msg)
10641081
elif expected.__class__.__name__ == "EncoderDecoderCache":
10651082
self.assertEqual(type(expected), type(value), msg=msg)
10661083
atts = ["self_attention_cache", "cross_attention_cache"]

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def make_dynamic_cache(
348348
def make_static_cache(
349349
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
350350
max_cache_len: Optional[int] = None,
351+
cls_layers: Optional[Union[str, List[type]]] = None,
351352
) -> transformers.cache_utils.DynamicCache:
352353
"""
353354
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
@@ -379,6 +380,9 @@ def make_static_cache(
379380
)
380381
print(string_type(past_key_values, with_shape=True))
381382
"""
383+
assert not cls_layers or set(cls_layers) == {
384+
transformers.cache_utils.StaticLayer
385+
}, f"Not implemented when cls_layers={cls_layers!r}"
382386
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
383387

384388
class _config:
@@ -583,13 +587,9 @@ def get_text_config(self, *args, **kwargs):
583587
)
584588
return finalize_cache(cache)
585589

586-
def get_make_hybrid_cache():
587-
return make_sliding_window_cache
588-
589590
else:
590591
make_sliding_window_cache = None # type: ignore[assignment]
591592

592-
593593
if hasattr(transformers.cache_utils, "HybridCache"):
594594

595595
def make_hybrid_cache(

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
from typing import Any, Callable, List, Set, Tuple
33
import torch
4+
import transformers.cache_utils
45
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
56

67
try:
@@ -27,16 +28,38 @@
2728
DynamicCache: "4.50",
2829
BaseModelOutput: None,
2930
}
31+
SHORTEN_LAYER_NAMES = {
32+
"DynamicLayer": "D",
33+
"DynamicSlidingWindowLayer": "W",
34+
"StaticLayer": "S",
35+
"StaticSlidingWindowLayer": "X",
36+
"D": "DynamicLayer",
37+
"W": "DynamicSlidingWindowLayer",
38+
"S": "StaticLayer",
39+
"X": "StaticSlidingWindowLayer",
40+
}
3041

3142

3243
def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
3344
ca = CacheKeyValue(cache)
3445
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
35-
keys = list(
36-
itertools.chain.from_iterable(
37-
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
46+
unique = set(ca.cls_layers) if ca.cls_layers else None
47+
if (
48+
cache.__class__.__name__ != "DynamicCache"
49+
or unique is None
50+
or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
51+
):
52+
keys = list(
53+
itertools.chain.from_iterable(
54+
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
55+
)
3856
)
39-
)
57+
return flat, keys
58+
59+
keys = []
60+
for i in range(len(ca.key_cache)):
61+
letter = SHORTEN_LAYER_NAMES[ca.cls_layers[i].__name__]
62+
keys.extend([f"key_{letter}{i}", f"value_{letter}{i}"])
4063
return flat, keys
4164

4265

@@ -54,7 +77,20 @@ def _unflatten_cache(
5477
output_type=None,
5578
) -> DynamicCache:
5679
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
57-
res = make_cache(list(zip(values[::2], values[1::2])))
80+
expected = list(
81+
itertools.chain.from_iterable(
82+
(f"key_{i}", f"value_{i}") for i in range(len(values) // 2)
83+
)
84+
)
85+
if expected == context:
86+
res = make_cache(list(zip(values[::2], values[1::2])))
87+
else:
88+
cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
89+
cls_layers = [
90+
getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names
91+
]
92+
res = make_cache(list(zip(values[::2], values[1::2])), cls_layers=cls_layers)
93+
5894
assert output_type is None or isinstance(
5995
res, output_type
6096
), f"Type mismatch between {output_type} (expected) and {type(res)}"
@@ -70,29 +106,13 @@ def flatten_dynamic_cache(
70106
dynamic_cache: DynamicCache,
71107
) -> Tuple[List[Any], torch.utils._pytree.Context]:
72108
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
73-
assert (
74-
not hasattr(dynamic_cache, "layers")
75-
or not dynamic_cache.layers
76-
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
77-
), (
78-
f"The serialization does not work yet on other layers "
79-
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
80-
)
81109
return _flatten_key_value_cache(dynamic_cache)
82110

83111

84112
def flatten_with_keys_dynamic_cache(
85113
dynamic_cache: DynamicCache,
86114
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
87115
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
88-
assert (
89-
not hasattr(dynamic_cache, "layers")
90-
or not dynamic_cache.layers
91-
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
92-
), (
93-
f"The serialization does not work yet on other layers "
94-
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
95-
)
96116
return _flatten_with_keys_cache(dynamic_cache)
97117

98118

@@ -160,7 +180,9 @@ def unflatten_static_cache(
160180
) -> StaticCache:
161181
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
162182
return _unflatten_cache(
163-
lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
183+
lambda *args, **kwargs: make_static_cache(
184+
*args, max_cache_len=values[0].shape[2], **kwargs
185+
),
164186
values,
165187
context,
166188
output_type=output_type,

0 commit comments

Comments
 (0)