Skip to content

Commit 61548e0

Browse files
authored
Fix decompose for controlled CZ gates with phase shift (#7071)
* Fix decompose for controlled gates with phase shift * Fix test * Fix type check, int != complex * Remove decomposition to Z * Fix param name * Fix controlled op qasm, reformat tests * Fix test * Fix test
1 parent 2e5c8a2 commit 61548e0

5 files changed

Lines changed: 94 additions & 17 deletions

File tree

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030

3131
from cirq import protocols, value, _import
3232
from cirq.ops import (
33-
raw_types,
33+
control_values as cv,
3434
controlled_operation as cop,
35-
op_tree,
35+
diagonal_gate as dg,
36+
global_phase_op as gp,
3637
matrix_gates,
37-
control_values as cv,
38+
op_tree,
39+
raw_types,
3840
)
3941

4042
if TYPE_CHECKING:
@@ -157,13 +159,13 @@ def _decompose_(
157159
def _decompose_with_context_(
158160
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
159161
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
162+
control_qubits = list(qubits[: self.num_controls()])
160163
if (
161164
protocols.has_unitary(self.sub_gate)
162165
and protocols.num_qubits(self.sub_gate) == 1
163166
and self._qid_shape_() == (2,) * len(self._qid_shape_())
164167
and isinstance(self.control_values, cv.ProductOfSums)
165168
):
166-
control_qubits = list(qubits[: self.num_controls()])
167169
invert_ops: List['cirq.Operation'] = []
168170
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
169171
if set(cvals) == {0}:
@@ -174,11 +176,20 @@ def _decompose_with_context_(
174176
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
175177
)
176178
return invert_ops + decomposed_ops + invert_ops
177-
179+
if isinstance(self.sub_gate, gp.GlobalPhaseGate):
180+
# A controlled global phase is a diagonal gate, where each active control value index
181+
# is set equal to the phase angle.
182+
shape = self.control_qid_shape
183+
if protocols.is_parameterized(self.sub_gate) or set(shape) != {2}:
184+
# Could work in theory, but DiagonalGate decompose does not support them.
185+
return NotImplemented
186+
angle = np.angle(complex(self.sub_gate.coefficient))
187+
rads = np.zeros(shape=shape)
188+
for hot in self.control_values.expand():
189+
rads[hot] = angle
190+
return dg.DiagonalGate(diag_angles_radians=[*rads.flatten()]).on(*qubits)
178191
if isinstance(self.sub_gate, common_gates.CZPowGate):
179-
z_sub_gate = common_gates.ZPowGate(
180-
exponent=self.sub_gate.exponent, global_shift=self.sub_gate.global_shift
181-
)
192+
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
182193
num_controls = self.num_controls() + 1
183194
control_values = self.control_values & cv.ProductOfSums(((1,),))
184195
control_qid_shape = self.control_qid_shape + (2,)
@@ -197,9 +208,18 @@ def _decompose_with_context_(
197208
)
198209
)
199210
if self != controlled_z:
200-
return protocols.decompose_once_with_qubits(
201-
controlled_z, qubits, NotImplemented, context=context
202-
)
211+
result = controlled_z.on(*qubits)
212+
if self.sub_gate.global_shift == 0:
213+
return result
214+
# Reconstruct the controlled global shift of the subgate.
215+
total_shift = self.sub_gate.exponent * self.sub_gate.global_shift
216+
phase_gate = gp.GlobalPhaseGate(1j ** (2 * total_shift))
217+
controlled_phase_op = phase_gate.controlled(
218+
num_controls=self.num_controls(),
219+
control_values=self.control_values,
220+
control_qid_shape=self.control_qid_shape,
221+
).on(*control_qubits)
222+
return [result, controlled_phase_op]
203223

204224
if isinstance(self.sub_gate, matrix_gates.MatrixGate):
205225
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
@@ -328,7 +348,7 @@ def __str__(self) -> str:
328348
return str(self.control_values) + str(self.sub_gate)
329349

330350
def __repr__(self) -> str:
331-
if self.num_controls() == 1 and self.control_values.is_trivial:
351+
if self.control_qid_shape == (2,) and self.control_values.is_trivial:
332352
return f'cirq.ControlledGate(sub_gate={self.sub_gate!r})'
333353

334354
if self.control_values.is_trivial and set(self.control_qid_shape) == {2}:

cirq-core/cirq/ops/controlled_gate_test.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from types import NotImplementedType
16-
from typing import Union, Tuple, cast
16+
from typing import Any, cast, Optional, Sequence, Tuple, Union
1717

1818
import numpy as np
1919
import pytest
@@ -408,6 +408,11 @@ def test_unitary():
408408
),
409409
True,
410410
),
411+
(cirq.GlobalPhaseGate(-1), True),
412+
(cirq.GlobalPhaseGate(1j**0.7), True),
413+
(cirq.GlobalPhaseGate(sympy.Symbol("s")), False),
414+
(cirq.CZPowGate(exponent=1.2, global_shift=0.3), True),
415+
(cirq.CZPowGate(exponent=sympy.Symbol("s"), global_shift=0.3), False),
411416
# Single qudit gate with dimension 4.
412417
(cirq.MatrixGate(np.kron(*(cirq.unitary(cirq.H),) * 2), qid_shape=(4,)), False),
413418
(cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), False),
@@ -420,11 +425,61 @@ def test_unitary():
420425
],
421426
)
422427
def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_target):
423-
cgate = cirq.ControlledGate(gate)
428+
_test_controlled_gate_is_consistent(gate, should_decompose_to_target)
429+
430+
431+
@pytest.mark.parametrize(
432+
'gate',
433+
[
434+
cirq.GlobalPhaseGate(1j**0.7),
435+
cirq.ZPowGate(exponent=1.2, global_shift=0.3),
436+
cirq.CZPowGate(exponent=1.2, global_shift=0.3),
437+
cirq.CCZPowGate(exponent=1.2, global_shift=0.3),
438+
],
439+
)
440+
@pytest.mark.parametrize(
441+
'control_qid_shape, control_values, should_decompose_to_target',
442+
[
443+
([2, 2], None, True),
444+
([2, 2], xor_control_values, False),
445+
([3], None, False),
446+
([3, 4], xor_control_values, False),
447+
],
448+
)
449+
def test_nontrivial_controlled_gate_is_consistent(
450+
gate: cirq.Gate,
451+
control_qid_shape: Sequence[int],
452+
control_values: Any,
453+
should_decompose_to_target: bool,
454+
):
455+
_test_controlled_gate_is_consistent(
456+
gate, should_decompose_to_target, control_qid_shape, control_values
457+
)
458+
459+
460+
def _test_controlled_gate_is_consistent(
461+
gate: cirq.Gate,
462+
should_decompose_to_target: bool,
463+
control_qid_shape: Optional[Sequence[int]] = None,
464+
control_values: Any = None,
465+
):
466+
cgate = cirq.ControlledGate(
467+
gate, control_qid_shape=control_qid_shape, control_values=control_values
468+
)
424469
cirq.testing.assert_implements_consistent_protocols(cgate)
425470
cirq.testing.assert_decompose_ends_at_default_gateset(
426471
cgate, ignore_known_gates=not should_decompose_to_target
427472
)
473+
# The above only decompose once, which doesn't check that the sub-gate's phase is handled.
474+
# We need to check full decomposition here.
475+
if not cirq.is_parameterized(gate):
476+
shape = cirq.qid_shape(cgate)
477+
qids = cirq.LineQid.for_qid_shape(shape)
478+
decomposed = cirq.decompose(cgate.on(*qids))
479+
if len(decomposed) < 1000: # CCCCCZ rounding error explodes
480+
first_op = cirq.IdentityGate(qid_shape=shape).on(*qids) # To ensure same qid order
481+
circuit = cirq.Circuit(first_op, *decomposed)
482+
np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-1)
428483

429484

430485
def test_pow_inverse():

cirq-core/cirq/ops/controlled_operation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
212212
hasattr(self._sub_operation, "gate")
213213
and len(self._controls) == 1
214214
and self.control_values == cv.ProductOfSums(((1,),))
215+
and all(q.dimension == 2 for q in self.qubits)
215216
):
216217
gate = self.sub_operation.gate
217218
if (

cirq-core/cirq/ops/global_phase_op.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import cirq
2323
from cirq import value, protocols
24+
from cirq._compat import proper_repr
2425
from cirq.ops import raw_types, controlled_gate, control_values as cv
2526

2627

@@ -68,10 +69,10 @@ def __str__(self) -> str:
6869
return str(self.coefficient)
6970

7071
def __repr__(self) -> str:
71-
return f'cirq.GlobalPhaseGate({self.coefficient!r})'
72+
return f'cirq.GlobalPhaseGate({proper_repr(self.coefficient)})'
7273

7374
def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
74-
return f'cirq.global_phase_operation({self.coefficient!r})'
75+
return f'cirq.global_phase_operation({proper_repr(self.coefficient)})'
7576

7677
def _json_dict_(self) -> Dict[str, Any]:
7778
return protocols.obj_to_dict_helper(self, ['coefficient'])

cirq-core/cirq/ops/global_phase_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_init():
3333

3434

3535
def test_protocols():
36-
for p in [1, 1j, -1]:
36+
for p in [1, 1j, -1, sympy.Symbol('s')]:
3737
cirq.testing.assert_implements_consistent_protocols(cirq.global_phase_operation(p))
3838

3939
np.testing.assert_allclose(

0 commit comments

Comments
 (0)