Skip to content

Commit 8c3cdda

Browse files
petim0mpharrigan
andauthored
Fix Product's build_call_graph (#1810)
The old version of Product had a `build_call_graph` that was incorrect when the bloq encodings had ancillas, as the Multi_controlled_X was not consistent between the decomposition and the `build_call_graph`. This PR fixes that, I also took the time to change the split/join behavior to control the XGate to an `Autopartition` and I also made the symbolic decomposition work. This PR also adds tests. Before [output.pdf](https://github.com/user-attachments/files/25289067/output.pdf) Now [outputNow.pdf](https://github.com/user-attachments/files/25289118/outputNow.pdf) --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent d194ec5 commit 8c3cdda

2 files changed

Lines changed: 70 additions & 11 deletions

File tree

qualtran/bloqs/block_encoding/product.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
bloq_example,
2626
BloqBuilder,
2727
BloqDocSpec,
28+
CtrlSpec,
2829
DecomposeTypeError,
2930
QAny,
3031
QBit,
@@ -37,12 +38,11 @@
3738
from qualtran.bloqs.block_encoding import BlockEncoding
3839
from qualtran.bloqs.bookkeeping.auto_partition import AutoPartition, Unused
3940
from qualtran.bloqs.bookkeeping.partition import Partition
40-
from qualtran.bloqs.mcmt import MultiControlX
4141
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
4242
from qualtran.bloqs.state_preparation.black_box_prepare import BlackBoxPrepare
4343
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
4444
from qualtran.resource_counting.generalizers import ignore_split_join
45-
from qualtran.symbolics import HasLength, is_symbolic, prod, smax, ssum, SymbolicFloat, SymbolicInt
45+
from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicFloat, SymbolicInt
4646
from qualtran.symbolics.math_funcs import is_zero
4747

4848

@@ -171,15 +171,45 @@ def constituents(self) -> Sequence[Bloq]:
171171
ret.append(AutoPartition(u, partition, left_only=False))
172172
return ret
173173

174+
def _multCX(self, bitsize) -> Bloq:
175+
return XGate().controlled(ctrl_spec=CtrlSpec(QAny(bitsize), cvs=0))
176+
177+
def _multCX_autopart(self, *, used_bits: int, total_bits: int) -> Bloq:
178+
if used_bits <= 0:
179+
raise ValueError("used_bits must be > 0")
180+
if used_bits > total_bits:
181+
raise ValueError(f"{used_bits=} cannot exceed {total_bits=}")
182+
183+
ctrl_parts = (
184+
["ctrl", Unused(total_bits - used_bits)] if total_bits > used_bits else ["ctrl"]
185+
)
186+
return AutoPartition(
187+
self._multCX(used_bits),
188+
partitions=[
189+
(Register("ctrl", QAny(total_bits)), ctrl_parts),
190+
(Register("q", QBit()), ["q"]),
191+
],
192+
)
193+
174194
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
175195
counts = Counter[Bloq]()
176196
for bloq in self.constituents:
177197
counts[bloq] += 1
178198
n = len(self.block_encodings)
179199
for i, u in enumerate(reversed(self.block_encodings)):
180200
if not is_zero(u.ancilla_bitsize) and n - 1 > 0 and i != n - 1:
181-
counts[MultiControlX(HasLength(u.ancilla_bitsize))] += 1
201+
anc_bits = self.ancilla_bitsize - (n - 1)
202+
if not is_symbolic(u.ancilla_bitsize):
203+
counts[
204+
self._multCX_autopart(used_bits=u.ancilla_bitsize, total_bits=anc_bits)
205+
] += 1
206+
else:
207+
counts[self._multCX(u.ancilla_bitsize)] += 1
182208
counts[XGate()] += 1
209+
210+
if not is_symbolic(self.ancilla_bitsize):
211+
counts[self.anc_part] += 1
212+
counts[self.anc_part.adjoint()] += 1
183213
return counts
184214

185215
def build_composite_bloq(
@@ -226,17 +256,12 @@ def build_composite_bloq(
226256

227257
# set corresponding flag if ancillas are all zero
228258
if u.ancilla_bitsize > 0 and n - 1 > 0 and i != n - 1:
229-
controls = bb.split(cast(Soquet, anc_soq))
230259
# flag_bits_soq will always be assigned based on the following assertion
231260
assert self.ancilla_bitsize > 0
232261
# pylint: disable=used-before-assignment
233-
controls[: u.ancilla_bitsize], flag_bits_soq[i] = bb.add_t( # type: ignore[assignment]
234-
MultiControlX(tuple([0] * u.ancilla_bitsize)),
235-
controls=controls[: u.ancilla_bitsize],
236-
target=flag_bits_soq[i],
237-
)
262+
MultCX = self._multCX_autopart(used_bits=u.ancilla_bitsize, total_bits=anc_bits)
263+
anc_soq, flag_bits_soq[i] = bb.add(MultCX, ctrl=anc_soq, q=flag_bits_soq[i])
238264
flag_bits_soq[i] = bb.add(XGate(), q=flag_bits_soq[i])
239-
anc_soq = bb.join(controls)
240265

241266
out = {"system": system}
242267
if self.resource_bitsize > 0:
@@ -263,6 +288,17 @@ def _product_block_encoding() -> Product:
263288
return product_block_encoding
264289

265290

291+
@bloq_example()
292+
def _product_block_encoding_with_ancillas() -> Product:
293+
from qualtran.bloqs.basic_gates import Hadamard, TGate
294+
from qualtran.bloqs.block_encoding.unitary import Unitary
295+
296+
product_block_encoding = Product(
297+
(Unitary(TGate(), ancilla_bitsize=3), Unitary(Hadamard(), ancilla_bitsize=3))
298+
)
299+
return product_block_encoding
300+
301+
266302
@bloq_example
267303
def _product_block_encoding_properties() -> Product:
268304
from qualtran.bloqs.basic_gates import Hadamard, TGate

qualtran/bloqs/block_encoding/product_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_product_block_encoding,
3737
_product_block_encoding_properties,
3838
_product_block_encoding_symb,
39+
_product_block_encoding_with_ancillas,
3940
Product,
4041
)
4142
from qualtran.bloqs.block_encoding.unitary import Unitary
@@ -57,6 +58,9 @@ def test_product_signature():
5758
assert _product_block_encoding().signature == Signature(
5859
[Register("system", QAny(1)), Register("ancilla", QAny(1))]
5960
)
61+
assert _product_block_encoding_with_ancillas().signature == Signature(
62+
[Register("system", QAny(1)), Register("ancilla", QAny(4))]
63+
)
6064
assert _product_block_encoding_properties().signature == Signature(
6165
[Register("system", QAny(1)), Register("ancilla", QAny(3)), Register("resource", QAny(1))]
6266
)
@@ -113,6 +117,13 @@ def test_product_params():
113117
assert bloq.ancilla_bitsize == 1
114118
assert bloq.resource_bitsize == 0
115119

120+
bloq = _product_block_encoding_with_ancillas()
121+
assert bloq.system_bitsize == 1
122+
assert bloq.alpha == 1
123+
assert bloq.epsilon == 0
124+
assert bloq.ancilla_bitsize == 4
125+
assert bloq.resource_bitsize == 0
126+
116127
bloq = _product_block_encoding_properties()
117128
assert bloq.system_bitsize == 1
118129
assert bloq.alpha == 0.5 * 0.5
@@ -216,7 +227,19 @@ def test_product_signal_state():
216227

217228

218229
def test_product_counts():
219-
assert_equivalent_bloq_example_counts(_product_block_encoding)
230+
assert_equivalent_bloq_example_counts(_product_block_encoding_with_ancillas)
231+
232+
233+
def test_product_symbolic_call_graph_decomposes():
234+
from qualtran.bloqs.basic_gates import Hadamard, TGate
235+
from qualtran.bloqs.block_encoding.unitary import Unitary
236+
237+
product_block_encoding = Product(
238+
(
239+
Unitary(TGate(), ancilla_bitsize=sympy.symbols("anc")),
240+
Unitary(Hadamard(), ancilla_bitsize=sympy.symbols("anc")),
241+
)
242+
)
220243

221244

222245
@pytest.mark.notebook

0 commit comments

Comments
 (0)