@@ -192,6 +192,20 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
192192 cache2 = torch_deepcopy ([cache ])
193193 self .assertEqualAny ([cache ], cache2 )
194194
195+ @ignore_warnings (UserWarning )
196+ @unittest .skipIf (make_sliding_window_cache , "transformers<5" )
197+ def test_base_sliding_window_cache_unflatten_flatten5 (self ):
198+ cache = make_dynamic_cache (
199+ [(torch .rand ((4 , 4 , 4 , 4 )), torch .rand ((4 , 4 , 4 , 4 )))],
200+ cls_layers = "DynamicSlidingWindowLayer" ,
201+ )
202+ with torch_export_patches (patch_transformers = True ):
203+ cache2 = torch_deepcopy ([cache ])
204+ self .assertEqualAny ([cache ], cache2 )
205+ self .assertEqual (
206+ [type (lay ) for lay in cache .layers ], [type (lay ) for lay in cache2 [0 ].layers ]
207+ )
208+
195209 @ignore_warnings (UserWarning )
196210 @requires_torch ("2.7.99" )
197211 @unittest .skipIf (not make_sliding_window_cache , "SlidingWindowCache was removed" )
@@ -215,6 +229,30 @@ def forward(self, cache):
215229 with torch_export_patches (patch_transformers = True ):
216230 torch .export .export (model , (cache ,), dynamic_shapes = (ds ,))
217231
232+ @ignore_warnings (UserWarning )
233+ @requires_torch ("2.7.99" )
234+ @unittest .skipIf (make_sliding_window_cache , "transformers<5" )
235+ def test_sliding_window_cache_export5 (self ):
236+ class Model (torch .nn .Module ):
237+ def forward (self , cache ):
238+ dc = CacheKeyValue (cache )
239+ return dc .key_cache [0 ]
240+
241+ cache = make_dynamic_cache (
242+ [
243+ (torch .rand ((4 , 4 , 4 , 4 )), torch .rand ((4 , 4 , 4 , 4 ))),
244+ (torch .rand ((4 , 4 , 4 , 4 )), torch .rand ((4 , 4 , 4 , 4 ))),
245+ ],
246+ cls_layers = "DynamicSlidingWindowLayer" ,
247+ )
248+ model = Model ()
249+ model (cache )
250+ DYN = torch .export .Dim .DYNAMIC
251+ ds = make_dynamic_shapes_kv_cache (cache , {0 : DYN })
252+
253+ with torch_export_patches (patch_transformers = True ):
254+ torch .export .export (model , (cache ,), dynamic_shapes = (ds ,))
255+
218256 @ignore_warnings (UserWarning )
219257 @unittest .skipIf (not make_sliding_window_cache , "SlidingWindowCache was removed" )
220258 def test_sliding_window_cache_flatten (self ):
@@ -233,6 +271,28 @@ def test_sliding_window_cache_flatten(self):
233271 self .string_type (cache2 , with_shape = True , with_min_max = True ),
234272 )
235273
274+ @ignore_warnings (UserWarning )
275+ @unittest .skipIf (make_sliding_window_cache , "transformers<5" )
276+ def test_sliding_window_cache_flatten5 (self ):
277+ cache = make_dynamic_cache (
278+ [(torch .rand ((4 , 4 , 4 , 4 )), torch .rand ((4 , 4 , 4 , 4 )))],
279+ cls_layers = "DynamicSlidingWindowLayer" ,
280+ )
281+ with torch_export_patches (patch_transformers = True ):
282+ flat , _spec = torch .utils ._pytree .tree_flatten (cache )
283+ self .assertEqual (
284+ "#2[T1s4x4x4x4,T1s4x4x4x4]" ,
285+ self .string_type (flat , with_shape = True ),
286+ )
287+ cache2 = torch .utils ._pytree .tree_unflatten (flat , _spec )
288+ self .assertEqual (
289+ self .string_type (cache , with_shape = True , with_min_max = True ),
290+ self .string_type (cache2 , with_shape = True , with_min_max = True ),
291+ )
292+ self .assertEqual (
293+ [type (lay ) for lay in cache .layers ], [type (lay ) for lay in cache2 .layers ]
294+ )
295+
236296 @ignore_warnings (UserWarning )
237297 @requires_torch ("2.7.99" )
238298 def test_static_cache (self ):
0 commit comments