Skip to content

Commit 9617c81

Browse files
committed
strict BloqAsCirqGate and CirqGateAsBloq
1 parent 1ed1444 commit 9617c81

9 files changed

Lines changed: 133 additions & 90 deletions

File tree

qualtran/_infra/gate_with_registers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
319319

320320
return _wire_symbol_from_gate(self, self.signature, reg, idx)
321321

322-
# Part-2: Cirq-FT style interface can be used to implemented algorithms by Bloq authors.
322+
# Part-2: Cirq-FT style interface can be used to implement algorithms by Bloq authors.
323323

324324
def _num_qubits_(self) -> int:
325325
return total_bits(self.signature)

qualtran/bloqs/basic_gates/global_phase_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_unitary():
2929
for alpha in random_state.random(size=10):
3030
coefficient = np.exp(2j * np.pi * alpha)
3131
bloq = GlobalPhase(exponent=2 * alpha)
32-
np.testing.assert_allclose(cirq.unitary(bloq), coefficient)
32+
np.testing.assert_allclose(bloq.tensor_contract(), coefficient)
3333

3434

3535
@pytest.mark.parametrize("cv", [0, 1])

qualtran/bloqs/basic_gates/rotation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def adjoint(self) -> 'ZPowGate':
160160
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
161161
if reg is None:
162162
return Text('')
163-
return TextBox(str(self))
163+
return TextBox(f'Z^{self.exponent}')
164164

165165
def __str__(self):
166166
return f'Z**{self.exponent}'
@@ -302,7 +302,7 @@ def adjoint(self) -> 'XPowGate':
302302
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
303303
if reg is None:
304304
return Text('')
305-
return TextBox(str(self))
305+
return TextBox(f'X^{self.exponent}')
306306

307307
def __str__(self):
308308
return f'X**{self.exponent}'
@@ -376,7 +376,7 @@ def adjoint(self) -> 'YPowGate':
376376
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
377377
if reg is None:
378378
return Text('')
379-
return TextBox(str(self))
379+
return TextBox(f'Y^{self.exponent}')
380380

381381
def __str__(self):
382382
return f'Y**{self.exponent}'

qualtran/bloqs/phase_estimation/text_book_qpe_test.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
import numpy as np
1616
import pytest
1717

18+
import qualtran.testing as qlt_testing
1819
from qualtran import Signature
1920
from qualtran._infra.gate_with_registers import get_named_qubits
2021
from qualtran.bloqs.basic_gates import ZPowGate
22+
from qualtran.bloqs.chemistry.hubbard_model.qubitization import get_walk_operator_for_hubbard_model
2123
from qualtran.bloqs.for_testing.qubitization_walk_test import get_uniform_pauli_qubitized_walk
22-
from qualtran.bloqs.phase_estimation.lp_resource_state import LPResourceState
23-
from qualtran.bloqs.phase_estimation.qpe_window_state import RectangularWindowState
24-
from qualtran.bloqs.phase_estimation.text_book_qpe import TextbookQPE
24+
from qualtran.bloqs.hamiltonian_simulation.hamiltonian_simulation_by_gqsp import (
25+
HamiltonianSimulationByGQSP,
26+
)
27+
from qualtran.bloqs.phase_estimation import LPResourceState, RectangularWindowState, TextbookQPE
2528
from qualtran.cirq_interop.testing import GateHelper
2629

2730

@@ -109,3 +112,13 @@ def should_decompose(binst):
109112
# 5. Verify that the estimated phase is correct.
110113
phase = theta * 2 * np.pi
111114
np.testing.assert_allclose(eig_val / qubitization_lambda, np.cos(phase), atol=eps)
115+
116+
117+
def test_qpe_of_gqsp():
118+
# This triggered a bug in the cirq interop.
119+
# https://github.com/quantumlib/Qualtran/issues/1570
120+
121+
walk_op = get_walk_operator_for_hubbard_model(2, 2, 1, 1)
122+
hubbard_time_evolution_by_gqsp = HamiltonianSimulationByGQSP(walk_op, t=5, precision=1e-7)
123+
textbook_qpe_w_gqsp = TextbookQPE(hubbard_time_evolution_by_gqsp, RectangularWindowState(3))
124+
qlt_testing.assert_valid_bloq_decomposition(textbook_qpe_w_gqsp)

qualtran/cirq_interop/_bloq_to_cirq.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414

1515
"""Qualtran Bloqs to Cirq gates/circuits conversion."""
1616

17-
from functools import cached_property
18-
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
17+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
1918

2019
import cirq
2120
import networkx as nx
2221
import numpy as np
23-
from numpy.typing import NDArray
2422

2523
from qualtran import (
2624
Bloq,
@@ -40,7 +38,6 @@
4038
_get_all_and_output_quregs_from_input,
4139
merge_qubits,
4240
split_qubits,
43-
total_bits,
4441
)
4542
from qualtran.cirq_interop._cirq_to_bloq import _QReg, CirqQuregInT, CirqQuregT
4643
from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager
@@ -58,7 +55,9 @@ def _cirq_style_decompose_from_decompose_bloq(
5855
# Input qubits can get de-allocated by cbloq.to_cirq_circuit_and_quregs, thus mark them as managed.
5956
qm = InteropQubitManager(context.qubit_manager)
6057
qm.manage_qubits(merge_qubits(bloq.signature.lefts(), **in_quregs))
61-
circuit, out_quregs = cbloq.to_cirq_circuit_and_quregs(qubit_manager=qm, **in_quregs)
58+
circuit, out_quregs = _cbloq_to_cirq_circuit(
59+
cbloq.signature, in_quregs, cbloq._binst_graph, qubit_manager=qm
60+
)
6261
qubit_map = {q: q for q in circuit.all_qubits()}
6362
for reg in bloq.signature.rights():
6463
if reg.side == Side.RIGHT:
@@ -93,11 +92,6 @@ def bloq(self) -> Bloq:
9392
"""The bloq we're wrapping."""
9493
return self._bloq
9594

96-
@cached_property
97-
def signature(self) -> Signature:
98-
"""`GateWithRegisters` registers."""
99-
return self.bloq.signature
100-
10195
@classmethod
10296
def bloq_on(
10397
cls, bloq: Bloq, cirq_quregs: Dict[str, 'CirqQuregT'], qubit_manager: cirq.QubitManager # type: ignore[type-var]
@@ -120,15 +114,16 @@ def bloq_on(
120114
all_quregs, out_quregs = _get_all_and_output_quregs_from_input(
121115
bloq.signature, qubit_manager, in_quregs=cirq_quregs
122116
)
123-
return BloqAsCirqGate(bloq=bloq).on_registers(**all_quregs), out_quregs
117+
cirq_op = BloqAsCirqGate(bloq=bloq).on(*merge_qubits(bloq.signature, **all_quregs))
118+
return cirq_op, out_quregs
124119

125120
def _num_qubits_(self) -> int:
126-
return total_bits(self.signature)
121+
return self.bloq.signature.n_qubits()
127122

128123
def _decompose_with_context_(
129124
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None
130125
) -> cirq.OP_TREE:
131-
quregs = split_qubits(self.signature, qubits)
126+
quregs = split_qubits(self.bloq.signature, qubits)
132127
if context is None:
133128
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())
134129
try:
@@ -142,23 +137,20 @@ def _decompose_with_context_(
142137
def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE:
143138
return self._decompose_with_context_(qubits)
144139

140+
def _has_unitary_(self):
141+
return all(reg.side == Side.THRU for reg in self.bloq.signature)
142+
145143
def _unitary_(self):
146-
if all(reg.side == Side.THRU for reg in self.signature):
144+
if all(reg.side == Side.THRU for reg in self.bloq.signature):
147145
try:
148-
_ = self.bloq.decompose_bloq() # check for decomposability
149-
return NotImplemented
150-
except (DecomposeNotImplementedError, DecomposeTypeError):
151146
tensor = self.bloq.tensor_contract()
152147
if tensor.ndim != 2:
153148
return NotImplemented
154149
return tensor
150+
except NotImplementedError:
151+
return NotImplemented
155152
return NotImplemented
156153

157-
def on_registers(
158-
self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] # type: ignore[type-var]
159-
) -> cirq.Operation:
160-
return self.on(*merge_qubits(self.signature, **qubit_regs))
161-
162154
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
163155
"""Draw cirq diagrams.
164156

qualtran/cirq_interop/_bloq_to_cirq_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from attrs import frozen
2020

2121
from qualtran import Bloq, BloqBuilder, ConnectionT, Signature, Soquet, SoquetT
22-
from qualtran._infra.gate_with_registers import get_named_qubits
22+
from qualtran._infra.gate_with_registers import get_named_qubits, merge_qubits
2323
from qualtran.bloqs.basic_gates import Toffoli, XGate, YGate
2424
from qualtran.bloqs.cryptography.rsa import ModExp
2525
from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd
@@ -235,8 +235,8 @@ def test_bloq_as_cirq_gate_for_mod_exp():
235235
mod_exp = ModExp.make_for_shor(4, 3)
236236
gate = BloqAsCirqGate(mod_exp)
237237
# Use Cirq's infrastructure to construct an operation and corresponding decomposition.
238-
quregs = get_named_qubits(gate.signature)
239-
op = gate.on_registers(**quregs)
238+
quregs = get_named_qubits(mod_exp.signature)
239+
op = gate.on(*merge_qubits(mod_exp.signature, **quregs))
240240
# cirq.decompose_once(op) delegates to underlying Bloq's decomposition specified in
241241
# `bloq.decompose_bloq()` and wraps resulting composite bloq in a Cirq op-tree. Note
242242
# how `BloqAsCirqGate.decompose_with_registers()` automatically takes care of mapping
@@ -266,7 +266,7 @@ def test_bloq_as_cirq_gate_for_mod_exp():
266266
decomposed_circuit, out_regs = cbloq.to_cirq_circuit_and_quregs(exponent=quregs['exponent'])
267267
# Whereas when directly applying a cirq gate on qubits to get an operations, we need to
268268
# specify both input and output registers.
269-
circuit = cirq.Circuit(gate.on_registers(**out_regs), decomposed_circuit)
269+
circuit = cirq.Circuit(gate.on(*merge_qubits(cbloq.signature, **out_regs)), decomposed_circuit)
270270
# Notice the newly allocated qubits _C(0) and _C(1) for output register x.
271271
cirq.testing.assert_has_diagram(
272272
circuit,

qualtran/cirq_interop/_cirq_to_bloq.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,28 @@ def _get_cirq_quregs(signature: Signature, qm: InteropQubitManager):
7070
return ret
7171

7272

73-
class CirqGateAsBloqBase(GateWithRegisters, metaclass=abc.ABCMeta):
74-
"""A Bloq wrapper around a `cirq.Gate`"""
73+
class CirqGateAsBloqBase(Bloq, metaclass=abc.ABCMeta):
74+
"""A base class to bootstrap a bloq from a `cirq.Gate`.
75+
76+
Bloq authors can inherit from this abstract class and override the `cirq_gate` property
77+
to get a bloq adapted from the cirq gate. Authors can continue to customize the bloq
78+
by overriding methods (like costs, string representations, ...).
79+
80+
Otherwise, this class fulfils the Bloq API by delegating to `cirq.Gate` methods.
81+
82+
This is the base class that provides the functionality for the `CirqGateAsBloq` adapter.
83+
The adapter lets you use any `cirq.Gate` as a bloq immediately (without defining a new class
84+
that inherits from `CirqGateAsBloqBase`), and is used as a fallback in the interoperability
85+
functionality.
86+
"""
7587

7688
@property
7789
@abc.abstractmethod
78-
def cirq_gate(self) -> cirq.Gate: ...
90+
def cirq_gate(self) -> cirq.Gate:
91+
"""The `cirq.Gate` to use as the source of truth."""
7992

8093
@cached_property
8194
def signature(self) -> 'Signature':
82-
if isinstance(self.cirq_gate, Bloq):
83-
return self.cirq_gate.signature
8495
nqubits = cirq.num_qubits(self.cirq_gate)
8596
if nqubits == 1:
8697
return Signature([Register('q', QBit())])
@@ -89,14 +100,13 @@ def signature(self) -> 'Signature':
89100
# else
90101
return Signature([Register('q', QBit(), shape=nqubits)])
91102

103+
def decompose_bloq(self) -> 'CompositeBloq':
104+
return decompose_from_cirq_style_method(self)
105+
92106
def decompose_from_registers(
93107
self, *, context: cirq.DecompositionContext, **quregs: CirqQuregT
94108
) -> cirq.OP_TREE:
95-
op = (
96-
self.cirq_gate.on_registers(**quregs)
97-
if isinstance(self.cirq_gate, GateWithRegisters)
98-
else self.cirq_gate.on(*quregs.get('q', np.array(())).flatten())
99-
)
109+
op = self.cirq_gate.on(*quregs.get('q', np.array(())).flatten())
100110
try:
101111
return cirq.decompose_once(op)
102112
except TypeError as e:
@@ -112,37 +122,35 @@ def my_tensors(
112122
def as_cirq_op(
113123
self, qubit_manager: 'cirq.QubitManager', **in_quregs: 'CirqQuregT'
114124
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
115-
if isinstance(self.cirq_gate, GateWithRegisters):
116-
return self.cirq_gate.as_cirq_op(qubit_manager, **in_quregs)
117125
qubits = in_quregs.get('q', np.array([])).flatten()
118126
return self.cirq_gate.on(*qubits), in_quregs
119127

120-
# Delegate all cirq-style protocols to underlying gate
121-
def _unitary_(self):
122-
return cirq.unitary(self.cirq_gate, default=None)
123-
124-
def _circuit_diagram_info_(
125-
self, args: cirq.CircuitDiagramInfoArgs
126-
) -> Optional[cirq.CircuitDiagramInfo]:
127-
return cirq.circuit_diagram_info(self.cirq_gate, default=None)
128-
129-
def __str__(self):
130-
return str(self.cirq_gate)
131-
132128
def __pow__(self, power):
133129
return CirqGateAsBloq(gate=cirq.pow(self.cirq_gate, power))
134130

135131
def adjoint(self) -> 'Bloq':
136132
return CirqGateAsBloq(gate=cirq.inverse(self.cirq_gate))
137133

134+
def __str__(self):
135+
return f'cirq.{self.cirq_gate}'
136+
138137

139138
@frozen
140139
class CirqGateAsBloq(CirqGateAsBloqBase):
141-
gate: cirq.Gate
140+
"""An adapter that fulfils the Bloq API by delegating to `cirq.Gate` methods.
141+
142+
- The bloq's signature is one register named "q" of type QBit() with shape (n_qubits,) as
143+
determined by `cirq.num_qubits`.
144+
- Decomposition will go via `cirq.decompose_once`.
145+
- Tensor data is derived from `cirq.unitary`.
146+
- `as_cirq_op` will use the adapted cirq gate directly
147+
- Adjoint and exponentiation go via `cirq.inverse` and `cirq.pow`, respectively.
148+
- The string representation is "cirq.{gate}".
142149
143-
def __str__(self) -> str:
144-
g = min(self.cirq_gate.__class__.__name__, str(self.cirq_gate), key=len)
145-
return f'cirq.{g}'
150+
If you'd rather bootstrap your own bloq based on an existing `cirq.Gate`, you can inherit
151+
from `CirqGateAsBloqBase`."""
152+
153+
gate: cirq.Gate
146154

147155
@property
148156
def cirq_gate(self) -> cirq.Gate:

0 commit comments

Comments
 (0)