Skip to content

Commit ddea098

Browse files
authored
documentation (#361)
* documentation * foc
1 parent d054876 commit ddea098

2 files changed

Lines changed: 13 additions & 5 deletions

File tree

_unittests/ut_export/test_cf_simple_loop_for.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
104104
def test_simple_loop_for_2(self):
105105
class Model(torch.nn.Module):
106106
def forward(self, n_iter, x):
107-
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
107+
def body(
108+
i: torch.Tensor, x: torch.Tensor
109+
) -> Tuple[torch.Tensor, torch.Tensor]:
108110
return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(1))
109111

110112
return simple_loop_for(n_iter, body, (x,))
@@ -172,8 +174,13 @@ def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
172174
def test_simple_loop_for_2_concatenation_dims(self):
173175
class Model(torch.nn.Module):
174176
def forward(self, n_iter, x):
175-
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
176-
return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))
177+
def body(
178+
i: torch.Tensor, x: torch.Tensor
179+
) -> Tuple[torch.Tensor, torch.Tensor]:
180+
return (
181+
x[: i.item() + 1].unsqueeze(1),
182+
x[i.item() + 1 :].unsqueeze(0),
183+
)
177184

178185
return simple_loop_for(n_iter, body, (x,), (0, 1))
179186

onnx_diagnostic/export/cf_simple_loop_for.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def body(i, x):
321321
322322
class Model(torch.nn.Module):
323323
def forward(self, n_iter, x):
324-
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
324+
def body(i, x):
325325
return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))
326326
327327
return simple_loop_for(n_iter, body, (x,), (0, 1))
@@ -346,6 +346,7 @@ def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
346346
),
347347
)
348348
torch._check(
349-
isinstance(res, tuple), f"Output of the loop should be a tuple not {type(res)}."
349+
isinstance(res, tuple),
350+
lambda: f"Output of the loop should be a tuple not {type(res)}.",
350351
)
351352
return res[0] if len(res) == 1 else res

0 commit comments

Comments
 (0)