1313# limitations under the License.
1414
1515from types import NotImplementedType
16- from typing import Union , Tuple , cast
16+ from typing import Any , cast , Optional , Sequence , Tuple , Union
1717
1818import numpy as np
1919import pytest
@@ -408,6 +408,11 @@ def test_unitary():
408408 ),
409409 True ,
410410 ),
411+ (cirq .GlobalPhaseGate (- 1 ), True ),
412+ (cirq .GlobalPhaseGate (1j ** 0.7 ), True ),
413+ (cirq .GlobalPhaseGate (sympy .Symbol ("s" )), False ),
414+ (cirq .CZPowGate (exponent = 1.2 , global_shift = 0.3 ), True ),
415+ (cirq .CZPowGate (exponent = sympy .Symbol ("s" ), global_shift = 0.3 ), False ),
411416 # Single qudit gate with dimension 4.
412417 (cirq .MatrixGate (np .kron (* (cirq .unitary (cirq .H ),) * 2 ), qid_shape = (4 ,)), False ),
413418 (cirq .MatrixGate (cirq .testing .random_unitary (4 , random_state = 1234 )), False ),
@@ -420,11 +425,61 @@ def test_unitary():
420425 ],
421426)
422427def test_controlled_gate_is_consistent (gate : cirq .Gate , should_decompose_to_target ):
423- cgate = cirq .ControlledGate (gate )
428+ _test_controlled_gate_is_consistent (gate , should_decompose_to_target )
429+
430+
431+ @pytest .mark .parametrize (
432+ 'gate' ,
433+ [
434+ cirq .GlobalPhaseGate (1j ** 0.7 ),
435+ cirq .ZPowGate (exponent = 1.2 , global_shift = 0.3 ),
436+ cirq .CZPowGate (exponent = 1.2 , global_shift = 0.3 ),
437+ cirq .CCZPowGate (exponent = 1.2 , global_shift = 0.3 ),
438+ ],
439+ )
440+ @pytest .mark .parametrize (
441+ 'control_qid_shape, control_values, should_decompose_to_target' ,
442+ [
443+ ([2 , 2 ], None , True ),
444+ ([2 , 2 ], xor_control_values , False ),
445+ ([3 ], None , False ),
446+ ([3 , 4 ], xor_control_values , False ),
447+ ],
448+ )
449+ def test_nontrivial_controlled_gate_is_consistent (
450+ gate : cirq .Gate ,
451+ control_qid_shape : Sequence [int ],
452+ control_values : Any ,
453+ should_decompose_to_target : bool ,
454+ ):
455+ _test_controlled_gate_is_consistent (
456+ gate , should_decompose_to_target , control_qid_shape , control_values
457+ )
458+
459+
460+ def _test_controlled_gate_is_consistent (
461+ gate : cirq .Gate ,
462+ should_decompose_to_target : bool ,
463+ control_qid_shape : Optional [Sequence [int ]] = None ,
464+ control_values : Any = None ,
465+ ):
466+ cgate = cirq .ControlledGate (
467+ gate , control_qid_shape = control_qid_shape , control_values = control_values
468+ )
424469 cirq .testing .assert_implements_consistent_protocols (cgate )
425470 cirq .testing .assert_decompose_ends_at_default_gateset (
426471 cgate , ignore_known_gates = not should_decompose_to_target
427472 )
473+ # The above only decompose once, which doesn't check that the sub-gate's phase is handled.
474+ # We need to check full decomposition here.
475+ if not cirq .is_parameterized (gate ):
476+ shape = cirq .qid_shape (cgate )
477+ qids = cirq .LineQid .for_qid_shape (shape )
478+ decomposed = cirq .decompose (cgate .on (* qids ))
479+ if len (decomposed ) < 1000 : # CCCCCZ rounding error explodes
480+ first_op = cirq .IdentityGate (qid_shape = shape ).on (* qids ) # To ensure same qid order
481+ circuit = cirq .Circuit (first_op , * decomposed )
482+ np .testing .assert_allclose (cirq .unitary (cgate ), cirq .unitary (circuit ), atol = 1e-1 )
428483
429484
430485def test_pow_inverse ():
0 commit comments