Skip to content

Commit acade9c

Browse files
authored
improves documentation (#419)
* improves documentation * supports integers for value_if_missing * spell * disable more tests
1 parent f9fb41b commit acade9c

8 files changed

Lines changed: 308 additions & 50 deletions

File tree

.github/workflows/check-release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest, macOS-latest, windows-latest]
1818
python: ['3.13']
19-
transformers: ['5.1.0', 'main']
19+
transformers: ['5.2.0', 'main']
2020
torch: ['2.10', 'main']
2121

2222
steps:

.github/workflows/ci.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python: ['3.10', '3.11', '3.12', '3.13']
20-
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.1.0', 'main']
20+
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.2.0', 'main']
2121
torch: ['2.10', 'main']
2222
exclude:
2323
# 3.10 - torch
@@ -29,7 +29,7 @@ jobs:
2929
- python: '3.10'
3030
transformers: '4.57.6'
3131
- python: '3.10'
32-
transformers: '5.1.0'
32+
transformers: '5.2.0'
3333
- python: '3.10'
3434
transformers: 'main'
3535
# 3.11 - torch
@@ -41,7 +41,7 @@ jobs:
4141
- python: '3.11'
4242
transformers: '4.57.6'
4343
- python: '3.11'
44-
transformers: '5.1.0'
44+
transformers: '5.2.0'
4545
- python: '3.11'
4646
transformers: 'main'
4747
# 3.13 - torch
@@ -54,6 +54,8 @@ jobs:
5454
transformers: '4.51.3'
5555
- python: '3.13'
5656
transformers: '4.55.4'
57+
- python: '3.13'
58+
transformers: '4.57.6'
5759
steps:
5860
- uses: actions/checkout@v3
5961

_doc/final/plot_export_gemma3_tiny_input_observer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# %%
5454
# Captures inputs and outputs for the model.
5555
observer = InputObserver(
56-
missing=dict(pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16))
56+
value_if_missing=dict(pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16))
5757
)
5858
with (
5959
register_additional_serialization_functions(patch_transformers=True),

_unittests/ut_investigate/test_input_observer.py

Lines changed: 140 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ def forward(self, x=None, y=None):
887887
# self.assertEqual(2, len(args))
888888
# self.assertEqual(len([v for v in args.values() if v is not None]), 2)
889889

890-
def test_infer_dynamic_shapes_missing(self):
890+
def test_infer_dynamic_shapes_missing_kwargs(self):
891891
class Model(torch.nn.Module):
892892
def forward(
893893
self,
@@ -903,33 +903,35 @@ def forward(
903903

904904
inputs = [
905905
dict(
906-
input_ids=torch.ones((1, 282), dtype=torch.int64),
907-
pixel_values=torch.ones((1, 3, 896, 896), dtype=torch.int64),
908-
attention_mask=torch.ones((1, 282), dtype=torch.int64),
909-
position_ids=torch.ones((1, 282), dtype=torch.int64),
910-
token_type_ids=torch.ones((1, 282), dtype=torch.int64),
911-
cache_position=torch.ones((282,), dtype=torch.int64),
906+
input_ids=torch.ones((1, 28), dtype=torch.int64),
907+
pixel_values=torch.ones((1, 3, 112, 112), dtype=torch.int64),
908+
attention_mask=torch.ones((1, 28), dtype=torch.int64),
909+
position_ids=torch.ones((1, 28), dtype=torch.int64),
910+
token_type_ids=torch.ones((1, 28), dtype=torch.int64),
911+
cache_position=torch.ones((28,), dtype=torch.int64),
912912
),
913913
dict(
914914
input_ids=torch.ones((1, 1), dtype=torch.int64),
915-
attention_mask=torch.ones((1, 283), dtype=torch.int64),
915+
attention_mask=torch.ones((1, 29), dtype=torch.int64),
916916
position_ids=torch.ones((1, 1), dtype=torch.int64),
917-
past_key_values=torch.rand((1, 1, 282, 32)),
917+
past_key_values=torch.rand((1, 1, 28, 32)),
918918
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
919919
cache_position=torch.ones((1,), dtype=torch.int64),
920920
),
921921
dict(
922922
input_ids=torch.ones((1, 1), dtype=torch.int64),
923-
attention_mask=torch.ones((1, 284), dtype=torch.int64),
923+
attention_mask=torch.ones((1, 30), dtype=torch.int64),
924924
position_ids=torch.ones((1, 1), dtype=torch.int64),
925-
past_key_values=torch.rand((1, 1, 283, 32)),
925+
past_key_values=torch.rand((1, 1, 29, 32)),
926926
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
927927
cache_position=torch.ones((1,), dtype=torch.int64),
928928
),
929929
]
930930

931931
model = Model()
932-
observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896))))
932+
observer = InputObserver(
933+
value_if_missing=dict(pixel_values=torch.empty((0, 3, 112, 112)))
934+
)
933935
with observer(model):
934936
for kwargs in inputs:
935937
model(**kwargs)
@@ -946,6 +948,132 @@ def forward(
946948
"cache_position": {0: cst},
947949
}
948950
self.assertEqual(expected, shapes)
951+
kwargs = observer.infer_arguments()
952+
self.assertEqual(list(expected), list(kwargs))
953+
self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"].shape)
954+
955+
def test_infer_dynamic_shapes_missing_args(self):
956+
class Model(torch.nn.Module):
957+
def forward(
958+
self,
959+
input_ids=None,
960+
pixel_values=None,
961+
attention_mask=None,
962+
past_key_values=None,
963+
):
964+
return input_ids
965+
966+
inputs = [
967+
(
968+
torch.ones((1, 28), dtype=torch.int64),
969+
torch.ones((1, 3, 112, 112), dtype=torch.int64),
970+
torch.ones((1, 28), dtype=torch.int64),
971+
),
972+
(
973+
torch.ones((1, 1), dtype=torch.int64),
974+
None,
975+
torch.ones((1, 29), dtype=torch.int64),
976+
torch.rand((1, 1, 28, 32)),
977+
),
978+
(
979+
torch.ones((1, 1), dtype=torch.int64),
980+
None,
981+
torch.ones((1, 30), dtype=torch.int64),
982+
torch.rand((1, 1, 29, 32)),
983+
),
984+
]
985+
986+
model = Model()
987+
observer = InputObserver(
988+
value_if_missing={1: torch.empty((0, 3, 112, 112), dtype=torch.int64)}
989+
)
990+
with observer(model):
991+
for args in inputs:
992+
model(*args)
993+
994+
shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
995+
cst = torch.export.Dim.DYNAMIC
996+
expected = ({0: cst, 1: cst}, {0: cst}, {0: cst, 1: cst}, {0: cst, 2: cst})
997+
self.assertEqual(expected, shapes)
998+
args = observer.infer_arguments()
999+
self.assertEqual(len(expected), len(args))
1000+
self.assertEqual((0, 3, 112, 112), args[1].shape)
1001+
1002+
def test_infer_dynamic_shapes_missing_kwargs_nested(self):
1003+
class Model(torch.nn.Module):
1004+
def forward(
1005+
self,
1006+
input_ids=None,
1007+
pixel_values=None,
1008+
attention_mask=None,
1009+
position_ids=None,
1010+
past_key_values=None,
1011+
token_type_ids=None,
1012+
cache_position=None,
1013+
):
1014+
return input_ids
1015+
1016+
inputs = [
1017+
dict(
1018+
input_ids=torch.ones((1, 28), dtype=torch.int64),
1019+
pixel_values=(
1020+
torch.ones((1, 3, 112, 112), dtype=torch.int64),
1021+
torch.ones((1, 3, 112, 112), dtype=torch.int64),
1022+
),
1023+
attention_mask=torch.ones((1, 28), dtype=torch.int64),
1024+
position_ids=torch.ones((1, 28), dtype=torch.int64),
1025+
token_type_ids=torch.ones((1, 28), dtype=torch.int64),
1026+
cache_position=torch.ones((28,), dtype=torch.int64),
1027+
),
1028+
dict(
1029+
input_ids=torch.ones((1, 1), dtype=torch.int64),
1030+
attention_mask=torch.ones((1, 29), dtype=torch.int64),
1031+
position_ids=torch.ones((1, 1), dtype=torch.int64),
1032+
past_key_values=torch.rand((1, 1, 28, 32)),
1033+
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
1034+
cache_position=torch.ones((1,), dtype=torch.int64),
1035+
),
1036+
dict(
1037+
input_ids=torch.ones((1, 1), dtype=torch.int64),
1038+
attention_mask=torch.ones((1, 30), dtype=torch.int64),
1039+
position_ids=torch.ones((1, 1), dtype=torch.int64),
1040+
past_key_values=torch.rand((1, 1, 29, 32)),
1041+
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
1042+
cache_position=torch.ones((1,), dtype=torch.int64),
1043+
),
1044+
]
1045+
1046+
model = Model()
1047+
observer = InputObserver(
1048+
value_if_missing=dict(
1049+
pixel_values=(
1050+
torch.empty((0, 3, 112, 112), dtype=torch.int64),
1051+
torch.empty((0, 3, 112, 112), dtype=torch.int64),
1052+
)
1053+
)
1054+
)
1055+
with observer(model):
1056+
for kwargs in inputs:
1057+
model(**kwargs)
1058+
1059+
shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
1060+
cst = torch.export.Dim.DYNAMIC
1061+
expected = {
1062+
"input_ids": {0: cst, 1: cst},
1063+
"pixel_values": ({0: cst}, {0: cst}),
1064+
"attention_mask": {0: cst, 1: cst},
1065+
"position_ids": {0: cst, 1: cst},
1066+
"past_key_values": {0: cst, 2: cst},
1067+
"token_type_ids": {0: cst, 1: cst},
1068+
"cache_position": {0: cst},
1069+
}
1070+
self.assertEqual(expected, shapes)
1071+
kwargs = observer.infer_arguments()
1072+
self.assertEqual(list(expected), list(kwargs))
1073+
self.assertIsInstance(kwargs["pixel_values"], tuple)
1074+
self.assertEqual(2, len(kwargs["pixel_values"]))
1075+
self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"][0].shape)
1076+
self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"][1].shape)
9491077

9501078
def test_io_captured_kwargs_kwargs(self):
9511079
class Model(torch.nn.Module):

_unittests/ut_investigate/test_input_observer_transformers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,11 @@ def forward(
279279
]
280280

281281
model = Model()
282-
observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896))))
282+
observer = InputObserver(
283+
value_if_missing=dict(
284+
pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.int64)
285+
)
286+
)
283287
with (
284288
register_additional_serialization_functions(patch_transformers=True),
285289
observer(model),

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_falcon_mamba_dev(self):
266266
model(**inputs)
267267
model(**data["inputs2"])
268268
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
269-
if not has_transformers("5.2.99"):
269+
if not has_transformers("5.3.99"):
270270
raise unittest.SkipTest("The model has control flow.")
271271
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
272272
torch.export.export(

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class TestTasksImageTextToText(ExtTestCase):
1717
@hide_stdout()
18-
@requires_transformers("5.2.99")
18+
@requires_transformers("5.3.99")
1919
@requires_torch("2.7.99")
2020
def test_image_text_to_text_idefics(self):
2121
mid = "HuggingFaceM4/tiny-random-idefics"
@@ -32,7 +32,7 @@ def test_image_text_to_text_idefics(self):
3232
self.assertEqualAny(expected, ep.module()(**inputs), atol=1)
3333

3434
@hide_stdout()
35-
@requires_transformers("5.2.99")
35+
@requires_transformers("5.3.99")
3636
@requires_torch("2.7.99")
3737
def test_image_text_to_text_tiny_gemma3(self):
3838
"""
@@ -88,7 +88,7 @@ def test_image_text_to_text_gemma3_4b_it(self):
8888
self.assertEqualAny(expected, ep.module()(**inputs))
8989

9090
@hide_stdout()
91-
@requires_transformers("5.2.99")
91+
@requires_transformers("5.3.99")
9292
@requires_torch("2.7.99")
9393
def test_image_text_to_text_zai_glm(self):
9494
"""

0 commit comments

Comments
 (0)