1414
1515import dataclasses
1616import inspect
17- from typing import Any , Callable , Dict , List , Optional , Union
17+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1818
1919import attrs
2020import cirq
4747 ec_point ,
4848 registers ,
4949 resolver_dict ,
50+ sympy_to_proto ,
5051)
51- from qualtran .serialization .sympy import sympy_expr_from_proto , sympy_expr_to_proto
5252
5353
5454def arg_to_proto (* , name : str , val : Any ) -> bloq_pb2 .BloqArg :
@@ -59,7 +59,7 @@ def arg_to_proto(*, name: str, val: Any) -> bloq_pb2.BloqArg:
5959 if isinstance (val , str ):
6060 return bloq_pb2 .BloqArg (name = name , string_val = val )
6161 if isinstance (val , sympy .Expr ):
62- return bloq_pb2 .BloqArg (name = name , sympy_expr = sympy_expr_to_proto (val ))
62+ return bloq_pb2 .BloqArg (name = name , sympy_expr = sympy_to_proto . sympy_expr_to_proto (val ))
6363 if isinstance (val , Register ):
6464 return bloq_pb2 .BloqArg (name = name , register = registers .register_to_proto (val ))
6565 if isinstance (val , tuple ) and all (isinstance (x , Register ) for x in val ):
@@ -90,7 +90,7 @@ def arg_from_proto(arg: bloq_pb2.BloqArg) -> Dict[str, Any]:
9090 if arg .HasField ("string_val" ):
9191 return {arg .name : arg .string_val }
9292 if arg .HasField ("sympy_expr" ):
93- return {arg .name : sympy_expr_from_proto (arg .sympy_expr )}
93+ return {arg .name : sympy_to_proto . sympy_expr_from_proto (arg .sympy_expr )}
9494 if arg .HasField ("register" ):
9595 return {arg .name : registers .register_from_proto (arg .register )}
9696 if arg .HasField ("registers" ):
@@ -185,44 +185,63 @@ def bloqs_to_proto(
185185
186186 A `BloqLibrary` contains multiple bloqs and their hierarchical decompositions. Since
187187 decompositions can use bloq objects that are not explicitly listed in the `bloqs` argument to
188- this function, this routine will recursively add any bloq objects encountered in decompositions
189- to the bloq library.
188+ this function, this routine will recursively (up to `max_depth`) add any bloq objects
189+ encountered in decompositions to the bloq library.
190+
191+ For bloqs within `max_depth` decompositions of the bloqs passed explicitly to this function,
192+ we perform a full serialization: each bloq is serialized with its decomposition and resource
193+ costs. For bloqs encountered only through references from full decompositions, or through
194+ bloqs included as compiled-time classical parameters to bloqs; we perform a "shallow"
195+ serialization where only the bloq, its signature, and its attributes are included in the
196+ BloqLibrary.
190197 """
191198
192199 # The bloq library uses a unique integer index as a simple address for each bloq object.
193200 # Set up this mapping and populate it by recursively searching for subbloqs.
194- bloq_to_id : Dict [Bloq , int ] = {}
201+ # Each value is an (id: bool, shallow: bool) tuple, where the second entry can be set to
202+ # `True` for bloqs that need to be referred to but do not need a full serialization.
203+ bloq_to_id_ext : Dict [Bloq , Tuple [int , bool ]] = {}
195204 for bloq in bloqs :
196- _assign_bloq_an_id (bloq , bloq_to_id )
197- _search_for_subbloqs (bloq , bloq_to_id , pred , max_depth )
205+ _assign_bloq_an_id (bloq , bloq_to_id_ext , shallow = True )
206+ _search_for_subbloqs (bloq , bloq_to_id_ext , pred , max_depth )
207+
208+ bloq_to_id = {bloq : bloq_id for bloq , (bloq_id , shallow ) in bloq_to_id_ext .items ()}
198209
199210 # Decompose[..]Error is raised if `bloq` does not have a decomposition.
200211 # KeyError is raised if `bloq` has a decomposition, but we do not wish to serialize it
201212 # because of conditions checked by `pred` and `max_depth`.
202213 stop_recursing_exceptions = (DecomposeNotImplementedError , DecomposeTypeError , KeyError )
203214
204- # `bloq_to_id` would now contain a list of all bloqs that should be serialized.
215+ # `bloq_to_id` contains a list of all bloqs that should be serialized.
205216 library = bloq_pb2 .BloqLibrary (name = name )
206- for bloq , bloq_id in bloq_to_id .items ():
207- try :
208- cbloq = bloq if isinstance (bloq , CompositeBloq ) else bloq .decompose_bloq ()
209- decomposition = [_connection_to_proto (cxn , bloq_to_id ) for cxn in cbloq .connections ]
210- except stop_recursing_exceptions :
217+ for bloq , (bloq_id , shallow ) in bloq_to_id_ext .items ():
218+ if shallow :
211219 decomposition = None
220+ else :
221+ try :
222+ cbloq = bloq if isinstance (bloq , CompositeBloq ) else bloq .decompose_bloq ()
223+ decomposition = [_connection_to_proto (cxn , bloq_to_id ) for cxn in cbloq .connections ]
224+ except stop_recursing_exceptions :
225+ decomposition = None
212226
213- try :
214- bloq_counts = {
215- bloq_to_id [b ]: args .int_or_sympy_to_proto (c )
216- for b , c in sorted (bloq .bloq_counts ().items (), key = lambda x : type (x [0 ]).__name__ )
217- }
218- except stop_recursing_exceptions :
227+ if shallow :
219228 bloq_counts = None
229+ else :
230+ try :
231+ bloq_counts = {
232+ bloq_to_id [b ]: args .int_or_sympy_to_proto (c )
233+ for b , c in sorted (
234+ bloq .bloq_counts ().items (), key = lambda x : type (x [0 ]).__name__
235+ )
236+ }
237+ except stop_recursing_exceptions :
238+ bloq_counts = None
220239
221240 library .table .add (
222- bloq_id = bloq_id ,
241+ bloq_id = bloq_to_id [ bloq ] ,
223242 decomposition = decomposition ,
224243 bloq_counts = bloq_counts ,
225- bloq = _bloq_to_proto (bloq , bloq_to_id = bloq_to_id ),
244+ bloq = _bloq_to_proto (bloq , bloq_to_id = bloq_to_id , shallow = shallow ),
226245 )
227246 return library
228247
@@ -273,11 +292,15 @@ def _bloq_instance_to_proto(
273292 return bloq_pb2 .BloqInstance (instance_id = binst .i , bloq_id = bloq_to_id [binst .bloq ])
274293
275294
276- def _assign_bloq_an_id (bloq : Bloq , bloq_to_id : Dict [Bloq , int ] ):
295+ def _assign_bloq_an_id (bloq : Bloq , bloq_to_id : Dict [Bloq , Tuple [ int , bool ]], shallow : bool = False ):
277296 """Assigns a new index for `bloq` and records it into the `bloq_to_id` mapping."""
278- if bloq not in bloq_to_id :
297+ if bloq in bloq_to_id :
298+ # Keep the same id, but if anyone requests a non-shallow serialization; do it.
299+ bloq_id , existing_shallow = bloq_to_id [bloq ]
300+ bloq_to_id [bloq ] = (bloq_id , (existing_shallow and shallow ))
301+ else :
279302 next_idx = len (bloq_to_id )
280- bloq_to_id [bloq ] = next_idx
303+ bloq_to_id [bloq ] = next_idx , shallow
281304
282305
283306def _cbloq_ordered_bloq_instances (cbloq : CompositeBloq ) -> List [BloqInstance ]:
@@ -291,7 +314,10 @@ def _cbloq_ordered_bloq_instances(cbloq: CompositeBloq) -> List[BloqInstance]:
291314
292315
293316def _search_for_subbloqs (
294- bloq : Bloq , bloq_to_id : Dict [Bloq , int ], pred : Callable [[BloqInstance ], bool ], max_depth : int
317+ bloq : Bloq ,
318+ bloq_to_id : Dict [Bloq , Tuple [int , bool ]],
319+ pred : Callable [[BloqInstance ], bool ],
320+ max_depth : int ,
295321) -> None :
296322 """Recursively finds all bloqs.
297323
@@ -309,14 +335,16 @@ def _search_for_subbloqs(
309335 `pred` is not used when querying the call graph nor when inspecting the bloq's attributes.
310336 """
311337
312- assert bloq in bloq_to_id
313338 if max_depth > 0 :
339+ # Ensure full serialization of this bloq
340+ _assign_bloq_an_id (bloq , bloq_to_id , shallow = False )
341+
314342 # Search the bloq's decomposition
315343 try :
316344 cbloq = bloq if isinstance (bloq , CompositeBloq ) else bloq .decompose_bloq ()
317345 for binst in _cbloq_ordered_bloq_instances (cbloq ):
318346 subbloq = binst .bloq
319- _assign_bloq_an_id (subbloq , bloq_to_id )
347+ _assign_bloq_an_id (subbloq , bloq_to_id , shallow = True )
320348 if pred (binst ):
321349 _search_for_subbloqs (subbloq , bloq_to_id , pred , max_depth - 1 )
322350 else :
@@ -328,7 +356,7 @@ def _search_for_subbloqs(
328356 # Search the bloq's call graph
329357 try :
330358 for subbloq , _ in bloq .bloq_counts ().items ():
331- _assign_bloq_an_id (subbloq , bloq_to_id )
359+ _assign_bloq_an_id (subbloq , bloq_to_id , shallow = True )
332360 _search_for_subbloqs (subbloq , bloq_to_id , pred , 0 )
333361 except NotImplementedError :
334362 # No call graph, nothing to recurse on.
@@ -339,15 +367,20 @@ def _search_for_subbloqs(
339367 for field in _iter_fields (bloq ):
340368 subbloq = getattr (bloq , field .name )
341369 if isinstance (subbloq , Bloq ):
342- _assign_bloq_an_id (subbloq , bloq_to_id )
370+ _assign_bloq_an_id (subbloq , bloq_to_id , shallow = True )
343371 _search_for_subbloqs (subbloq , bloq_to_id , pred , 0 )
344372
345373
346- def _bloq_to_proto (bloq : Bloq , * , bloq_to_id : Dict [ Bloq , int ]) -> bloq_pb2 . Bloq :
347- try :
348- t_complexity = annotations . t_complexity_to_proto ( bloq . t_complexity ())
349- except ( DecomposeTypeError , DecomposeNotImplementedError , TypeError ) :
374+ def _bloq_to_proto (
375+ bloq : Bloq , * , bloq_to_id : Dict [ Bloq , int ], shallow : bool = False
376+ ) -> bloq_pb2 . Bloq :
377+ if shallow :
350378 t_complexity = None
379+ else :
380+ try :
381+ t_complexity = annotations .t_complexity_to_proto (bloq .t_complexity ())
382+ except (DecomposeTypeError , DecomposeNotImplementedError , TypeError ):
383+ t_complexity = None
351384
352385 name = bloq .__module__ + "." + bloq .__class__ .__qualname__
353386 return bloq_pb2 .Bloq (
0 commit comments