Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
Expand Down Expand Up @@ -155,6 +156,14 @@ def assert_valid_val_array(self, val_array: NDArray, debug_str: str = 'val') ->
self.qdtype.assert_valid_classical_val(val)


@attrs.frozen
class ShapedQCDType:
qcdtype: 'QCDType'
shape: Tuple[int, ...] = attrs.field(
default=tuple(), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
)


class QCDType(Generic[T], metaclass=abc.ABCMeta):
"""The abstract interface for quantum/classical quantum computing data types."""

Expand Down Expand Up @@ -245,6 +254,10 @@ def iteration_length_or_zero(self) -> SymbolicInt:
# TODO: remove https://github.com/quantumlib/Qualtran/issues/1716
return getattr(self, 'iteration_length', 0)

def __getitem__(self, shape):
"""QInt(8)[20] returns a size-20 array of QInt(8)"""
return ShapedQCDType(qcdtype=self, shape=shape)

@classmethod
def _pkg_(cls):
return 'qualtran'
Expand Down
137 changes: 121 additions & 16 deletions qualtran/_infra/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union

import attrs
import sympy
from attrs import field, frozen

from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicInt

from .data_types import QAny, QBit, QCDType
from .data_types import QAny, QBit, QCDType, ShapedQCDType


class Side(enum.Flag):
Expand All @@ -53,6 +52,12 @@ def __repr__(self):
return f'{self.__class__.__name__}.{self._name_}'


def _consume_register_dtype(dtype: Union[QCDType, ShapedQCDType]) -> QCDType:
# In __attrs_post_init__, we actually handle the ShapedQCDType case, which isn't accounted
# for in attrs type checking.
return cast(QCDType, dtype)


@frozen
class Register:
"""A register serves as the input/output quantum data specifications in a bloq's `Signature`.
Expand All @@ -72,7 +77,7 @@ class Register:
"""

name: str
dtype: QCDType
dtype: QCDType = field(converter=_consume_register_dtype)
_shape: Tuple[SymbolicInt, ...] = field(
default=tuple(), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
)
Expand All @@ -86,6 +91,15 @@ def __repr__(self):
return f'Register({self.name!r}, dtype={self.dtype!r}, shape={self._shape!r}, side={self.side!r})'

def __attrs_post_init__(self):
if isinstance(self.dtype, ShapedQCDType):
if self._shape != ():
raise ValueError(
f"for Register {self.name}, use either a shaped dtype {self.dtype} "
f"or an explicit shape argument {self._shape}, not both."
)
object.__setattr__(self, '_shape', self.dtype.shape)
object.__setattr__(self, 'dtype', self.dtype.qcdtype)

if not isinstance(self.dtype, QCDType):
raise ValueError(f'dtype must be a QCDType: found {type(self.dtype)}')

Expand Down Expand Up @@ -193,13 +207,15 @@ def __init__(self, registers: Iterable[Register]):
self._rights = _dedupe((reg.name, reg) for reg in self._registers if reg.side & Side.RIGHT)

@classmethod
def build(cls, **registers: Union[int, sympy.Expr]) -> 'Signature':
"""Construct a Signature comprised of untyped thru registers of the given bitsizes.
def build(cls, *args, **kwargs) -> 'Signature':
"""Construct a Signature using a more natural syntax.

For rapid prorotyping or simple gates, this syntactic sugar can be used.
This builder constructs a `Signature` flexibly from a mix of types, positional elements,
and named keyword arguments. For rapid prototyping or simple gates, you can quickly define
registers without manually instantiating `Register` objects.

Examples:
The following constructors are equivalent
The following constructors are equivalent:

>>> sig1 = Signature.build(a=32, b=1)
>>> sig2 = Signature([
Expand All @@ -209,13 +225,104 @@ def build(cls, **registers: Union[int, sympy.Expr]) -> 'Signature':
>>> sig1 == sig2
True

We can also build signatures with fully instantiated `QCDType` arguments, including
shaped multidimensional registers:

>>> from qualtran import QBit, QUInt
>>> sig = Signature.build(ctrl=QBit()[5, 5], system=QUInt(32))
>>> sig == Signature([
... Register('ctrl', QBit(), shape=(5, 5)),
... Register('system', QUInt(32))
... ])
True

Left and Right registers can be specified with a 2-tuple `(LEFT, RIGHT)`.
Here, we allocate `b` as a right register.

>>> sig = Signature.build(a=(QBit(), QBit()), b=(None, QBit()))
>>> sig == Signature([
... Register('a', QBit(), side=Side.THRU),
... Register('b', QBit(), side=Side.RIGHT)
... ])
True

Positional arguments can be used to join previously defined components:

>>> sig1 = Signature.build(a=1)
>>> extra = [Register('c', QAny(5))]
>>> sig2 = Signature.build(sig1, extra)

Args:
**registers: Keyword arguments mapping register names to bitsizes. All registers
will be 0-dimensional, THRU, and of type QAny/QBit.
*args: Positional arguments must be instances of `Register`, `Signature`, or iterables
thereof, which will be concatenated in order of layout.
**kwargs: Keyword arguments mapping register names to data types or sizes.
Values can be integer bitsizes (where 1 maps to `QBit` and n to `QAny(n)`),
`QCDType` instances, `ShapedQCDType` instances, or 2-tuples of
`(left_dtype, right_dtype)` to explicitly specify sides.
"""
return cls(
Register(name=k, dtype=QBit() if v == 1 else QAny(v)) for k, v in registers.items() if v
)
if args and kwargs:
raise ValueError(
f"When using `Signature.build`, you must either specify a mapping "
f"from register names to data types or positional Signature and "
f"Register arguments, not both. Found positional {args} and keyword {kwargs}"
)

registers = []

def _flat_add(arg):
# add positional Signature, Register, or iterables thereof.
nonlocal registers
if isinstance(arg, Register):
registers.append(arg)
elif isinstance(arg, Signature):
registers.extend(arg)
elif isinstance(arg, Iterable) and not isinstance(arg, str):
for a2 in arg:
_flat_add(a2)
else:
raise ValueError(
f"Unknown type for positional argument to Signature.build: {arg!r}"
)
Comment thread
mpharrigan marked this conversation as resolved.

if args:
for arg in args:
_flat_add(arg)
return cls(registers)

for k, v in kwargs.items():
if not v:
continue

if isinstance(v, (QCDType, ShapedQCDType)):
registers.append(Register(name=k, dtype=v))
elif isinstance(v, tuple):
if len(v) != 2:
raise ValueError(
f"When using Signature.build with a tuple of data types, "
f"you must specify a tuple of length 2. For LEFT registers, "
f"the tuple is (dtype, None). For RIGHT registers, "
f"the tuple is (None, dtype). You provided {v}"
)
ldt, rdt = v
if ldt is not None:
registers.append(Register(name=k, dtype=ldt, side=Side.LEFT))
if rdt is not None:
registers.append(Register(name=k, dtype=rdt, side=Side.RIGHT))

elif isinstance(v, (Register, Signature)):
# mild defensiveness against common errors, but duck typing in the `else` clause.
raise ValueError(
f"Invalid data type for Signature.build keyword argument '{k}': {v}"
)
else:
dt: QCDType
if v == 1:
dt = QBit()
else:
dt = QAny(v)
registers.append(Register(name=k, dtype=dt))

return cls(registers)

@classmethod
def build_from_dtypes(cls, **registers: QCDType) -> 'Signature':
Expand Down Expand Up @@ -323,12 +430,10 @@ def __repr__(self):
return f'Signature({repr(self._registers)})'

@overload
def __getitem__(self, key: int) -> Register:
pass
def __getitem__(self, key: int) -> Register: ...

@overload
def __getitem__(self, key: slice) -> Tuple[Register, ...]:
pass
def __getitem__(self, key: slice) -> Tuple[Register, ...]: ...

def __getitem__(self, key):
return self._registers[key]
Expand Down
131 changes: 130 additions & 1 deletion qualtran/_infra/registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,87 @@ def test_signature_build():
sig1 = Signature([Register("r1", QAny(5)), Register("r2", QAny(2))])
sig2 = Signature.build(r1=5, r2=2)
assert sig1 == sig2
assert sig1.n_qubits() == 7
assert sig2.n_qubits() == 7


def test_signature_build_drops_falsey():
should_be = Signature([Register('x', QBit())])
assert Signature.build(x=1, y=0) == should_be
assert Signature.build(x=1, y=None) == should_be


def test_signature_build_dtypes():
should_be = Signature([Register('system', QUInt(8))])
assert Signature.build(system=QUInt(8)) == should_be


def test_signature_build_shaped():
should_be = Signature([Register('qubits', QBit(), shape=(5, 5))])
assert Signature.build(qubits=QBit()[5, 5]) == should_be

should_be = Signature([Register('ctrl', QBit()), Register('ints', QInt(8), shape=(5,))])
assert Signature.build(ctrl=1, ints=QInt(8)[5]) == should_be


def test_signature_build_sided():
should_be = Signature(
[Register('x_in', QAny(3), side=Side.LEFT), Register('x_out', QAny(3), side=Side.RIGHT)]
)
assert Signature.build(x_in=(QAny(3), None), x_out=(None, QAny(3))) == should_be


def test_signature_build_grouped_sided():
should_be = Signature(
[Register('x', QAny(3), side=Side.LEFT), Register('x', QBit(), shape=(3,), side=Side.RIGHT)]
)
assert Signature.build(x=(QAny(3), QBit()[3])) == should_be


def test_signature_build_signature():
should_be = Signature(
[Register('x', QAny(3), side=Side.LEFT), Register('x', QBit(), shape=(3,), side=Side.RIGHT)]
)
assert Signature.build(should_be) == should_be


def test_signature_build_registers():
should_be = Signature([Register('ctrl', QBit()), Register('system', QAny(5))])
assert Signature.build(Register('ctrl', QBit()), Register('system', QAny(5))) == should_be


def test_signature_build_signature_registers():
should_be = Signature(
[
Register('ctrl', QBit()),
Register('system', QAny(5)),
Register('x_in', QAny(3), side=Side.LEFT),
Register('x_out', QAny(3), side=Side.RIGHT),
Register('x', QBit()),
]
)

first_signature = Signature([Register('ctrl', QBit()), Register('system', QAny(5))])
regs = [Register('x_in', QAny(3), side=Side.LEFT), Register('x_out', QAny(3), side=Side.RIGHT)]
last_signature = Signature([Register('x', QBit())])
assert Signature.build(first_signature, regs, last_signature) == should_be


def test_signature_build_mixed_args_kwargs():
first_signature = Signature([Register('ctrl', QBit()), Register('system', QAny(5))])
with pytest.raises(ValueError, match=r'either.*not both.*'):
Signature.build(first_signature, y=QBit())


def test_signature_build_kwregs():
with pytest.raises(ValueError, match=r"Invalid data type.*'x'.*"):
Signature.build(x=Register('x', QBit()))


def test_signature_build_from_dtypes():
sig1 = Signature([Register("r1", QInt(7)), Register("r2", QBit())])
sig2 = Signature.build_from_dtypes(r1=QInt(7), r2=QBit())
assert sig1 == sig2

sig1 = Signature([Register("r1", QInt(7))])
sig2 = Signature.build_from_dtypes(r1=QInt(7), r2=QAny(0))
assert sig1 == sig2
Expand Down Expand Up @@ -235,3 +312,55 @@ def test_is_symbolic():
assert is_symbolic(r)
r = Register("my_reg", QAny(2), shape=sympy.symbols("x y"))
assert is_symbolic(r)


def test_register_pkg():
assert Register._pkg_() == 'qualtran'


def test_register_shape_error():
with pytest.raises(ValueError, match="use either a shaped dtype.*or an explicit shape"):
Register("my_reg", QBit()[2], shape=(2,))


def test_register_invalid_dtype():
with pytest.raises(ValueError, match="dtype must be a QCDType"):
Register("my_reg", 5) # type: ignore


def test_register_adjoint_side():
r2 = Register("my_reg", QBit(), side=Side.RIGHT)
assert r2.adjoint().side == Side.LEFT

r3 = Register("my_reg", QBit(), side=Side.LEFT)
assert r3.adjoint().side == Side.RIGHT


def test_signature_build_positional_errors():
with pytest.raises(ValueError, match="Unknown type for positional argument"):
Signature.build("not_a_register_or_signature")


def test_signature_build_tuple_error():
with pytest.raises(ValueError, match="you must specify a tuple of length 2"):
Signature.build(a=(QBit(),))


def test_signature_thru_registers_only():
sig = Signature.build(a=1)
assert sig.thru_registers_only
sig2 = Signature([Register('a', QBit(), side=Side.LEFT)])
assert not sig2.thru_registers_only


def test_signature_get_left_right():
sig = Signature([Register('a', QBit(), side=Side.LEFT), Register('b', QBit(), side=Side.RIGHT)])
assert sig.get_left('a').name == 'a'
assert sig.get_right('b').name == 'b'


def test_signature_contains_and_hash():
r = Register('a', QBit())
sig = Signature([r])
assert r in sig
assert hash(sig) == hash(Signature([Register('a', QBit())]))
Loading