Skip to content

Commit 7a4c037

Browse files
authored
Strict BloqAsCirqGate and CirqGateAsBloq (#1603)
* strict BloqAsCirqGate and CirqGateAsBloq * fixy * Warn when wrapping a bloq * let cirq handle _unitary_ still
1 parent 94d8026 commit 7a4c037

10 files changed

Lines changed: 149 additions & 91 deletions

File tree

qualtran/_infra/gate_with_registers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
319319

320320
return _wire_symbol_from_gate(self, self.signature, reg, idx)
321321

322-
# Part-2: Cirq-FT style interface can be used to implemented algorithms by Bloq authors.
322+
# Part-2: Cirq-FT style interface can be used to implement algorithms by Bloq authors.
323323

324324
def _num_qubits_(self) -> int:
325325
return total_bits(self.signature)

qualtran/bloqs/basic_gates/global_phase_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_unitary():
2929
for alpha in random_state.random(size=10):
3030
coefficient = np.exp(2j * np.pi * alpha)
3131
bloq = GlobalPhase(exponent=2 * alpha)
32-
np.testing.assert_allclose(cirq.unitary(bloq), coefficient)
32+
np.testing.assert_allclose(bloq.tensor_contract(), coefficient)
3333

3434

3535
@pytest.mark.parametrize("cv", [0, 1])

qualtran/bloqs/basic_gates/rotation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def adjoint(self) -> 'ZPowGate':
160160
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
161161
if reg is None:
162162
return Text('')
163-
return TextBox(str(self))
163+
return TextBox(f'Z^{self.exponent}')
164164

165165
def __str__(self):
166166
return f'Z**{self.exponent}'
@@ -302,7 +302,7 @@ def adjoint(self) -> 'XPowGate':
302302
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
303303
if reg is None:
304304
return Text('')
305-
return TextBox(str(self))
305+
return TextBox(f'X^{self.exponent}')
306306

307307
def __str__(self):
308308
return f'X**{self.exponent}'
@@ -376,7 +376,7 @@ def adjoint(self) -> 'YPowGate':
376376
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
377377
if reg is None:
378378
return Text('')
379-
return TextBox(str(self))
379+
return TextBox(f'Y^{self.exponent}')
380380

381381
def __str__(self):
382382
return f'Y**{self.exponent}'

qualtran/bloqs/phase_estimation/text_book_qpe_test.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
import numpy as np
1616
import pytest
1717

18+
import qualtran.testing as qlt_testing
1819
from qualtran import Signature
1920
from qualtran._infra.gate_with_registers import get_named_qubits
2021
from qualtran.bloqs.basic_gates import ZPowGate
22+
from qualtran.bloqs.chemistry.hubbard_model.qubitization import get_walk_operator_for_hubbard_model
2123
from qualtran.bloqs.for_testing.qubitization_walk_test import get_uniform_pauli_qubitized_walk
22-
from qualtran.bloqs.phase_estimation.lp_resource_state import LPResourceState
23-
from qualtran.bloqs.phase_estimation.qpe_window_state import RectangularWindowState
24-
from qualtran.bloqs.phase_estimation.text_book_qpe import TextbookQPE
24+
from qualtran.bloqs.hamiltonian_simulation.hamiltonian_simulation_by_gqsp import (
25+
HamiltonianSimulationByGQSP,
26+
)
27+
from qualtran.bloqs.phase_estimation import LPResourceState, RectangularWindowState, TextbookQPE
2528
from qualtran.cirq_interop.testing import GateHelper
2629

2730

@@ -109,3 +112,13 @@ def should_decompose(binst):
109112
# 5. Verify that the estimated phase is correct.
110113
phase = theta * 2 * np.pi
111114
np.testing.assert_allclose(eig_val / qubitization_lambda, np.cos(phase), atol=eps)
115+
116+
117+
def test_qpe_of_gqsp():
118+
# This triggered a bug in the cirq interop.
119+
# https://github.com/quantumlib/Qualtran/issues/1570
120+
121+
walk_op = get_walk_operator_for_hubbard_model(2, 2, 1, 1)
122+
hubbard_time_evolution_by_gqsp = HamiltonianSimulationByGQSP(walk_op, t=5, precision=1e-7)
123+
textbook_qpe_w_gqsp = TextbookQPE(hubbard_time_evolution_by_gqsp, RectangularWindowState(3))
124+
qlt_testing.assert_valid_bloq_decomposition(textbook_qpe_w_gqsp)

qualtran/bloqs/qsp/generalized_qsp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ class GeneralizedQSP(GateWithRegisters):
282282
Motlagh and Wiebe. (2023). Theorem 3; Figure 2; Theorem 6.
283283
"""
284284

285-
U: GateWithRegisters
285+
U: 'Bloq'
286286
P: Union[Tuple[complex, ...], Shaped] = field(converter=_to_tuple)
287287
Q: Union[Tuple[complex, ...], Shaped] = field(converter=_to_tuple)
288288
negative_power: SymbolicInt = field(default=0, kw_only=True)
@@ -302,7 +302,7 @@ def signature(self) -> Signature:
302302
@classmethod
303303
def from_qsp_polynomial(
304304
cls,
305-
U: GateWithRegisters,
305+
U: 'Bloq',
306306
P: Union[NDArray[np.number], Sequence[complex], Shaped],
307307
*,
308308
negative_power: SymbolicInt = 0,

qualtran/cirq_interop/_bloq_to_cirq.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414

1515
"""Qualtran Bloqs to Cirq gates/circuits conversion."""
1616

17-
from functools import cached_property
18-
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
17+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
1918

2019
import cirq
2120
import networkx as nx
2221
import numpy as np
23-
from numpy.typing import NDArray
2422

2523
from qualtran import (
2624
Bloq,
@@ -40,7 +38,6 @@
4038
_get_all_and_output_quregs_from_input,
4139
merge_qubits,
4240
split_qubits,
43-
total_bits,
4441
)
4542
from qualtran.cirq_interop._cirq_to_bloq import _QReg, CirqQuregInT, CirqQuregT
4643
from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager
@@ -58,7 +55,9 @@ def _cirq_style_decompose_from_decompose_bloq(
5855
# Input qubits can get de-allocated by cbloq.to_cirq_circuit_and_quregs, thus mark them as managed.
5956
qm = InteropQubitManager(context.qubit_manager)
6057
qm.manage_qubits(merge_qubits(bloq.signature.lefts(), **in_quregs))
61-
circuit, out_quregs = cbloq.to_cirq_circuit_and_quregs(qubit_manager=qm, **in_quregs)
58+
circuit, out_quregs = _cbloq_to_cirq_circuit(
59+
cbloq.signature, in_quregs, cbloq._binst_graph, qubit_manager=qm
60+
)
6261
qubit_map = {q: q for q in circuit.all_qubits()}
6362
for reg in bloq.signature.rights():
6463
if reg.side == Side.RIGHT:
@@ -93,11 +92,6 @@ def bloq(self) -> Bloq:
9392
"""The bloq we're wrapping."""
9493
return self._bloq
9594

96-
@cached_property
97-
def signature(self) -> Signature:
98-
"""`GateWithRegisters` registers."""
99-
return self.bloq.signature
100-
10195
@classmethod
10296
def bloq_on(
10397
cls, bloq: Bloq, cirq_quregs: Dict[str, 'CirqQuregT'], qubit_manager: cirq.QubitManager # type: ignore[type-var]
@@ -120,15 +114,16 @@ def bloq_on(
120114
all_quregs, out_quregs = _get_all_and_output_quregs_from_input(
121115
bloq.signature, qubit_manager, in_quregs=cirq_quregs
122116
)
123-
return BloqAsCirqGate(bloq=bloq).on_registers(**all_quregs), out_quregs
117+
cirq_op = BloqAsCirqGate(bloq=bloq).on(*merge_qubits(bloq.signature, **all_quregs))
118+
return cirq_op, out_quregs
124119

125120
def _num_qubits_(self) -> int:
126-
return total_bits(self.signature)
121+
return self.bloq.signature.n_qubits()
127122

128123
def _decompose_with_context_(
129124
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None
130125
) -> cirq.OP_TREE:
131-
quregs = split_qubits(self.signature, qubits)
126+
quregs = split_qubits(self.bloq.signature, qubits)
132127
if context is None:
133128
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())
134129
try:
@@ -142,23 +137,24 @@ def _decompose_with_context_(
142137
def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE:
143138
return self._decompose_with_context_(qubits)
144139

140+
def _has_unitary_(self):
141+
return all(reg.side == Side.THRU for reg in self.bloq.signature)
142+
145143
def _unitary_(self):
146-
if all(reg.side == Side.THRU for reg in self.signature):
144+
if all(reg.side == Side.THRU for reg in self.bloq.signature):
147145
try:
148-
_ = self.bloq.decompose_bloq() # check for decomposability
146+
# If decomposable, return NotImplemented to let the cirq protocol
147+
# try its decomposition-based strategies.
148+
_ = self.bloq.decompose_bloq()
149149
return NotImplemented
150150
except (DecomposeNotImplementedError, DecomposeTypeError):
151151
tensor = self.bloq.tensor_contract()
152-
if tensor.ndim != 2:
153-
return NotImplemented
152+
assert tensor.ndim == 2, "All registers should have been checked to be THRU."
154153
return tensor
154+
except NotImplementedError:
155+
return NotImplemented
155156
return NotImplemented
156157

157-
def on_registers(
158-
self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] # type: ignore[type-var]
159-
) -> cirq.Operation:
160-
return self.on(*merge_qubits(self.signature, **qubit_regs))
161-
162158
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
163159
"""Draw cirq diagrams.
164160

qualtran/cirq_interop/_bloq_to_cirq_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from attrs import frozen
2020

2121
from qualtran import Bloq, BloqBuilder, ConnectionT, Signature, Soquet, SoquetT
22-
from qualtran._infra.gate_with_registers import get_named_qubits
22+
from qualtran._infra.gate_with_registers import get_named_qubits, merge_qubits
2323
from qualtran.bloqs.basic_gates import Toffoli, XGate, YGate
2424
from qualtran.bloqs.cryptography.rsa import ModExp
2525
from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd
@@ -235,8 +235,8 @@ def test_bloq_as_cirq_gate_for_mod_exp():
235235
mod_exp = ModExp.make_for_shor(4, 3)
236236
gate = BloqAsCirqGate(mod_exp)
237237
# Use Cirq's infrastructure to construct an operation and corresponding decomposition.
238-
quregs = get_named_qubits(gate.signature)
239-
op = gate.on_registers(**quregs)
238+
quregs = get_named_qubits(mod_exp.signature)
239+
op = gate.on(*merge_qubits(mod_exp.signature, **quregs))
240240
# cirq.decompose_once(op) delegates to underlying Bloq's decomposition specified in
241241
# `bloq.decompose_bloq()` and wraps resulting composite bloq in a Cirq op-tree. Note
242242
# how `BloqAsCirqGate.decompose_with_registers()` automatically takes care of mapping
@@ -266,7 +266,7 @@ def test_bloq_as_cirq_gate_for_mod_exp():
266266
decomposed_circuit, out_regs = cbloq.to_cirq_circuit_and_quregs(exponent=quregs['exponent'])
267267
# Whereas when directly applying a cirq gate on qubits to get an operations, we need to
268268
# specify both input and output registers.
269-
circuit = cirq.Circuit(gate.on_registers(**out_regs), decomposed_circuit)
269+
circuit = cirq.Circuit(gate.on(*merge_qubits(cbloq.signature, **out_regs)), decomposed_circuit)
270270
# Notice the newly allocated qubits _C(0) and _C(1) for output register x.
271271
cirq.testing.assert_has_diagram(
272272
circuit,

qualtran/cirq_interop/_cirq_to_bloq.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import abc
1717
import itertools
1818
import numbers
19+
import warnings
1920
from functools import cached_property
2021
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union
2122

@@ -70,17 +71,28 @@ def _get_cirq_quregs(signature: Signature, qm: InteropQubitManager):
7071
return ret
7172

7273

73-
class CirqGateAsBloqBase(GateWithRegisters, metaclass=abc.ABCMeta):
74-
"""A Bloq wrapper around a `cirq.Gate`"""
74+
class CirqGateAsBloqBase(Bloq, metaclass=abc.ABCMeta):
75+
"""A base class to bootstrap a bloq from a `cirq.Gate`.
76+
77+
Bloq authors can inherit from this abstract class and override the `cirq_gate` property
78+
to get a bloq adapted from the cirq gate. Authors can continue to customize the bloq
79+
by overriding methods (like costs, string representations, ...).
80+
81+
Otherwise, this class fulfils the Bloq API by delegating to `cirq.Gate` methods.
82+
83+
This is the base class that provides the functionality for the `CirqGateAsBloq` adapter.
84+
The adapter lets you use any `cirq.Gate` as a bloq immediately (without defining a new class
85+
that inherits from `CirqGateAsBloqBase`), and is used as a fallback in the interoperability
86+
functionality.
87+
"""
7588

7689
@property
7790
@abc.abstractmethod
78-
def cirq_gate(self) -> cirq.Gate: ...
91+
def cirq_gate(self) -> cirq.Gate:
92+
"""The `cirq.Gate` to use as the source of truth."""
7993

8094
@cached_property
8195
def signature(self) -> 'Signature':
82-
if isinstance(self.cirq_gate, Bloq):
83-
return self.cirq_gate.signature
8496
nqubits = cirq.num_qubits(self.cirq_gate)
8597
if nqubits == 1:
8698
return Signature([Register('q', QBit())])
@@ -89,14 +101,13 @@ def signature(self) -> 'Signature':
89101
# else
90102
return Signature([Register('q', QBit(), shape=nqubits)])
91103

104+
def decompose_bloq(self) -> 'CompositeBloq':
105+
return decompose_from_cirq_style_method(self)
106+
92107
def decompose_from_registers(
93108
self, *, context: cirq.DecompositionContext, **quregs: CirqQuregT
94109
) -> cirq.OP_TREE:
95-
op = (
96-
self.cirq_gate.on_registers(**quregs)
97-
if isinstance(self.cirq_gate, GateWithRegisters)
98-
else self.cirq_gate.on(*quregs.get('q', np.array(())).flatten())
99-
)
110+
op = self.cirq_gate.on(*quregs.get('q', np.array(())).flatten())
100111
try:
101112
return cirq.decompose_once(op)
102113
except TypeError as e:
@@ -112,37 +123,43 @@ def my_tensors(
112123
def as_cirq_op(
113124
self, qubit_manager: 'cirq.QubitManager', **in_quregs: 'CirqQuregT'
114125
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
115-
if isinstance(self.cirq_gate, GateWithRegisters):
116-
return self.cirq_gate.as_cirq_op(qubit_manager, **in_quregs)
117126
qubits = in_quregs.get('q', np.array([])).flatten()
118127
return self.cirq_gate.on(*qubits), in_quregs
119128

120-
# Delegate all cirq-style protocols to underlying gate
121-
def _unitary_(self):
122-
return cirq.unitary(self.cirq_gate, default=None)
123-
124-
def _circuit_diagram_info_(
125-
self, args: cirq.CircuitDiagramInfoArgs
126-
) -> Optional[cirq.CircuitDiagramInfo]:
127-
return cirq.circuit_diagram_info(self.cirq_gate, default=None)
128-
129-
def __str__(self):
130-
return str(self.cirq_gate)
131-
132129
def __pow__(self, power):
133130
return CirqGateAsBloq(gate=cirq.pow(self.cirq_gate, power))
134131

135132
def adjoint(self) -> 'Bloq':
136133
return CirqGateAsBloq(gate=cirq.inverse(self.cirq_gate))
137134

135+
def __str__(self):
136+
return f'cirq.{self.cirq_gate}'
137+
138138

139139
@frozen
140140
class CirqGateAsBloq(CirqGateAsBloqBase):
141+
"""An adapter that fulfils the Bloq API by delegating to `cirq.Gate` methods.
142+
143+
- The bloq's signature is one register named "q" of type QBit() with shape (n_qubits,) as
144+
determined by `cirq.num_qubits`.
145+
- Decomposition will go via `cirq.decompose_once`.
146+
- Tensor data is derived from `cirq.unitary`.
147+
- `as_cirq_op` will use the adapted cirq gate directly
148+
- Adjoint and exponentiation go via `cirq.inverse` and `cirq.pow`, respectively.
149+
- The string representation is "cirq.{gate}".
150+
151+
If you'd rather bootstrap your own bloq based on an existing `cirq.Gate`, you can inherit
152+
from `CirqGateAsBloqBase`."""
153+
141154
gate: cirq.Gate
142155

143-
def __str__(self) -> str:
144-
g = min(self.cirq_gate.__class__.__name__, str(self.cirq_gate), key=len)
145-
return f'cirq.{g}'
156+
def __attrs_post_init__(self):
157+
if isinstance(self.gate, GateWithRegisters):
158+
warnings.warn(
159+
f"Tried to use `CirqGateAsBloq` to adapt a `qualtran.GateWithRegisters`, "
160+
f"which already satisfies the Bloq API. Consider using {self.gate} "
161+
f"directly (without the adapter)."
162+
)
146163

147164
@property
148165
def cirq_gate(self) -> cirq.Gate:

0 commit comments

Comments
 (0)