2020from typing_extensions import Self
2121
2222from qualtran import (
23+ AddControlledT ,
24+ Bloq ,
2325 bloq_example ,
2426 BloqBuilder ,
2527 BloqDocSpec ,
2628 BQUInt ,
29+ CtrlSpec ,
30+ DecomposeTypeError ,
2731 QAny ,
2832 Register ,
2933 Signature ,
3034 Soquet ,
3135 SoquetT ,
3236)
33- from qualtran ._infra .bloq import DecomposeTypeError
3437from qualtran .bloqs .block_encoding import BlockEncoding
3538from qualtran .bloqs .block_encoding .lcu_block_encoding import BlackBoxPrepare , BlackBoxSelect
3639from qualtran .bloqs .block_encoding .phase import Phase
@@ -68,6 +71,7 @@ class LinearCombination(BlockEncoding):
6871 (state should be normalized and can have junk).
6972 select: If specified, oracle taking
7073 $|i\rangle|\psi\rangle \mapsto \text{sgn}(\lambda_i) |i\rangle U_i|\psi\rangle$.
74+ is_controlled: if True, implements a controlled version. Defaults to False.
7175
7276 Registers:
7377 system: The system register.
@@ -88,6 +92,8 @@ class LinearCombination(BlockEncoding):
8892 _prepare : Optional [BlackBoxPrepare ] = None
8993 _select : Optional [BlackBoxSelect ] = None
9094
95+ is_controlled : bool = False
96+
9197 def __attrs_post_init__ (self ):
9298 if len (self ._block_encodings ) != len (self ._lambd ):
9399 raise ValueError ("Must provide the same number of block encodings and coefficients." )
@@ -138,6 +144,7 @@ def rescaled_lambd(self):
138144 @cached_property
139145 def signature (self ) -> Signature :
140146 return Signature .build_from_dtypes (
147+ ctrl = QAny (1 if self .is_controlled else 0 ),
141148 system = QAny (self .system_bitsize ),
142149 ancilla = QAny (self .ancilla_bitsize ),
143150 resource = QAny (self .resource_bitsize ),
@@ -298,11 +305,26 @@ def build_composite_bloq(
298305 be_part = Partition (self .select .system_bitsize , tuple (be_regs ))
299306
300307 prepare_soqs = bb .add_d (self .prepare , ** prepare_in_soqs )
301- select_out_soqs = bb .add_d (
302- self .select ,
303- selection = prepare_soqs .pop ("selection" ),
304- system = cast (Soquet , bb .add (evolve (be_part , partition = False ), ** be_system_soqs )),
305- )
308+
309+ if not self .is_controlled :
310+ select_out_soqs = bb .add_d (
311+ self .select ,
312+ selection = prepare_soqs .pop ("selection" ),
313+ system = cast (Soquet , bb .add (evolve (be_part , partition = False ), ** be_system_soqs )),
314+ )
315+ else :
316+ _ , add_ctrl_select = self .select .get_ctrl_system (CtrlSpec ())
317+ (ctrl ,), select_out_soqs_t = add_ctrl_select (
318+ bb ,
319+ [soqs .pop ('ctrl' )],
320+ dict (
321+ selection = prepare_soqs .pop ("selection" ),
322+ system = cast (Soquet , bb .add (evolve (be_part , partition = False ), ** be_system_soqs )),
323+ ),
324+ )
325+ select_out_soqs = dict (zip (["selection" , "system" ], select_out_soqs_t ))
326+ select_out_soqs ["ctrl" ] = ctrl
327+
306328 prep_adj_soqs = bb .add_d (
307329 self .prepare .adjoint (), selection = select_out_soqs .pop ("selection" ), ** prepare_soqs
308330 )
@@ -311,6 +333,9 @@ def build_composite_bloq(
311333 be_soqs = bb .add_d (be_part , x = select_out_soqs .pop ("system" ))
312334 out : Dict [str , SoquetT ] = {"system" : be_soqs .pop ("system" )}
313335
336+ if self .is_controlled :
337+ out ["ctrl" ] = select_out_soqs .pop ("ctrl" )
338+
314339 # merge ancilla registers of block encoding and Prepare oracle
315340 anc_soqs = {"selection" : prep_adj_soqs .pop ("selection" )}
316341 if self .be_ancilla_bitsize > 0 :
@@ -331,6 +356,20 @@ def build_composite_bloq(
331356 def __str__ (self ) -> str :
332357 return f"B[{ '+' .join (str (be )[2 :- 1 ] for be in self .signed_block_encodings )} ]"
333358
359+ def get_ctrl_system (self , ctrl_spec : 'CtrlSpec' ) -> tuple ['Bloq' , 'AddControlledT' ]:
360+ from qualtran .bloqs .mcmt .specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs
361+
362+ return get_ctrl_system_1bit_cv_from_bloqs (
363+ self ,
364+ ctrl_spec ,
365+ current_ctrl_bit = 1 if self .is_controlled else None ,
366+ bloq_with_ctrl = evolve (self , is_controlled = True ),
367+ ctrl_reg_name = 'ctrl' ,
368+ )
369+
370+ def adjoint (self ) -> 'LinearCombination' :
371+ return self
372+
334373
335374@bloq_example
336375def _linear_combination_block_encoding () -> LinearCombination :
0 commit comments