Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@
import itertools
import warnings
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)

import cirq
import numpy as np
Expand Down Expand Up @@ -319,7 +330,7 @@ def _ensure_in_reg_exists(bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QR
soqs_to_join[qreg.qubits[0]] = soq
elif len(in_reg_qubits) == 1 and qreg.qubits and qreg.qubits[0] in in_reg_qubits:
# Cast single QBit registers to the appropriate single-bit register dtype.
err_msg = "Found non-QBit type register which shouldn't happen: " f"{soq}"
err_msg = f"Found non-QBit type register which shouldn't happen: {soq}"
assert isinstance(soq.dtype, QBit), err_msg
if not isinstance(in_reg.dtype, QBit):
qreg_to_qvar[in_reg] = bb.add(Cast(QBit(), in_reg.dtype), reg=soq)
Expand Down Expand Up @@ -465,6 +476,7 @@ def cirq_optree_to_cbloq(
signature: Optional[Signature] = None,
in_quregs: Optional[Dict[str, 'CirqQuregT']] = None,
out_quregs: Optional[Dict[str, 'CirqQuregT']] = None,
op_conversion_method: Optional[Callable[[cirq.Operation], Bloq]] = None,
) -> CompositeBloq:
"""Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`.

Expand Down Expand Up @@ -495,6 +507,17 @@ def cirq_optree_to_cbloq(

Any qubit in `optree` which is not part of `in_quregs` and `out_quregs` is considered to be
allocated & deallocated inside the CompositeBloq and does not show up in it's signature.

Args:
optree: A Cirq OP_TREE (e.g. a circuit or list of operations).
signature: The signature of the resulting CompositeBloq. If not provided, a default
signature with one thru-register named "qubits" is used.
in_quregs: Mapping from register names to arrays of cirq qubits for LEFT registers.
out_quregs: Mapping from register names to arrays of cirq qubits for RIGHT registers.
op_conversion_method: An optional callable that takes a ``cirq.Operation`` and returns
a ``Bloq``. If provided, this is used instead of the default ``_extract_bloq_from_op``
to convert each operation. This allows callers to attach custom metadata (e.g.
routing costs) to bloqs during conversion.
"""
circuit = cirq.Circuit(optree)
if signature is None:
Expand Down Expand Up @@ -533,7 +556,10 @@ def cirq_optree_to_cbloq(

# 2. Add each operation to the composite Bloq.
for op in circuit.all_operations():
bloq = _extract_bloq_from_op(op)
if op_conversion_method is not None:
bloq = op_conversion_method(op)
else:
bloq = _extract_bloq_from_op(op)
if bloq.signature == Signature([]):
bb.add(bloq)
continue
Expand Down
30 changes: 30 additions & 0 deletions qualtran/cirq_interop/_cirq_to_bloq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,33 @@ def test_cirq_gate_cost_via_decomp():

gc_swappow = get_cost_value(swappow_bloq, QECGatesCost())
assert gc_swappow == GateCounts(clifford=5, rotation=1, and_bloq=1, measurement=1)


def test_cirq_optree_to_cbloq_op_conversion_method():
"""Test the op_conversion_method parameter of cirq_optree_to_cbloq.

When provided, op_conversion_method should be called for each operation
instead of the default _extract_bloq_from_op.
"""
qubits = cirq.LineQubit.range(3)
circuit = cirq.Circuit(cirq.H(qubits[0]), cirq.CNOT(qubits[0], qubits[1]), cirq.T(qubits[2]))

# Track which operations were converted
converted_ops: list[cirq.Operation] = []

def custom_converter(op: cirq.Operation) -> Bloq:
converted_ops.append(op)
# Fall back to the default behavior.
from qualtran.cirq_interop._cirq_to_bloq import _extract_bloq_from_op

return _extract_bloq_from_op(op)

cbloq = cirq_optree_to_cbloq(circuit, op_conversion_method=custom_converter)

# Verify the custom converter was called for each operation.
assert len(converted_ops) == 3

# The resulting CompositeBloq should still produce the correct unitary.
bloq_unitary = cbloq.tensor_contract()
cirq_unitary = circuit.unitary(qubits)
np.testing.assert_allclose(cirq_unitary, bloq_unitary, atol=1e-8)
Loading