2525 bloq_example ,
2626 BloqBuilder ,
2727 BloqDocSpec ,
28+ CtrlSpec ,
2829 DecomposeTypeError ,
2930 QAny ,
3031 QBit ,
3738from qualtran .bloqs .block_encoding import BlockEncoding
3839from qualtran .bloqs .bookkeeping .auto_partition import AutoPartition , Unused
3940from qualtran .bloqs .bookkeeping .partition import Partition
40- from qualtran .bloqs .mcmt import MultiControlX
4141from qualtran .bloqs .reflections .prepare_identity import PrepareIdentity
4242from qualtran .bloqs .state_preparation .black_box_prepare import BlackBoxPrepare
4343from qualtran .resource_counting import BloqCountDictT , SympySymbolAllocator
4444from qualtran .resource_counting .generalizers import ignore_split_join
45- from qualtran .symbolics import HasLength , is_symbolic , prod , smax , ssum , SymbolicFloat , SymbolicInt
45+ from qualtran .symbolics import is_symbolic , prod , smax , ssum , SymbolicFloat , SymbolicInt
4646from qualtran .symbolics .math_funcs import is_zero
4747
4848
@@ -171,15 +171,45 @@ def constituents(self) -> Sequence[Bloq]:
171171 ret .append (AutoPartition (u , partition , left_only = False ))
172172 return ret
173173
174+ def _multCX (self , bitsize ) -> Bloq :
175+ return XGate ().controlled (ctrl_spec = CtrlSpec (QAny (bitsize ), cvs = 0 ))
176+
177+ def _multCX_autopart (self , * , used_bits : int , total_bits : int ) -> Bloq :
178+ if used_bits <= 0 :
179+ raise ValueError ("used_bits must be > 0" )
180+ if used_bits > total_bits :
181+ raise ValueError (f"{ used_bits = } cannot exceed { total_bits = } " )
182+
183+ ctrl_parts = (
184+ ["ctrl" , Unused (total_bits - used_bits )] if total_bits > used_bits else ["ctrl" ]
185+ )
186+ return AutoPartition (
187+ self ._multCX (used_bits ),
188+ partitions = [
189+ (Register ("ctrl" , QAny (total_bits )), ctrl_parts ),
190+ (Register ("q" , QBit ()), ["q" ]),
191+ ],
192+ )
193+
174194 def build_call_graph (self , ssa : SympySymbolAllocator ) -> BloqCountDictT :
175195 counts = Counter [Bloq ]()
176196 for bloq in self .constituents :
177197 counts [bloq ] += 1
178198 n = len (self .block_encodings )
179199 for i , u in enumerate (reversed (self .block_encodings )):
180200 if not is_zero (u .ancilla_bitsize ) and n - 1 > 0 and i != n - 1 :
181- counts [MultiControlX (HasLength (u .ancilla_bitsize ))] += 1
201+ anc_bits = self .ancilla_bitsize - (n - 1 )
202+ if not is_symbolic (u .ancilla_bitsize ):
203+ counts [
204+ self ._multCX_autopart (used_bits = u .ancilla_bitsize , total_bits = anc_bits )
205+ ] += 1
206+ else :
207+ counts [self ._multCX (u .ancilla_bitsize )] += 1
182208 counts [XGate ()] += 1
209+
210+ if not is_symbolic (self .ancilla_bitsize ):
211+ counts [self .anc_part ] += 1
212+ counts [self .anc_part .adjoint ()] += 1
183213 return counts
184214
185215 def build_composite_bloq (
@@ -226,17 +256,12 @@ def build_composite_bloq(
226256
227257 # set corresponding flag if ancillas are all zero
228258 if u .ancilla_bitsize > 0 and n - 1 > 0 and i != n - 1 :
229- controls = bb .split (cast (Soquet , anc_soq ))
230259 # flag_bits_soq will always be assigned based on the following assertion
231260 assert self .ancilla_bitsize > 0
232261 # pylint: disable=used-before-assignment
233- controls [: u .ancilla_bitsize ], flag_bits_soq [i ] = bb .add_t ( # type: ignore[assignment]
234- MultiControlX (tuple ([0 ] * u .ancilla_bitsize )),
235- controls = controls [: u .ancilla_bitsize ],
236- target = flag_bits_soq [i ],
237- )
262+ MultCX = self ._multCX_autopart (used_bits = u .ancilla_bitsize , total_bits = anc_bits )
263+ anc_soq , flag_bits_soq [i ] = bb .add (MultCX , ctrl = anc_soq , q = flag_bits_soq [i ])
238264 flag_bits_soq [i ] = bb .add (XGate (), q = flag_bits_soq [i ])
239- anc_soq = bb .join (controls )
240265
241266 out = {"system" : system }
242267 if self .resource_bitsize > 0 :
@@ -263,6 +288,17 @@ def _product_block_encoding() -> Product:
263288 return product_block_encoding
264289
265290
291+ @bloq_example ()
292+ def _product_block_encoding_with_ancillas () -> Product :
293+ from qualtran .bloqs .basic_gates import Hadamard , TGate
294+ from qualtran .bloqs .block_encoding .unitary import Unitary
295+
296+ product_block_encoding = Product (
297+ (Unitary (TGate (), ancilla_bitsize = 3 ), Unitary (Hadamard (), ancilla_bitsize = 3 ))
298+ )
299+ return product_block_encoding
300+
301+
266302@bloq_example
267303def _product_block_encoding_properties () -> Product :
268304 from qualtran .bloqs .basic_gates import Hadamard , TGate
0 commit comments