Skip to content

Commit 72e1d20

Browse files
authored
Clean up redundant complex type unions (#7041)
* Clean up redundant complex type unions * format * fix numpy ufunc, one mypy err * Additional cleanup
1 parent 14d61c8 commit 72e1d20

19 files changed

Lines changed: 50 additions & 58 deletions

cirq-core/cirq/interop/quirk/cells/parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def classify(e: str) -> Union[str, float]:
8686
return _merge_scientific_float_tokens(g for g in result if g.strip())
8787

8888

89-
_ResolvedToken = Union[sympy.Expr, int, float, complex]
89+
_ResolvedToken = Union[sympy.Expr, complex]
9090

9191

9292
class _CustomQuirkOperationToken:

cirq-core/cirq/linalg/combinators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from numpy.typing import DTypeLike, ArrayLike
2626

2727

28-
def kron(*factors: Union[np.ndarray, complex, float], shape_len: int = 2) -> np.ndarray:
28+
def kron(*factors: Union[np.ndarray, complex], shape_len: int = 2) -> np.ndarray:
2929
"""Computes the kronecker product of a sequence of values.
3030
3131
A *args version of lambda args: functools.reduce(np.kron, args).
@@ -56,7 +56,7 @@ def kron(*factors: Union[np.ndarray, complex, float], shape_len: int = 2) -> np.
5656
)
5757

5858

59-
def kron_with_controls(*factors: Union[np.ndarray, complex, float]) -> np.ndarray:
59+
def kron_with_controls(*factors: Union[np.ndarray, complex]) -> np.ndarray:
6060
"""Computes the kronecker product of a sequence of values and control tags.
6161
6262
Use `cirq.CONTROL_TAG` to represent controls. Any entry of the output

cirq-core/cirq/linalg/decompositions.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,7 @@ class AxisAngleDecomposition:
283283
rotation axis, and g is the global phase.
284284
"""
285285

286-
def __init__(
287-
self,
288-
*,
289-
angle: float,
290-
axis: Tuple[float, float, float],
291-
global_phase: Union[int, float, complex],
292-
):
286+
def __init__(self, *, angle: float, axis: Tuple[float, float, float], global_phase: complex):
293287
if not np.isclose(np.linalg.norm(axis, 2), 1, atol=1e-8):
294288
raise ValueError('Axis vector must be normalized.')
295289
self.global_phase = complex(global_phase)
@@ -634,7 +628,7 @@ def scatter_plot_normalized_kak_interaction_coefficients(
634628
ax = cast(mplot3d.axes3d.Axes3D, fig.add_subplot(1, 1, 1, projection='3d'))
635629

636630
def coord_transform(
637-
pts: Union[List[Tuple[int, int, int]], np.ndarray]
631+
pts: Union[List[Tuple[int, int, int]], np.ndarray],
638632
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
639633
if len(pts) == 0:
640634
return np.array([]), np.array([]), np.array([])

cirq-core/cirq/linalg/tolerance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def all_near_zero(a: 'ArrayLike', *, atol: float = 1e-8) -> bool:
3333

3434

3535
def all_near_zero_mod(
36-
a: Union[float, complex, Iterable[float], np.ndarray], period: float, *, atol: float = 1e-8
36+
a: Union[float, Iterable[float], np.ndarray], period: float, *, atol: float = 1e-8
3737
) -> bool:
3838
"""Checks if the tensor's elements are all near multiples of the period.
3939

cirq-core/cirq/ops/dense_pauli_string.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def pauli_mask(self) -> np.ndarray:
116116
return self._pauli_mask
117117

118118
@property
119-
def coefficient(self) -> Union[sympy.Expr, complex]:
119+
def coefficient(self) -> 'cirq.TParamValComplex':
120120
"""A complex coefficient or symbol."""
121121
return self._coefficient
122122

@@ -359,7 +359,7 @@ def __str__(self) -> str:
359359
coef = '+'
360360
elif self.coefficient == -1:
361361
coef = '-'
362-
elif isinstance(self.coefficient, (complex, sympy.Symbol)):
362+
elif isinstance(self.coefficient, (numbers.Complex, sympy.Symbol)):
363363
coef = f'{self.coefficient}*'
364364
else:
365365
coef = f'({self.coefficient})*'
@@ -403,7 +403,7 @@ def mutable_copy(self) -> 'MutableDensePauliString':
403403
@abc.abstractmethod
404404
def copy(
405405
self,
406-
coefficient: Optional[Union[sympy.Expr, int, float, complex]] = None,
406+
coefficient: Optional['cirq.TParamValComplex'] = None,
407407
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
408408
) -> Self:
409409
"""Returns a copy with possibly modified contents.
@@ -459,7 +459,7 @@ def frozen(self) -> 'DensePauliString':
459459

460460
def copy(
461461
self,
462-
coefficient: Optional[Union[sympy.Expr, int, float, complex]] = None,
462+
coefficient: Optional['cirq.TParamValComplex'] = None,
463463
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
464464
) -> 'DensePauliString':
465465
if pauli_mask is None and (coefficient is None or coefficient == self.coefficient):
@@ -559,7 +559,7 @@ def __imul__(self, other):
559559

560560
def copy(
561561
self,
562-
coefficient: Optional[Union[sympy.Expr, int, float, complex]] = None,
562+
coefficient: Optional['cirq.TParamValComplex'] = None,
563563
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
564564
) -> 'MutableDensePauliString':
565565
return MutableDensePauliString(

cirq-core/cirq/ops/eigen_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def _parameter_names_(self) -> AbstractSet[str]:
359359

360360
def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'EigenGate':
361361
exponent = resolver.value_of(self._exponent, recursive)
362-
if isinstance(exponent, (complex, numbers.Complex)):
362+
if isinstance(exponent, numbers.Complex):
363363
if isinstance(exponent, numbers.Real):
364364
exponent = float(exponent)
365365
else:

cirq-core/cirq/ops/identity.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""IdentityGate."""
1515

16+
import numbers
1617
from types import NotImplementedType
1718
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Sequence, Union
1819

@@ -72,7 +73,7 @@ def num_qubits(self) -> int:
7273
return len(self._qid_shape)
7374

7475
def __pow__(self, power: Any) -> Any:
75-
if isinstance(power, (int, float, complex, sympy.Basic)):
76+
if isinstance(power, (numbers.Complex, sympy.Basic)):
7677
return self
7778
return NotImplemented
7879

@@ -126,7 +127,7 @@ def _json_dict_(self) -> Dict[str, Any]:
126127
def _mul_with_qubits(self, qubits: Tuple['cirq.Qid', ...], other):
127128
if isinstance(other, raw_types.Operation):
128129
return other
129-
if isinstance(other, (complex, float, int)):
130+
if isinstance(other, numbers.Complex):
130131
from cirq.ops.pauli_string import PauliString
131132

132133
return PauliString(coefficient=other)

cirq-core/cirq/ops/linear_combinations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
UnitPauliStringT = FrozenSet[Tuple[raw_types.Qid, pauli_gates.Pauli]]
5050
PauliSumLike = Union[
51-
int, float, complex, PauliString, 'PauliSum', pauli_string.SingleQubitPauliStringGateOperation
51+
complex, PauliString, 'PauliSum', pauli_string.SingleQubitPauliStringGateOperation
5252
]
5353
document(
5454
PauliSumLike,

cirq-core/cirq/ops/pauli_string.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
Optional,
2929
overload,
3030
Sequence,
31-
SupportsComplex,
3231
Tuple,
3332
TYPE_CHECKING,
3433
TypeVar,
@@ -271,9 +270,7 @@ def __mul__(self, other: 'cirq.Operation') -> 'cirq.PauliString[Union[TKey, cirq
271270
pass
272271

273272
@overload
274-
def __mul__(
275-
self, other: Union[complex, int, float, numbers.Number]
276-
) -> 'cirq.PauliString[TKey]':
273+
def __mul__(self, other: complex) -> 'cirq.PauliString[TKey]':
277274
pass
278275

279276
def __mul__(self, other):
@@ -308,10 +305,9 @@ def gate(self) -> 'cirq.DensePauliString':
308305
)
309306

310307
def __rmul__(self, other) -> 'PauliString':
311-
if isinstance(other, numbers.Number):
308+
if isinstance(other, numbers.Complex):
312309
return PauliString(
313-
qubit_pauli_map=self._qubit_pauli_map,
314-
coefficient=self._coefficient * complex(cast(SupportsComplex, other)),
310+
qubit_pauli_map=self._qubit_pauli_map, coefficient=self._coefficient * other
315311
)
316312

317313
if isinstance(other, raw_types.Operation) and isinstance(other.gate, identity.IdentityGate):
@@ -321,10 +317,9 @@ def __rmul__(self, other) -> 'PauliString':
321317
return NotImplemented
322318

323319
def __truediv__(self, other):
324-
if isinstance(other, numbers.Number):
320+
if isinstance(other, numbers.Complex):
325321
return PauliString(
326-
qubit_pauli_map=self._qubit_pauli_map,
327-
coefficient=self._coefficient / complex(cast(SupportsComplex, other)),
322+
qubit_pauli_map=self._qubit_pauli_map, coefficient=self._coefficient / other
328323
)
329324
return NotImplemented
330325

@@ -518,7 +513,7 @@ def _unitary_(self) -> Optional[np.ndarray]:
518513
def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs'):
519514
if not self._has_unitary_():
520515
return None
521-
assert isinstance(self.coefficient, complex)
516+
assert isinstance(self.coefficient, numbers.Complex)
522517
if self.coefficient != 1:
523518
args.target_tensor *= self.coefficient
524519
return protocols.apply_unitaries([self[q].on(q) for q in self.qubits], self.qubits, args)
@@ -792,9 +787,11 @@ def __pos__(self) -> 'PauliString':
792787
return self
793788

794789
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
795-
"""Override behavior of numpy's exp method."""
790+
"""Override numpy behavior."""
796791
if ufunc == np.exp and len(inputs) == 1 and inputs[0] is self:
797792
return math.e**self
793+
if ufunc == np.multiply and len(inputs) == 2 and inputs[1] is self:
794+
return self * inputs[0]
798795
return NotImplemented
799796

800797
def __pow__(self, power):
@@ -1174,14 +1171,14 @@ def _as_pauli_string(self) -> PauliString:
11741171
def __mul__(self, other):
11751172
if isinstance(other, SingleQubitPauliStringGateOperation):
11761173
return self._as_pauli_string() * other._as_pauli_string()
1177-
if isinstance(other, (PauliString, complex, float, int)):
1174+
if isinstance(other, (PauliString, numbers.Complex)):
11781175
return self._as_pauli_string() * other
11791176
if (as_pauli_string := _try_interpret_as_pauli_string(other)) is not None:
11801177
return self * as_pauli_string
11811178
return NotImplemented
11821179

11831180
def __rmul__(self, other):
1184-
if isinstance(other, (PauliString, complex, float, int)):
1181+
if isinstance(other, (PauliString, numbers.Complex)):
11851182
return other * self._as_pauli_string()
11861183
if (as_pauli_string := _try_interpret_as_pauli_string(other)) is not None:
11871184
return as_pauli_string * self
@@ -1430,8 +1427,8 @@ def _imul_helper(self, other: 'cirq.PAULI_STRING_LIKE', sign: int):
14301427
pauli_int = _pauli_like_to_pauli_int(qubit, pauli_gate_like)
14311428
phase_log_i += self._imul_atom_helper(cast(TKey, qubit), pauli_int, sign)
14321429
self.coefficient *= 1j ** (phase_log_i & 3)
1433-
elif isinstance(other, numbers.Number):
1434-
self.coefficient *= complex(cast(SupportsComplex, other))
1430+
elif isinstance(other, numbers.Complex):
1431+
self.coefficient *= other
14351432
elif isinstance(other, raw_types.Operation) and isinstance(
14361433
other.gate, identity.IdentityGate
14371434
):

cirq-core/cirq/ops/pauli_string_phasor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,14 @@ def _resolve_parameters_(
392392
) -> 'PauliStringPhasorGate':
393393
exponent_neg = resolver.value_of(self.exponent_neg, recursive)
394394
exponent_pos = resolver.value_of(self.exponent_pos, recursive)
395-
if isinstance(exponent_neg, (complex, numbers.Complex)):
395+
if isinstance(exponent_neg, numbers.Complex):
396396
if isinstance(exponent_neg, numbers.Real):
397397
exponent_neg = float(exponent_neg)
398398
else:
399399
raise ValueError(
400400
f'PauliStringPhasorGate does not support complex exponent {exponent_neg}'
401401
)
402-
if isinstance(exponent_pos, (complex, numbers.Complex)):
402+
if isinstance(exponent_pos, numbers.Complex):
403403
if isinstance(exponent_pos, numbers.Real):
404404
exponent_pos = float(exponent_pos)
405405
else:

0 commit comments

Comments
 (0)