Skip to content

Commit 4c04e26

Browse files
authored
add an example about simple loop (#368)
* add an example about simple loop * graph
1 parent b4360cb commit 4c04e26

1 file changed

Lines changed: 136 additions & 0 deletions

File tree

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
.. _l-plot-simple-for-loop:
3+
4+
Export with loops
5+
=================
6+
7+
This is a simple example of loop which cannot be efficiently rewritten
8+
with ``scan``.
9+
"""
10+
11+
import torch
12+
from onnx_diagnostic import doc
13+
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
14+
15+
16+
class Model(torch.nn.Module):
17+
def __init__(self, crop_size):
18+
super().__init__()
19+
self.crop_size = crop_size
20+
21+
def forward(self, W, splits):
22+
crop_size = self.crop_size
23+
starts = splits[:-1]
24+
ends = splits[1:]
25+
cropped = []
26+
for start, end in zip(starts, ends):
27+
extract = W[:, start:end]
28+
if extract.shape[1] < crop_size:
29+
cropped.append(extract)
30+
else:
31+
cropped.append(extract[:, :crop_size])
32+
return torch.cat(cropped, axis=1)
33+
34+
35+
model = Model(4)
36+
args = (torch.rand((2, 22)), torch.tensor([0, 5, 15, 20, 22], dtype=torch.int64))
37+
38+
expected = model(*args)
39+
print(f"-- exected shape: {expected.shape}")
40+
41+
42+
# %%
43+
# Rewrite with higher order ops scan
44+
# ++++++++++++++++++++++++++++++++++
45+
#
46+
# The loop cannot be exported as is. It needs to be rewritten.
47+
48+
49+
class ModelWithScan(Model):
50+
def forward(self, W, splits):
51+
crop_size = self.crop_size
52+
starts = splits[:-1]
53+
ends = splits[1:]
54+
55+
def body_scan(init, split, W):
56+
extract = W[:, split[0].item() : split[1].item()]
57+
cropped = extract[:, : torch.sym_min(extract.shape[1], crop_size)]
58+
carried = torch.cat([init, cropped], axis=1)
59+
return carried
60+
61+
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
62+
return torch.ops.higher_order.scan(
63+
body_scan, [torch.empty((W.shape[0], 0), dtype=W.dtype)], [starts_ends], [W]
64+
)
65+
66+
67+
rewritten_model_with_scan = ModelWithScan(4)
68+
(results,) = rewritten_model_with_scan(*args)
69+
70+
print(f"-- max discrepancies with scan: { torch.abs(expected - results).max()}")
71+
72+
# %%
73+
# This approach has one flows, the variable carried grows at every
74+
# iteration and the cost of the copy is quadratic when the same operation
75+
# in the first model is linear.
76+
# We cannot simply return variable ``cropped`` because its shape
77+
# is not always the same.
78+
#
79+
# Introduce of a new higher order ops: simple_loop_for
80+
# ++++++++++++++++++++++++++++++++++++++++++++++++++++
81+
#
82+
# ``simple_loop_for`` was designed to support this specific case.
83+
# It takes all the outputs coming from the body function and stores
84+
# them in list. Then it contenates them according to ``concatenation_dims``.
85+
86+
87+
class ModelWithLoop(Model):
88+
def forward(self, W, splits):
89+
crop_size = self.crop_size
90+
starts = splits[:-1]
91+
ends = splits[1:]
92+
93+
def body_loop(i, splits, W):
94+
split = splits[i.item() : (i + 1).item()][0] # [i.item()] fails
95+
extract = W[:, split[0].item() : split[1].item()]
96+
cropped = extract[:, : torch.sym_min(extract.shape[1], crop_size)]
97+
return (cropped,)
98+
99+
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
100+
n_iterations = torch.tensor(starts_ends.shape[0], dtype=torch.int64)
101+
return simple_loop_for(
102+
n_iterations, body_loop, (starts_ends, W), concatenation_dims=[1]
103+
)
104+
105+
106+
rewritten_model_with_loop = ModelWithLoop(4)
107+
results = rewritten_model_with_loop(*args)
108+
109+
print(f"-- max discrepancies with loop: { torch.abs(expected - results).max()}")
110+
111+
112+
# %%
113+
# torch.export.export?
114+
# ++++++++++++++++++++
115+
116+
dynamic_shapes = (
117+
{0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
118+
{0: torch.export.Dim.DYNAMIC},
119+
)
120+
try:
121+
ep = torch.export.export(rewritten_model_with_scan, args, dynamic_shapes=dynamic_shapes)
122+
print("----- exported program with scan:")
123+
print(ep)
124+
except Exception as e:
125+
print(f"export failed due to {e}")
126+
127+
# %%
128+
# And loops?
129+
130+
131+
ep = torch.export.export(rewritten_model_with_loop, args, dynamic_shapes=dynamic_shapes)
132+
print(ep)
133+
134+
# %%
135+
136+
doc.plot_legend("export a loop\nreturning\ndifferent shapes", "loops", "green")

0 commit comments

Comments
 (0)