Skip to content

Commit 4a4d7af

Browse files
authored
Only serialize the requested depth (#1619)
* strict BloqAsCirqGate and CirqGateAsBloq * fixy * Only serialize things that need serialization * Warn when wrapping a bloq * let cirq handle _unitary_ still
1 parent 09b0c4d commit 4a4d7af

6 files changed

Lines changed: 137 additions & 55 deletions

File tree

qualtran/bloqs/phase_estimation/qubitization_qpe_test.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import time
15+
1416
import cirq
1517
import numpy as np
1618
import pytest
@@ -21,25 +23,40 @@
2123
from qualtran.bloqs.phase_estimation.qpe_window_state import RectangularWindowState
2224
from qualtran.bloqs.phase_estimation.qubitization_qpe import (
2325
_qubitization_qpe_chem_thc,
26+
_qubitization_qpe_hubbard_model_large,
2427
_qubitization_qpe_hubbard_model_small,
2528
_qubitization_qpe_ising,
2629
_qubitization_qpe_sparse_chem,
2730
QubitizationQPE,
2831
)
2932
from qualtran.bloqs.phase_estimation.text_book_qpe_test import simulate_theta_estimate
3033
from qualtran.cirq_interop.testing import GateHelper
34+
from qualtran.serialization.bloq import bloqs_to_proto
3135
from qualtran.testing import execute_notebook
3236

3337

3438
def test_ising_example(bloq_autotester):
3539
bloq_autotester(_qubitization_qpe_ising)
3640

3741

38-
@pytest.mark.slow
39-
def test_qubitization_qpe_bloq_autotester(bloq_autotester):
42+
def test_qubitization_qpe_hubbard_model_small_autotester(bloq_autotester):
4043
bloq_autotester(_qubitization_qpe_hubbard_model_small)
4144

4245

46+
def test_serialization_speed():
47+
start = time.perf_counter()
48+
bloqs_to_proto(_qubitization_qpe_hubbard_model_small.make())
49+
end = time.perf_counter()
50+
# Should take substantially less time than this
51+
if (end - start) > 2.0:
52+
assert False, 'Serialization should only check one level; and should be quick.'
53+
54+
55+
@pytest.mark.slow
56+
def test_qubitization_qpe_hubbard_model_large_autotester(bloq_autotester):
57+
bloq_autotester(_qubitization_qpe_hubbard_model_large)
58+
59+
4360
@pytest.mark.slow
4461
def test_qubitization_qpe_chem_thc_bloq_autotester(bloq_autotester):
4562
bloq_autotester(_qubitization_qpe_chem_thc)

qualtran/serialization/bloq.py

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import dataclasses
1616
import inspect
17-
from typing import Any, Callable, Dict, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import attrs
2020
import cirq
@@ -47,8 +47,8 @@
4747
ec_point,
4848
registers,
4949
resolver_dict,
50+
sympy_to_proto,
5051
)
51-
from qualtran.serialization.sympy import sympy_expr_from_proto, sympy_expr_to_proto
5252

5353

5454
def arg_to_proto(*, name: str, val: Any) -> bloq_pb2.BloqArg:
@@ -59,7 +59,7 @@ def arg_to_proto(*, name: str, val: Any) -> bloq_pb2.BloqArg:
5959
if isinstance(val, str):
6060
return bloq_pb2.BloqArg(name=name, string_val=val)
6161
if isinstance(val, sympy.Expr):
62-
return bloq_pb2.BloqArg(name=name, sympy_expr=sympy_expr_to_proto(val))
62+
return bloq_pb2.BloqArg(name=name, sympy_expr=sympy_to_proto.sympy_expr_to_proto(val))
6363
if isinstance(val, Register):
6464
return bloq_pb2.BloqArg(name=name, register=registers.register_to_proto(val))
6565
if isinstance(val, tuple) and all(isinstance(x, Register) for x in val):
@@ -90,7 +90,7 @@ def arg_from_proto(arg: bloq_pb2.BloqArg) -> Dict[str, Any]:
9090
if arg.HasField("string_val"):
9191
return {arg.name: arg.string_val}
9292
if arg.HasField("sympy_expr"):
93-
return {arg.name: sympy_expr_from_proto(arg.sympy_expr)}
93+
return {arg.name: sympy_to_proto.sympy_expr_from_proto(arg.sympy_expr)}
9494
if arg.HasField("register"):
9595
return {arg.name: registers.register_from_proto(arg.register)}
9696
if arg.HasField("registers"):
@@ -185,44 +185,63 @@ def bloqs_to_proto(
185185
186186
A `BloqLibrary` contains multiple bloqs and their hierarchical decompositions. Since
187187
decompositions can use bloq objects that are not explicitly listed in the `bloqs` argument to
188-
this function, this routine will recursively add any bloq objects encountered in decompositions
189-
to the bloq library.
188+
this function, this routine will recursively (up to `max_depth`) add any bloq objects
189+
encountered in decompositions to the bloq library.
190+
191+
For bloqs within `max_depth` decompositions of the bloqs passed explicitly to this function,
192+
we perform a full serialization: each bloq is serialized with its decomposition and resource
193+
costs. For bloqs encountered only through references from full decompositions, or through
194+
bloqs included as compiled-time classical parameters to bloqs; we perform a "shallow"
195+
serialization where only the bloq, its signature, and its attributes are included in the
196+
BloqLibrary.
190197
"""
191198

192199
# The bloq library uses a unique integer index as a simple address for each bloq object.
193200
# Set up this mapping and populate it by recursively searching for subbloqs.
194-
bloq_to_id: Dict[Bloq, int] = {}
201+
# Each value is an (id: bool, shallow: bool) tuple, where the second entry can be set to
202+
# `True` for bloqs that need to be referred to but do not need a full serialization.
203+
bloq_to_id_ext: Dict[Bloq, Tuple[int, bool]] = {}
195204
for bloq in bloqs:
196-
_assign_bloq_an_id(bloq, bloq_to_id)
197-
_search_for_subbloqs(bloq, bloq_to_id, pred, max_depth)
205+
_assign_bloq_an_id(bloq, bloq_to_id_ext, shallow=True)
206+
_search_for_subbloqs(bloq, bloq_to_id_ext, pred, max_depth)
207+
208+
bloq_to_id = {bloq: bloq_id for bloq, (bloq_id, shallow) in bloq_to_id_ext.items()}
198209

199210
# Decompose[..]Error is raised if `bloq` does not have a decomposition.
200211
# KeyError is raised if `bloq` has a decomposition, but we do not wish to serialize it
201212
# because of conditions checked by `pred` and `max_depth`.
202213
stop_recursing_exceptions = (DecomposeNotImplementedError, DecomposeTypeError, KeyError)
203214

204-
# `bloq_to_id` would now contain a list of all bloqs that should be serialized.
215+
# `bloq_to_id` contains a list of all bloqs that should be serialized.
205216
library = bloq_pb2.BloqLibrary(name=name)
206-
for bloq, bloq_id in bloq_to_id.items():
207-
try:
208-
cbloq = bloq if isinstance(bloq, CompositeBloq) else bloq.decompose_bloq()
209-
decomposition = [_connection_to_proto(cxn, bloq_to_id) for cxn in cbloq.connections]
210-
except stop_recursing_exceptions:
217+
for bloq, (bloq_id, shallow) in bloq_to_id_ext.items():
218+
if shallow:
211219
decomposition = None
220+
else:
221+
try:
222+
cbloq = bloq if isinstance(bloq, CompositeBloq) else bloq.decompose_bloq()
223+
decomposition = [_connection_to_proto(cxn, bloq_to_id) for cxn in cbloq.connections]
224+
except stop_recursing_exceptions:
225+
decomposition = None
212226

213-
try:
214-
bloq_counts = {
215-
bloq_to_id[b]: args.int_or_sympy_to_proto(c)
216-
for b, c in sorted(bloq.bloq_counts().items(), key=lambda x: type(x[0]).__name__)
217-
}
218-
except stop_recursing_exceptions:
227+
if shallow:
219228
bloq_counts = None
229+
else:
230+
try:
231+
bloq_counts = {
232+
bloq_to_id[b]: args.int_or_sympy_to_proto(c)
233+
for b, c in sorted(
234+
bloq.bloq_counts().items(), key=lambda x: type(x[0]).__name__
235+
)
236+
}
237+
except stop_recursing_exceptions:
238+
bloq_counts = None
220239

221240
library.table.add(
222-
bloq_id=bloq_id,
241+
bloq_id=bloq_to_id[bloq],
223242
decomposition=decomposition,
224243
bloq_counts=bloq_counts,
225-
bloq=_bloq_to_proto(bloq, bloq_to_id=bloq_to_id),
244+
bloq=_bloq_to_proto(bloq, bloq_to_id=bloq_to_id, shallow=shallow),
226245
)
227246
return library
228247

@@ -273,11 +292,15 @@ def _bloq_instance_to_proto(
273292
return bloq_pb2.BloqInstance(instance_id=binst.i, bloq_id=bloq_to_id[binst.bloq])
274293

275294

276-
def _assign_bloq_an_id(bloq: Bloq, bloq_to_id: Dict[Bloq, int]):
295+
def _assign_bloq_an_id(bloq: Bloq, bloq_to_id: Dict[Bloq, Tuple[int, bool]], shallow: bool = False):
277296
"""Assigns a new index for `bloq` and records it into the `bloq_to_id` mapping."""
278-
if bloq not in bloq_to_id:
297+
if bloq in bloq_to_id:
298+
# Keep the same id, but if anyone requests a non-shallow serialization; do it.
299+
bloq_id, existing_shallow = bloq_to_id[bloq]
300+
bloq_to_id[bloq] = (bloq_id, (existing_shallow and shallow))
301+
else:
279302
next_idx = len(bloq_to_id)
280-
bloq_to_id[bloq] = next_idx
303+
bloq_to_id[bloq] = next_idx, shallow
281304

282305

283306
def _cbloq_ordered_bloq_instances(cbloq: CompositeBloq) -> List[BloqInstance]:
@@ -291,7 +314,10 @@ def _cbloq_ordered_bloq_instances(cbloq: CompositeBloq) -> List[BloqInstance]:
291314

292315

293316
def _search_for_subbloqs(
294-
bloq: Bloq, bloq_to_id: Dict[Bloq, int], pred: Callable[[BloqInstance], bool], max_depth: int
317+
bloq: Bloq,
318+
bloq_to_id: Dict[Bloq, Tuple[int, bool]],
319+
pred: Callable[[BloqInstance], bool],
320+
max_depth: int,
295321
) -> None:
296322
"""Recursively finds all bloqs.
297323
@@ -309,14 +335,16 @@ def _search_for_subbloqs(
309335
`pred` is not used when querying the call graph nor when inspecting the bloq's attributes.
310336
"""
311337

312-
assert bloq in bloq_to_id
313338
if max_depth > 0:
339+
# Ensure full serialization of this bloq
340+
_assign_bloq_an_id(bloq, bloq_to_id, shallow=False)
341+
314342
# Search the bloq's decomposition
315343
try:
316344
cbloq = bloq if isinstance(bloq, CompositeBloq) else bloq.decompose_bloq()
317345
for binst in _cbloq_ordered_bloq_instances(cbloq):
318346
subbloq = binst.bloq
319-
_assign_bloq_an_id(subbloq, bloq_to_id)
347+
_assign_bloq_an_id(subbloq, bloq_to_id, shallow=True)
320348
if pred(binst):
321349
_search_for_subbloqs(subbloq, bloq_to_id, pred, max_depth - 1)
322350
else:
@@ -328,7 +356,7 @@ def _search_for_subbloqs(
328356
# Search the bloq's call graph
329357
try:
330358
for subbloq, _ in bloq.bloq_counts().items():
331-
_assign_bloq_an_id(subbloq, bloq_to_id)
359+
_assign_bloq_an_id(subbloq, bloq_to_id, shallow=True)
332360
_search_for_subbloqs(subbloq, bloq_to_id, pred, 0)
333361
except NotImplementedError:
334362
# No call graph, nothing to recurse on.
@@ -339,15 +367,20 @@ def _search_for_subbloqs(
339367
for field in _iter_fields(bloq):
340368
subbloq = getattr(bloq, field.name)
341369
if isinstance(subbloq, Bloq):
342-
_assign_bloq_an_id(subbloq, bloq_to_id)
370+
_assign_bloq_an_id(subbloq, bloq_to_id, shallow=True)
343371
_search_for_subbloqs(subbloq, bloq_to_id, pred, 0)
344372

345373

346-
def _bloq_to_proto(bloq: Bloq, *, bloq_to_id: Dict[Bloq, int]) -> bloq_pb2.Bloq:
347-
try:
348-
t_complexity = annotations.t_complexity_to_proto(bloq.t_complexity())
349-
except (DecomposeTypeError, DecomposeNotImplementedError, TypeError):
374+
def _bloq_to_proto(
375+
bloq: Bloq, *, bloq_to_id: Dict[Bloq, int], shallow: bool = False
376+
) -> bloq_pb2.Bloq:
377+
if shallow:
350378
t_complexity = None
379+
else:
380+
try:
381+
t_complexity = annotations.t_complexity_to_proto(bloq.t_complexity())
382+
except (DecomposeTypeError, DecomposeNotImplementedError, TypeError):
383+
t_complexity = None
351384

352385
name = bloq.__module__ + "." + bloq.__class__.__qualname__
353386
return bloq_pb2.Bloq(

qualtran/serialization/bloq_test.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_cbloq_to_proto_test_two_cswap():
133133
cswap_proto = bloq_serialization.bloqs_to_proto(TestCSwap(100)).table[0].bloq
134134
assert TestCSwap(100).t_complexity().t == 7 * 100
135135
cbloq = TestTwoCSwap(100).decompose_bloq()
136-
proto_lib = bloq_serialization.bloqs_to_proto(cbloq)
136+
proto_lib = bloq_serialization.bloqs_to_proto(cbloq, max_depth=100)
137137
assert len(proto_lib.table) == 2
138138
assert proto_lib.table[1].bloq == cswap_proto
139139
assert proto_lib.table[0].bloq.t_complexity.t == 7 * 100 * 2
@@ -178,18 +178,50 @@ def test_meta_bloq_to_proto():
178178
sub_bloq_one = TestTwoCSwap(20)
179179
sub_bloq_two = TestTwoCSwap(20).decompose_bloq()
180180
bloq = TestMetaBloq(sub_bloq_one, sub_bloq_two)
181-
proto_lib = bloq_serialization.bloqs_to_proto(bloq, name="Meta Bloq Test")
182-
assert proto_lib.name == "Meta Bloq Test"
183-
assert len(proto_lib.table) == 3 # TestMetaBloq, TestTwoCSwap, CompositeBloq
184-
185-
proto_lib = bloq_serialization.bloqs_to_proto(bloq, max_depth=2)
186-
assert len(proto_lib.table) == 4 # TestMetaBloq, TestTwoCSwap, CompositeBloq, TestCSwap
187-
188-
assert proto_lib.table[0].bloq.name.split('.')[-1] == 'TestMetaBloq'
189-
assert len(proto_lib.table[0].decomposition) == 9
190-
191-
assert proto_lib.table[1].bloq.name.split('.')[-1] == 'TestTwoCSwap'
192-
assert len(proto_lib.table[1].decomposition) == 9
193181

194-
assert proto_lib == bloq_serialization.bloqs_to_proto(bloq, bloq, TestTwoCSwap(20), max_depth=2)
195-
assert bloq in bloq_serialization.bloqs_from_proto(proto_lib)
182+
depth_0_lib = bloq_serialization.bloqs_to_proto(bloq, max_depth=0)
183+
assert len(depth_0_lib.table) == 3
184+
for table_entry in depth_0_lib.table:
185+
bloq_name = table_entry.bloq.name.split('.')[-1]
186+
if bloq_name == 'TestMetaBloq':
187+
assert len(table_entry.decomposition) == 0
188+
elif bloq_name == 'TestTwoCSwap':
189+
# This is included solely through TestMetaBloq's attribute
190+
assert len(table_entry.decomposition) == 0
191+
elif bloq_name == 'CompositeBloq':
192+
# This is included solely through TestMetaBloq's attribute
193+
assert len(table_entry.decomposition) == 0
194+
else:
195+
raise AssertionError(f"Unknown {bloq_name}")
196+
197+
depth_1_lib = bloq_serialization.bloqs_to_proto(bloq, max_depth=1)
198+
assert len(depth_1_lib.table) == 3
199+
for table_entry in depth_1_lib.table:
200+
bloq_name = table_entry.bloq.name.split('.')[-1]
201+
if bloq_name == 'TestMetaBloq':
202+
assert len(table_entry.decomposition) > 0
203+
elif bloq_name == 'TestTwoCSwap':
204+
# This is still a "shallow" inclusion, because this only appears in 1 level
205+
# of decomposition
206+
assert len(table_entry.decomposition) == 0
207+
elif bloq_name == 'CompositeBloq':
208+
assert len(table_entry.decomposition) == 0
209+
else:
210+
raise AssertionError(f"Unknown {bloq_name}")
211+
212+
depth_2_lib = bloq_serialization.bloqs_to_proto(bloq, max_depth=2)
213+
assert len(depth_2_lib.table) > 3
214+
for table_entry in depth_2_lib.table:
215+
bloq_name = table_entry.bloq.name.split('.')[-1]
216+
if bloq_name == 'TestMetaBloq':
217+
assert len(table_entry.decomposition) == 9
218+
elif bloq_name == 'TestTwoCSwap':
219+
assert len(table_entry.decomposition) == 9
220+
elif bloq_name == 'CompositeBloq':
221+
assert len(table_entry.decomposition) > 0
222+
elif bloq_name == 'TestCSwap':
223+
assert len(table_entry.decomposition) == 0
224+
else:
225+
raise AssertionError(f"Unknown {bloq_name}")
226+
227+
assert bloq in bloq_serialization.bloqs_from_proto(depth_2_lib)

qualtran/serialization/sympy_test.py renamed to qualtran/serialization/sympy_to_proto_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sympy
1717
from sympy.codegen.cfunctions import log2
1818

19-
from qualtran.serialization.sympy import sympy_expr_from_proto, sympy_expr_to_proto
19+
from qualtran.serialization.sympy_to_proto import sympy_expr_from_proto, sympy_expr_to_proto
2020

2121
x = sympy.Symbol('x', positive=True)
2222
a, b, c = sympy.symbols("a b c")

qualtran/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,11 @@ def assert_bloq_example_serializes(bloq_ex: BloqExample) -> None:
579579
try:
580580
bloq_lib = bloqs_to_proto(bloq)
581581
except Exception as e:
582-
raise BloqCheckException.fail('Serialization Failed:\n' + str(e)) from e
582+
raise BloqCheckException.fail(f'Serialization Failed: {e!r}') from e
583583
try:
584584
bloq_roundtrip = bloqs_from_proto(bloq_lib)[0]
585585
except Exception as e:
586-
raise BloqCheckException.fail('DeSerialization Failed:\n' + str(e)) from e
586+
raise BloqCheckException.fail(f'DeSerialization Failed: {e!r}') from e
587587

588588
try:
589589
assert bloq == bloq_roundtrip

0 commit comments

Comments
 (0)