@@ -104,7 +104,9 @@ def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
104104 def test_simple_loop_for_2 (self ):
105105 class Model (torch .nn .Module ):
106106 def forward (self , n_iter , x ):
107- def body (i : torch .Tensor , x : torch .Tensor ) -> Tuple [torch .Tensor ]:
107+ def body (
108+ i : torch .Tensor , x : torch .Tensor
109+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
108110 return (x [: i .item () + 1 ].unsqueeze (1 ), x [i .item () + 1 :].unsqueeze (1 ))
109111
110112 return simple_loop_for (n_iter , body , (x ,))
@@ -172,8 +174,13 @@ def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
172174 def test_simple_loop_for_2_concatenation_dims (self ):
173175 class Model (torch .nn .Module ):
174176 def forward (self , n_iter , x ):
175- def body (i : torch .Tensor , x : torch .Tensor ) -> Tuple [torch .Tensor ]:
176- return (x [: i .item () + 1 ].unsqueeze (1 ), x [i .item () + 1 :].unsqueeze (0 ))
177+ def body (
178+ i : torch .Tensor , x : torch .Tensor
179+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
180+ return (
181+ x [: i .item () + 1 ].unsqueeze (1 ),
182+ x [i .item () + 1 :].unsqueeze (0 ),
183+ )
177184
178185 return simple_loop_for (n_iter , body , (x ,), (0 , 1 ))
179186
0 commit comments