@@ -887,7 +887,7 @@ def forward(self, x=None, y=None):
887887 # self.assertEqual(2, len(args))
888888 # self.assertEqual(len([v for v in args.values() if v is not None]), 2)
889889
890- def test_infer_dynamic_shapes_missing (self ):
890+ def test_infer_dynamic_shapes_missing_kwargs (self ):
891891 class Model (torch .nn .Module ):
892892 def forward (
893893 self ,
@@ -903,33 +903,35 @@ def forward(
903903
904904 inputs = [
905905 dict (
906- input_ids = torch .ones ((1 , 282 ), dtype = torch .int64 ),
907- pixel_values = torch .ones ((1 , 3 , 896 , 896 ), dtype = torch .int64 ),
908- attention_mask = torch .ones ((1 , 282 ), dtype = torch .int64 ),
909- position_ids = torch .ones ((1 , 282 ), dtype = torch .int64 ),
910- token_type_ids = torch .ones ((1 , 282 ), dtype = torch .int64 ),
911- cache_position = torch .ones ((282 ,), dtype = torch .int64 ),
906+ input_ids = torch .ones ((1 , 28 ), dtype = torch .int64 ),
907+ pixel_values = torch .ones ((1 , 3 , 112 , 112 ), dtype = torch .int64 ),
908+ attention_mask = torch .ones ((1 , 28 ), dtype = torch .int64 ),
909+ position_ids = torch .ones ((1 , 28 ), dtype = torch .int64 ),
910+ token_type_ids = torch .ones ((1 , 28 ), dtype = torch .int64 ),
911+ cache_position = torch .ones ((28 ,), dtype = torch .int64 ),
912912 ),
913913 dict (
914914 input_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
915- attention_mask = torch .ones ((1 , 283 ), dtype = torch .int64 ),
915+ attention_mask = torch .ones ((1 , 29 ), dtype = torch .int64 ),
916916 position_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
917- past_key_values = torch .rand ((1 , 1 , 282 , 32 )),
917+ past_key_values = torch .rand ((1 , 1 , 28 , 32 )),
918918 token_type_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
919919 cache_position = torch .ones ((1 ,), dtype = torch .int64 ),
920920 ),
921921 dict (
922922 input_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
923- attention_mask = torch .ones ((1 , 284 ), dtype = torch .int64 ),
923+ attention_mask = torch .ones ((1 , 30 ), dtype = torch .int64 ),
924924 position_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
925- past_key_values = torch .rand ((1 , 1 , 283 , 32 )),
925+ past_key_values = torch .rand ((1 , 1 , 29 , 32 )),
926926 token_type_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
927927 cache_position = torch .ones ((1 ,), dtype = torch .int64 ),
928928 ),
929929 ]
930930
931931 model = Model ()
932- observer = InputObserver (missing = dict (pixel_values = torch .empty ((0 , 3 , 896 , 896 ))))
932+ observer = InputObserver (
933+ value_if_missing = dict (pixel_values = torch .empty ((0 , 3 , 112 , 112 )))
934+ )
933935 with observer (model ):
934936 for kwargs in inputs :
935937 model (** kwargs )
@@ -946,6 +948,132 @@ def forward(
946948 "cache_position" : {0 : cst },
947949 }
948950 self .assertEqual (expected , shapes )
951+ kwargs = observer .infer_arguments ()
952+ self .assertEqual (list (expected ), list (kwargs ))
953+ self .assertEqual ((0 , 3 , 112 , 112 ), kwargs ["pixel_values" ].shape )
954+
955+ def test_infer_dynamic_shapes_missing_args (self ):
956+ class Model (torch .nn .Module ):
957+ def forward (
958+ self ,
959+ input_ids = None ,
960+ pixel_values = None ,
961+ attention_mask = None ,
962+ past_key_values = None ,
963+ ):
964+ return input_ids
965+
966+ inputs = [
967+ (
968+ torch .ones ((1 , 28 ), dtype = torch .int64 ),
969+ torch .ones ((1 , 3 , 112 , 112 ), dtype = torch .int64 ),
970+ torch .ones ((1 , 28 ), dtype = torch .int64 ),
971+ ),
972+ (
973+ torch .ones ((1 , 1 ), dtype = torch .int64 ),
974+ None ,
975+ torch .ones ((1 , 29 ), dtype = torch .int64 ),
976+ torch .rand ((1 , 1 , 28 , 32 )),
977+ ),
978+ (
979+ torch .ones ((1 , 1 ), dtype = torch .int64 ),
980+ None ,
981+ torch .ones ((1 , 30 ), dtype = torch .int64 ),
982+ torch .rand ((1 , 1 , 29 , 32 )),
983+ ),
984+ ]
985+
986+ model = Model ()
987+ observer = InputObserver (
988+ value_if_missing = {1 : torch .empty ((0 , 3 , 112 , 112 ), dtype = torch .int64 )}
989+ )
990+ with observer (model ):
991+ for args in inputs :
992+ model (* args )
993+
994+ shapes = observer .infer_dynamic_shapes (set_batch_dimension_for = True )
995+ cst = torch .export .Dim .DYNAMIC
996+ expected = ({0 : cst , 1 : cst }, {0 : cst }, {0 : cst , 1 : cst }, {0 : cst , 2 : cst })
997+ self .assertEqual (expected , shapes )
998+ args = observer .infer_arguments ()
999+ self .assertEqual (len (expected ), len (args ))
1000+ self .assertEqual ((0 , 3 , 112 , 112 ), args [1 ].shape )
1001+
1002+ def test_infer_dynamic_shapes_missing_kwargs_nested (self ):
1003+ class Model (torch .nn .Module ):
1004+ def forward (
1005+ self ,
1006+ input_ids = None ,
1007+ pixel_values = None ,
1008+ attention_mask = None ,
1009+ position_ids = None ,
1010+ past_key_values = None ,
1011+ token_type_ids = None ,
1012+ cache_position = None ,
1013+ ):
1014+ return input_ids
1015+
1016+ inputs = [
1017+ dict (
1018+ input_ids = torch .ones ((1 , 28 ), dtype = torch .int64 ),
1019+ pixel_values = (
1020+ torch .ones ((1 , 3 , 112 , 112 ), dtype = torch .int64 ),
1021+ torch .ones ((1 , 3 , 112 , 112 ), dtype = torch .int64 ),
1022+ ),
1023+ attention_mask = torch .ones ((1 , 28 ), dtype = torch .int64 ),
1024+ position_ids = torch .ones ((1 , 28 ), dtype = torch .int64 ),
1025+ token_type_ids = torch .ones ((1 , 28 ), dtype = torch .int64 ),
1026+ cache_position = torch .ones ((28 ,), dtype = torch .int64 ),
1027+ ),
1028+ dict (
1029+ input_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
1030+ attention_mask = torch .ones ((1 , 29 ), dtype = torch .int64 ),
1031+ position_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
1032+ past_key_values = torch .rand ((1 , 1 , 28 , 32 )),
1033+ token_type_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
1034+ cache_position = torch .ones ((1 ,), dtype = torch .int64 ),
1035+ ),
1036+ dict (
1037+ input_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
1038+ attention_mask = torch .ones ((1 , 30 ), dtype = torch .int64 ),
1039+ position_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
1040+ past_key_values = torch .rand ((1 , 1 , 29 , 32 )),
1041+ token_type_ids = torch .ones ((1 , 1 ), dtype = torch .int64 ),
1042+ cache_position = torch .ones ((1 ,), dtype = torch .int64 ),
1043+ ),
1044+ ]
1045+
1046+ model = Model ()
1047+ observer = InputObserver (
1048+ value_if_missing = dict (
1049+ pixel_values = (
1050+ torch .empty ((0 , 3 , 112 , 112 ), dtype = torch .int64 ),
1051+ torch .empty ((0 , 3 , 112 , 112 ), dtype = torch .int64 ),
1052+ )
1053+ )
1054+ )
1055+ with observer (model ):
1056+ for kwargs in inputs :
1057+ model (** kwargs )
1058+
1059+ shapes = observer .infer_dynamic_shapes (set_batch_dimension_for = True )
1060+ cst = torch .export .Dim .DYNAMIC
1061+ expected = {
1062+ "input_ids" : {0 : cst , 1 : cst },
1063+ "pixel_values" : ({0 : cst }, {0 : cst }),
1064+ "attention_mask" : {0 : cst , 1 : cst },
1065+ "position_ids" : {0 : cst , 1 : cst },
1066+ "past_key_values" : {0 : cst , 2 : cst },
1067+ "token_type_ids" : {0 : cst , 1 : cst },
1068+ "cache_position" : {0 : cst },
1069+ }
1070+ self .assertEqual (expected , shapes )
1071+ kwargs = observer .infer_arguments ()
1072+ self .assertEqual (list (expected ), list (kwargs ))
1073+ self .assertIsInstance (kwargs ["pixel_values" ], tuple )
1074+ self .assertEqual (2 , len (kwargs ["pixel_values" ]))
1075+ self .assertEqual ((0 , 3 , 112 , 112 ), kwargs ["pixel_values" ][0 ].shape )
1076+ self .assertEqual ((0 , 3 , 112 , 112 ), kwargs ["pixel_values" ][1 ].shape )
9491077
9501078 def test_io_captured_kwargs_kwargs (self ):
9511079 class Model (torch .nn .Module ):
0 commit comments