Skip to content

Commit a68adcd

Browse files
authored
Compensate for global phase in MatrixGate decomposition (#7118)
* Change `assert_allclose` to `equal_upto_global_phase` in test, due to MatrixGate decomposition. * Add compensating global phase op if needed to MatrixGate decomposition. * Fix ionq test * address PR comments * Extract phase_delta
1 parent 5a81b3d commit a68adcd

8 files changed

Lines changed: 64 additions & 19 deletions

File tree

cirq-core/cirq/linalg/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,16 @@
8282
from cirq.linalg.transformations import (
8383
apply_matrix_to_slices as apply_matrix_to_slices,
8484
density_matrix_kronecker_product as density_matrix_kronecker_product,
85+
can_numpy_support_shape as can_numpy_support_shape,
8586
match_global_phase as match_global_phase,
8687
partial_trace as partial_trace,
8788
partial_trace_of_state_vector_as_mixture as partial_trace_of_state_vector_as_mixture,
89+
phase_delta as phase_delta,
8890
reflection_matrix_pow as reflection_matrix_pow,
8991
state_vector_kronecker_product as state_vector_kronecker_product,
9092
sub_state_vector as sub_state_vector,
9193
targeted_conjugate_about as targeted_conjugate_about,
9294
targeted_left_multiply as targeted_left_multiply,
9395
to_special as to_special,
9496
transpose_flattened_array as transpose_flattened_array,
95-
can_numpy_support_shape as can_numpy_support_shape,
9697
)

cirq-core/cirq/linalg/transformations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,3 +814,17 @@ def _can_numpy_support_dims(num_dims: int) -> bool:
814814
def can_numpy_support_shape(shape: Sequence[int]) -> bool:
815815
"""Returns whether numpy supports the given shape or not numpy/numpy#5744."""
816816
return min(shape, default=0) >= 0 and _can_numpy_support_dims(len(shape))
817+
818+
819+
def phase_delta(u1: np.ndarray, u2: np.ndarray) -> complex:
820+
"""Calculates the phase delta of two unitaries.
821+
822+
The delta is from u1 to u2. i.e. u1 * phase_delta(u1, u2) == u2.
823+
824+
Assumes but does not verify that inputs are valid unitaries and differ only
825+
by phase.
826+
"""
827+
# All cells will have the same phase difference. Just choose the cell with the largest
828+
# absolute value, to minimize rounding error.
829+
max_index = np.unravel_index(np.abs(u1).argmax(), u1.shape)
830+
return u2[max_index] / u1[max_index]

cirq-core/cirq/linalg/transformations_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,11 @@ def test_transpose_flattened_array(num_dimensions):
653653
@pytest.mark.parametrize('shape, result', [((), True), (30 * (1,), True), ((-3, 1, -1), False)])
654654
def test_can_numpy_support_shape(shape: tuple[int, ...], result: bool) -> None:
655655
assert linalg.can_numpy_support_shape(shape) is result
656+
657+
658+
@pytest.mark.parametrize('coeff', [1, 1j, -1, -1j, 1j**0.5, 1j**0.3])
659+
def test_phase_delta(coeff):
660+
u1 = cirq.testing.random_unitary(4)
661+
u2 = u1 * coeff
662+
np.testing.assert_almost_equal(linalg.phase_delta(u1, u2), coeff)
663+
np.testing.assert_almost_equal(u1 * linalg.phase_delta(u1, u2), u2)

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
controlled_operation as cop,
3535
diagonal_gate as dg,
3636
global_phase_op as gp,
37-
matrix_gates,
3837
op_tree,
3938
raw_types,
4039
)
@@ -220,12 +219,6 @@ def _decompose_with_context_(
220219
control_qid_shape=self.control_qid_shape,
221220
).on(*control_qubits)
222221
return [result, controlled_phase_op]
223-
224-
if isinstance(self.sub_gate, matrix_gates.MatrixGate):
225-
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
226-
# local phase in the controlled variant and hence cannot be ignored.
227-
return NotImplemented
228-
229222
result = protocols.decompose_once_with_qubits(
230223
self.sub_gate,
231224
qubits[self.num_controls() :],

cirq-core/cirq/ops/controlled_gate_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,23 @@ def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_targ
431431
@pytest.mark.parametrize(
432432
'gate',
433433
[
434+
cirq.I,
435+
cirq.GlobalPhaseGate(1),
436+
cirq.GlobalPhaseGate(-1),
437+
cirq.GlobalPhaseGate(1j),
434438
cirq.GlobalPhaseGate(1j**0.7),
439+
cirq.Z,
435440
cirq.ZPowGate(exponent=1.2, global_shift=0.3),
441+
cirq.CZ,
436442
cirq.CZPowGate(exponent=1.2, global_shift=0.3),
443+
cirq.CCZ,
437444
cirq.CCZPowGate(exponent=1.2, global_shift=0.3),
445+
cirq.X,
446+
cirq.XPowGate(exponent=1.2, global_shift=0.3),
447+
cirq.CX,
448+
cirq.CXPowGate(exponent=1.2, global_shift=0.3),
449+
cirq.CCX,
450+
cirq.CCXPowGate(exponent=1.2, global_shift=0.3),
438451
],
439452
)
440453
@pytest.mark.parametrize(
@@ -476,10 +489,9 @@ def _test_controlled_gate_is_consistent(
476489
shape = cirq.qid_shape(cgate)
477490
qids = cirq.LineQid.for_qid_shape(shape)
478491
decomposed = cirq.decompose(cgate.on(*qids))
479-
if len(decomposed) < 1000: # CCCCCZ rounding error explodes
480-
first_op = cirq.IdentityGate(qid_shape=shape).on(*qids) # To ensure same qid order
481-
circuit = cirq.Circuit(first_op, *decomposed)
482-
np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-1)
492+
first_op = cirq.IdentityGate(qid_shape=shape).on(*qids) # To ensure same qid order
493+
circuit = cirq.Circuit(first_op, *decomposed)
494+
np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-13)
483495

484496

485497
def test_pow_inverse():

cirq-core/cirq/ops/matrix_gates.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
"""Quantum gates defined by a matrix."""
1616

17-
from typing import Any, Dict, Iterable, Optional, Tuple, TYPE_CHECKING
17+
from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING
1818

1919
import numpy as np
2020

2121
from cirq import _import, linalg, protocols
2222
from cirq._compat import proper_repr
23-
from cirq.ops import phased_x_z_gate, raw_types
23+
from cirq.ops import global_phase_op, identity, phased_x_z_gate, raw_types
2424

2525
if TYPE_CHECKING:
2626
import cirq
@@ -148,18 +148,34 @@ def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'MatrixGate':
148148
return MatrixGate(matrix=result.reshape(self._matrix.shape), qid_shape=self._qid_shape)
149149

150150
def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> 'cirq.OP_TREE':
151+
from cirq.circuits import Circuit
152+
153+
decomposed: List['cirq.Operation'] = NotImplemented
151154
if self._qid_shape == (2,):
152-
return [
155+
decomposed = [
153156
g.on(qubits[0])
154157
for g in single_qubit_decompositions.single_qubit_matrix_to_gates(self._matrix)
155158
]
156159
if self._qid_shape == (2,) * 2:
157-
return two_qubit_to_cz.two_qubit_matrix_to_cz_operations(
160+
decomposed = two_qubit_to_cz.two_qubit_matrix_to_cz_operations(
158161
*qubits, self._matrix, allow_partial_czs=True
159162
)
160163
if self._qid_shape == (2,) * 3:
161-
return three_qubit_decomposition.three_qubit_matrix_to_operations(*qubits, self._matrix)
162-
return NotImplemented
164+
decomposed = three_qubit_decomposition.three_qubit_matrix_to_operations(
165+
*qubits, self._matrix
166+
)
167+
if decomposed is NotImplemented:
168+
return NotImplemented
169+
# The above algorithms ignore phase, but phase is important to maintain if the gate is
170+
# controlled. Here, we add it back in with a global phase op.
171+
ident = identity.IdentityGate(qid_shape=self._qid_shape).on(*qubits) # Preserve qid order
172+
u = protocols.unitary(Circuit(ident, *decomposed)).reshape(self._matrix.shape)
173+
phase_delta = linalg.phase_delta(u, self._matrix)
174+
# Phase delta is on the complex unit circle, so if real(phase_delta) >= 1, that means
175+
# no phase delta. (>1 is rounding error).
176+
if phase_delta.real < 1:
177+
decomposed.append(global_phase_op.global_phase_operation(phase_delta))
178+
return decomposed
163179

164180
def _has_unitary_(self) -> bool:
165181
return True

cirq-core/cirq/transformers/analytical_decompositions/three_qubit_decomposition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def three_qubit_matrix_to_operations(
2626
q0: ops.Qid, q1: ops.Qid, q2: ops.Qid, u: np.ndarray, atol: float = 1e-8
27-
) -> Sequence[ops.Operation]:
27+
) -> List[ops.Operation]:
2828
"""Returns operations for a 3 qubit unitary.
2929
3030
The algorithm is described in Shende et al.:

cirq-ionq/cirq_ionq/ionq_gateset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, *, atol: float = 1e-8):
5151
cirq.YYPowGate,
5252
cirq.ZZPowGate,
5353
cirq.MeasurementGate,
54+
cirq.GlobalPhaseGate,
5455
unroll_circuit_op=False,
5556
)
5657
self.atol = atol

0 commit comments

Comments
 (0)