Skip to content

Commit 1ee71bf

Browse files
maxglickmpharrigan
andauthored
Add a phased classical action for SelectedMajoranaFermion (#1778)
Add a phased classical action for SelectedMajoranaFermion. See #1699 for details. The classical action only exists for some choices of `target_gate`, and we assume specifically `target_gate=cirq.X` or `target_gate=cirq.Z`. We also assume that there is only 1 control register and 1 selection register. --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent bd97492 commit 1ee71bf

3 files changed

Lines changed: 93 additions & 2 deletions

File tree

qualtran/bloqs/multiplexers/selected_majorana_fermion.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Iterator, Sequence, Tuple, Union
16+
from typing import Dict, Iterator, Sequence, Tuple, Union
1717

1818
import attrs
1919
import cirq
@@ -25,6 +25,7 @@
2525
from qualtran._infra.data_types import BQUInt
2626
from qualtran._infra.gate_with_registers import total_bits
2727
from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate
28+
from qualtran.simulation.classical_sim import ClassicalValT
2829

2930

3031
@attrs.frozen
@@ -137,5 +138,54 @@ def nth_operation( # type: ignore[override]
137138
yield self.target_gate(target[target_idx]).controlled_by(control)
138139
yield cirq.CZ(*accumulator, target[target_idx])
139140

141+
def on_classical_vals(self, **vals) -> Dict[str, 'ClassicalValT']:
142+
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
143+
return NotImplemented
144+
if len(self.control_registers) != 1 or len(self.selection_registers) != 1:
145+
return NotImplemented
146+
control_name = self.control_registers[0].name
147+
control = vals[control_name]
148+
selection_name = self.selection_registers[0].name
149+
selection = vals[selection_name]
150+
target = vals['target']
151+
152+
# When target_gate == cirq.X, flip the selection-th bit in target. The ith bit of a
153+
# size N regirster is addressed with the unsigned integer 2^(N - 1 - i) in our big
154+
# endian convention.
155+
if control and self.target_gate == cirq.X:
156+
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
157+
target = (2 ** (max_selection - selection)) ^ target
158+
# When target_gate == cirq.Z, the action is only in the phase.
159+
160+
return {control_name: control, selection_name: selection, 'target': target}
161+
162+
def basis_state_phase(self, **vals) -> Union[complex, None]:
163+
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
164+
return None
165+
if len(self.control_registers) != 1 or len(self.selection_registers) != 1:
166+
return None
167+
control_name = self.control_registers[0].name
168+
control = vals[control_name]
169+
selection_name = self.selection_registers[0].name
170+
selection = vals[selection_name]
171+
target = vals['target']
172+
if control:
173+
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
174+
# This gate applies Z in positions 0 through (selection - 1). The effect is
175+
# a phase of plus or minus 1 depending on the parity of the number of ones
176+
# in those positions. For an N-bit big endien integer, the first j bits can
177+
# be isolated by shifting right by N - j.
178+
#
179+
# The target gate X has no additional phase, so calculate as in the
180+
# previous paragraph.
181+
if self.target_gate == cirq.X:
182+
num_phases = (target >> (max_selection - selection + 1)).bit_count()
183+
# The taget gate Z is applied in position selection, so consider the full
184+
# range 0 through selection.
185+
else:
186+
num_phases = (target >> (max_selection - selection)).bit_count()
187+
return 1 if (num_phases % 2) == 0 else -1
188+
return 1
189+
140190
def __str__(self):
141191
return f'SelectedMajoranaFermion({self.target_gate})'

qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
2121
from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion
2222
from qualtran.cirq_interop.testing import GateHelper
23-
from qualtran.testing import assert_valid_bloq_decomposition
23+
from qualtran.testing import (
24+
assert_consistent_phased_classical_action,
25+
assert_valid_bloq_decomposition,
26+
)
2427

2528

2629
@pytest.mark.slow
@@ -148,3 +151,14 @@ def test_selected_majorana_fermion_gate_make_on():
148151
op = gate.on_registers(**get_named_qubits(gate.signature))
149152
op2 = SelectedMajoranaFermion.make_on(target_gate=cirq.X, **get_named_qubits(gate.signature))
150153
assert op == op2
154+
155+
156+
@pytest.mark.parametrize("selection_bitsize, target_bitsize", [(2, 4), (3, 5)])
157+
@pytest.mark.parametrize("target_gate", [cirq.X, cirq.Z])
158+
def test_selected_majorana_fermion_classical_action(selection_bitsize, target_bitsize, target_gate):
159+
gate = SelectedMajoranaFermion(
160+
Register('selection', BQUInt(selection_bitsize, target_bitsize)), target_gate=target_gate
161+
)
162+
assert_consistent_phased_classical_action(
163+
gate, selection=range(target_bitsize), target=range(2**target_bitsize), control=range(2)
164+
)

qualtran/testing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Side,
4040
)
4141
from qualtran._infra.composite_bloq import _get_flat_dangling_soqs
42+
from qualtran.simulation.classical_sim import do_phased_classical_simulation
4243
from qualtran.symbolics import is_symbolic
4344

4445
if TYPE_CHECKING:
@@ -714,3 +715,29 @@ def assert_consistent_classical_action(
714715
np.testing.assert_equal(
715716
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
716717
)
718+
719+
720+
def assert_consistent_phased_classical_action(
721+
bloq: Bloq,
722+
**parameter_ranges: Union[NDArray, Sequence[int], Sequence[Union[Sequence[int], NDArray]]],
723+
):
724+
"""Check that the bloq has a phased classical action consistent with its decomposition.
725+
726+
Args:
727+
bloq: bloq to test.
728+
parameter_ranges: named arguments giving ranges for each of the registers of the bloq.
729+
"""
730+
cb = bloq.decompose_bloq()
731+
parameter_names = tuple(parameter_ranges.keys())
732+
for vals in itertools.product(*[parameter_ranges[p] for p in parameter_names]):
733+
call_with = {p: v for p, v in zip(parameter_names, vals)}
734+
bloq_res, bloq_phase = do_phased_classical_simulation(bloq, call_with)
735+
decomposed_res, decomposed_phase = do_phased_classical_simulation(cb, call_with)
736+
np.testing.assert_equal(
737+
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
738+
)
739+
np.testing.assert_equal(
740+
bloq_phase,
741+
decomposed_phase,
742+
err_msg=f'{bloq=} {call_with=} {bloq_phase=} {decomposed_phase=}',
743+
)

0 commit comments

Comments
 (0)