|
14 | 14 |
|
15 | 15 | """Quantum gates defined by a matrix.""" |
16 | 16 |
|
17 | | -from typing import Any, Dict, Iterable, Optional, Tuple, TYPE_CHECKING |
| 17 | +from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING |
18 | 18 |
|
19 | 19 | import numpy as np |
20 | 20 |
|
21 | 21 | from cirq import _import, linalg, protocols |
22 | 22 | from cirq._compat import proper_repr |
23 | | -from cirq.ops import phased_x_z_gate, raw_types |
| 23 | +from cirq.ops import global_phase_op, identity, phased_x_z_gate, raw_types |
24 | 24 |
|
25 | 25 | if TYPE_CHECKING: |
26 | 26 | import cirq |
@@ -148,18 +148,34 @@ def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'MatrixGate': |
148 | 148 | return MatrixGate(matrix=result.reshape(self._matrix.shape), qid_shape=self._qid_shape) |
149 | 149 |
|
150 | 150 | def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> 'cirq.OP_TREE': |
| 151 | + from cirq.circuits import Circuit |
| 152 | + |
| 153 | + decomposed: List['cirq.Operation'] = NotImplemented |
151 | 154 | if self._qid_shape == (2,): |
152 | | - return [ |
| 155 | + decomposed = [ |
153 | 156 | g.on(qubits[0]) |
154 | 157 | for g in single_qubit_decompositions.single_qubit_matrix_to_gates(self._matrix) |
155 | 158 | ] |
156 | 159 | if self._qid_shape == (2,) * 2: |
157 | | - return two_qubit_to_cz.two_qubit_matrix_to_cz_operations( |
| 160 | + decomposed = two_qubit_to_cz.two_qubit_matrix_to_cz_operations( |
158 | 161 | *qubits, self._matrix, allow_partial_czs=True |
159 | 162 | ) |
160 | 163 | if self._qid_shape == (2,) * 3: |
161 | | - return three_qubit_decomposition.three_qubit_matrix_to_operations(*qubits, self._matrix) |
162 | | - return NotImplemented |
| 164 | + decomposed = three_qubit_decomposition.three_qubit_matrix_to_operations( |
| 165 | + *qubits, self._matrix |
| 166 | + ) |
| 167 | + if decomposed is NotImplemented: |
| 168 | + return NotImplemented |
| 169 | + # The above algorithms ignore phase, but phase is important to maintain if the gate is |
| 170 | + # controlled. Here, we add it back in with a global phase op. |
| 171 | + ident = identity.IdentityGate(qid_shape=self._qid_shape).on(*qubits) # Preserve qid order |
| 172 | + u = protocols.unitary(Circuit(ident, *decomposed)).reshape(self._matrix.shape) |
| 173 | + phase_delta = linalg.phase_delta(u, self._matrix) |
| 174 | + # Phase delta is on the complex unit circle, so if real(phase_delta) >= 1, that means |
| 175 | + # no phase delta. (>1 is rounding error). |
| 176 | + if phase_delta.real < 1: |
| 177 | + decomposed.append(global_phase_op.global_phase_operation(phase_delta)) |
| 178 | + return decomposed |
163 | 179 |
|
164 | 180 | def _has_unitary_(self) -> bool: |
165 | 181 | return True |
|
0 commit comments