Skip to content

Commit 344d020

Browse files
authored
Simplify StatePreparationAliasSampling tests by relying on Cirq simulators to simulate operations allocating ancillas (#6204)
* Simplify StatePreparationAliasSampling tests by relying on Cirq simulators capability to simulate operations allocating ancilla qubits * Update unary iteration implementation to proactively allocate all required ancillas * Fix pylint errors
1 parent 83ede36 commit 344d020

3 files changed

Lines changed: 43 additions & 70 deletions

File tree

cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,33 @@
2222
@pytest.mark.parametrize("selection_bitsize, target_bitsize", [(2, 4), (3, 8), (4, 9)])
2323
@pytest.mark.parametrize("target_gate", [cirq.X, cirq.Y])
2424
def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, target_gate):
25-
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True)
2625
gate = cirq_ft.SelectedMajoranaFermionGate(
2726
cirq_ft.SelectionRegisters(
2827
[cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)]
2928
),
3029
target_gate=target_gate,
3130
)
32-
g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm))
31+
g = cirq_ft.testing.GateHelper(gate)
3332
assert len(g.all_qubits) <= gate.registers.total_bits() + selection_bitsize + 1
3433

3534
sim = cirq.Simulator(dtype=np.complex128)
3635
for n in range(target_bitsize):
3736
# Initial qubit values
38-
qubit_vals = {q: 0 for q in g.all_qubits}
37+
qubit_vals = {q: 0 for q in g.operation.qubits}
3938
# All controls 'on' to activate circuit
4039
qubit_vals.update({c: 1 for c in g.quregs['control']})
4140
# Set selection according to `n`
4241
qubit_vals.update(zip(g.quregs['selection'], iter_bits(n, selection_bitsize)))
4342

44-
initial_state = [qubit_vals[x] for x in g.all_qubits]
43+
initial_state = [qubit_vals[x] for x in g.operation.qubits]
4544

4645
result = sim.simulate(
47-
g.decomposed_circuit, initial_state=initial_state, qubit_order=g.all_qubits
46+
g.circuit, initial_state=initial_state, qubit_order=g.operation.qubits
4847
)
4948

5049
final_target_state = cirq.sub_state_vector(
5150
result.final_state_vector,
52-
keep_indices=[g.all_qubits.index(q) for q in g.quregs['target']],
51+
keep_indices=[g.operation.qubits.index(q) for q in g.quregs['target']],
5352
)
5453

5554
expected_target_state = cirq.Circuit(

cirq-ft/cirq_ft/algos/state_preparation_test.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import itertools
16-
1715
import cirq
1816
import cirq_ft
1917
import numpy as np
@@ -22,49 +20,18 @@
2220
from cirq_ft.infra.jupyter_tools import execute_notebook
2321

2422

25-
def construct_gate_helper_and_qubit_order(data, eps):
26-
gate = cirq_ft.StatePreparationAliasSampling.from_lcu_probs(
27-
lcu_probabilities=data, probability_epsilon=eps
28-
)
29-
g = cirq_ft.testing.GateHelper(gate)
30-
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())
31-
32-
def map_func(op: cirq.Operation, _):
33-
gateset = cirq.Gateset(cirq_ft.And, cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate)
34-
return cirq.Circuit(
35-
cirq.decompose(op, on_stuck_raise=None, keep=gateset.validate, context=context)
36-
)
37-
38-
# TODO: Do not decompose {cq.And, cq.LessThanEqualGate, cq.LessThanGate} because the
39-
# `cq.map_clean_and_borrowable_qubits` currently gets confused and is not able to re-map qubits
40-
# optimally; which results in a higher number of ancillas thus the tests fails due to OOO.
41-
decomposed_circuit = cirq.map_operations_and_unroll(
42-
g.circuit, map_func, raise_if_add_qubits=False
43-
)
44-
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", size=25, maximize_reuse=True)
45-
decomposed_circuit = cirq_ft.map_clean_and_borrowable_qubits(decomposed_circuit, qm=greedy_mm)
46-
# We are fine decomposing the `cq.And` gates once the qubit re-mapping is complete. Ideally,
47-
# we shouldn't require this two step process.
48-
arithmetic_gateset = cirq.Gateset(cirq_ft.LessThanEqualGate, cirq_ft.LessThanGate)
49-
decomposed_circuit = cirq.Circuit(
50-
cirq.decompose(decomposed_circuit, keep=arithmetic_gateset.validate, on_stuck_raise=None)
51-
)
52-
ordered_input = list(itertools.chain(*g.quregs.values()))
53-
qubit_order = cirq.QubitOrder.explicit(ordered_input, fallback=cirq.QubitOrder.DEFAULT)
54-
return g, qubit_order, decomposed_circuit
55-
56-
5723
@pytest.mark.parametrize("num_sites, epsilon", [[2, 3e-3], [3, 3.0e-3], [4, 5.0e-3], [7, 8.0e-3]])
5824
def test_state_preparation_via_coherent_alias_sampling(num_sites, epsilon):
5925
lcu_coefficients = get_1d_Ising_lcu_coeffs(num_sites)
60-
g, qubit_order, decomposed_circuit = construct_gate_helper_and_qubit_order(
61-
lcu_coefficients, epsilon
62-
)
63-
# assertion to ensure that simulating the `decomposed_circuit` doesn't run out of memory.
64-
assert len(decomposed_circuit.all_qubits()) < 25
65-
result = cirq.Simulator(dtype=np.complex128).simulate(
66-
decomposed_circuit, qubit_order=qubit_order
26+
gate = cirq_ft.StatePreparationAliasSampling.from_lcu_probs(
27+
lcu_probabilities=lcu_coefficients.tolist(), probability_epsilon=epsilon
6728
)
29+
g = cirq_ft.testing.GateHelper(gate)
30+
qubit_order = g.operation.qubits
31+
32+
# Assertion to ensure that simulating the `decomposed_circuit` doesn't run out of memory.
33+
assert len(g.circuit.all_qubits()) < 20
34+
result = cirq.Simulator(dtype=np.complex128).simulate(g.circuit, qubit_order=qubit_order)
6835
state_vector = result.final_state_vector
6936
# State vector is of the form |l>|temp_{l}>. We trace out the |temp_{l}> part to
7037
# get the coefficients corresponding to |l>.
@@ -82,7 +49,12 @@ def test_state_preparation_via_coherent_alias_sampling(num_sites, epsilon):
8249

8350
def test_state_preparation_via_coherent_alias_sampling_diagram():
8451
data = np.asarray(range(1, 5)) / np.sum(range(1, 5))
85-
g, qubit_order, _ = construct_gate_helper_and_qubit_order(data, 0.05)
52+
gate = cirq_ft.StatePreparationAliasSampling.from_lcu_probs(
53+
lcu_probabilities=data.tolist(), probability_epsilon=0.05
54+
)
55+
g = cirq_ft.testing.GateHelper(gate)
56+
qubit_order = g.operation.qubits
57+
8658
circuit = cirq.Circuit(cirq.decompose_once(g.operation))
8759
cirq.testing.assert_has_diagram(
8860
circuit,

cirq-ft/cirq_ft/algos/unary_iteration_gate.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ def _unary_iteration_segtree(
2828
ops: List[cirq.Operation],
2929
control: cirq.Qid,
3030
selection: Sequence[cirq.Qid],
31+
ancilla: Sequence[cirq.Qid],
3132
sl: int,
3233
l: int,
3334
r: int,
3435
l_iter: int,
3536
r_iter: int,
36-
qm: cirq.QubitManager,
3737
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
3838
"""Constructs a unary iteration circuit by iterating over nodes of an implicit Segment Tree.
3939
@@ -45,6 +45,8 @@ def _unary_iteration_segtree(
4545
selection: Sequence of selection qubits. The i'th qubit in the list corresponds to the i'th
4646
level in the segment tree.Thus, a total of O(logN) selection qubits are required for a
4747
tree on range `N = (r_iter - l_iter)`.
48+
ancilla: Pre-allocated ancilla qubits to be used for constructing the unary iteration
49+
circuit.
4850
sl: Current depth of the tree. `selection[sl]` gives the selection qubit corresponding to
4951
the current depth.
5052
l: Left index of the range represented by current node of the segment tree.
@@ -76,73 +78,73 @@ def _unary_iteration_segtree(
7678
if r_iter <= m:
7779
# Yield only left sub-tree.
7880
yield from _unary_iteration_segtree(
79-
ops, control, selection, sl + 1, l, m, l_iter, r_iter, qm
81+
ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter
8082
)
8183
return
8284
if l_iter >= m:
8385
# Yield only right sub-tree
8486
yield from _unary_iteration_segtree(
85-
ops, control, selection, sl + 1, m, r, l_iter, r_iter, qm
87+
ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter
8688
)
8789
return
88-
anc, sq = qm.qalloc(1)[0], selection[sl]
90+
anc, sq = ancilla[sl], selection[sl]
8991
ops.append(and_gate.And((1, 0)).on(control, sq, anc))
90-
yield from _unary_iteration_segtree(ops, anc, selection, sl + 1, l, m, l_iter, r_iter, qm)
92+
yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter)
9193
ops.append(cirq.CNOT(control, anc))
92-
yield from _unary_iteration_segtree(ops, anc, selection, sl + 1, m, r, l_iter, r_iter, qm)
94+
yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter)
9395
ops.append(and_gate.And(adjoint=True).on(control, sq, anc))
94-
qm.qfree([anc])
9596

9697

9798
def _unary_iteration_zero_control(
9899
ops: List[cirq.Operation],
99100
selection: Sequence[cirq.Qid],
101+
ancilla: Sequence[cirq.Qid],
100102
l_iter: int,
101103
r_iter: int,
102-
qm: cirq.QubitManager,
103104
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
104105
sl, l, r = 0, 0, 2 ** len(selection)
105106
m = (l + r) >> 1
106107
ops.append(cirq.X(selection[0]))
107108
yield from _unary_iteration_segtree(
108-
ops, selection[0], selection[1:], sl, l, m, l_iter, r_iter, qm
109+
ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter
109110
)
110111
ops.append(cirq.X(selection[0]))
111112
yield from _unary_iteration_segtree(
112-
ops, selection[0], selection[1:], sl, m, r, l_iter, r_iter, qm
113+
ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter
113114
)
114115

115116

116117
def _unary_iteration_single_control(
117118
ops: List[cirq.Operation],
118119
control: cirq.Qid,
119120
selection: Sequence[cirq.Qid],
121+
ancilla: Sequence[cirq.Qid],
120122
l_iter: int,
121123
r_iter: int,
122-
qm: cirq.QubitManager,
123124
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
124125
sl, l, r = 0, 0, 2 ** len(selection)
125-
yield from _unary_iteration_segtree(ops, control, selection, sl, l, r, l_iter, r_iter, qm)
126+
yield from _unary_iteration_segtree(ops, control, selection, ancilla, sl, l, r, l_iter, r_iter)
126127

127128

128129
def _unary_iteration_multi_controls(
129130
ops: List[cirq.Operation],
130131
controls: Sequence[cirq.Qid],
131132
selection: Sequence[cirq.Qid],
133+
ancilla: Sequence[cirq.Qid],
132134
l_iter: int,
133135
r_iter: int,
134-
qm: cirq.QubitManager,
135136
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
136137
num_controls = len(controls)
137-
and_ancilla = qm.qalloc(num_controls - 2)
138-
and_target = qm.qalloc(1)[0]
138+
and_ancilla = ancilla[: num_controls - 2]
139+
and_target = ancilla[num_controls - 2]
139140
multi_controlled_and = and_gate.And((1,) * len(controls)).on_registers(
140141
control=np.array(controls), ancilla=np.array(and_ancilla), target=and_target
141142
)
142143
ops.append(multi_controlled_and)
143-
yield from _unary_iteration_single_control(ops, and_target, selection, l_iter, r_iter, qm)
144+
yield from _unary_iteration_single_control(
145+
ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter
146+
)
144147
ops.append(cirq.inverse(multi_controlled_and))
145-
qm.qfree(and_ancilla + [and_target])
146148

147149

148150
def unary_iteration(
@@ -203,18 +205,18 @@ def unary_iteration(
203205
"""
204206
assert 2 ** len(selection) >= r_iter - l_iter
205207
assert len(selection) > 0
208+
ancilla = qubit_manager.qalloc(max(0, len(controls) + len(selection) - 1))
206209
if len(controls) == 0:
207-
yield from _unary_iteration_zero_control(
208-
flanking_ops, selection, l_iter, r_iter, qubit_manager
209-
)
210+
yield from _unary_iteration_zero_control(flanking_ops, selection, ancilla, l_iter, r_iter)
210211
elif len(controls) == 1:
211212
yield from _unary_iteration_single_control(
212-
flanking_ops, controls[0], selection, l_iter, r_iter, qubit_manager
213+
flanking_ops, controls[0], selection, ancilla, l_iter, r_iter
213214
)
214215
else:
215216
yield from _unary_iteration_multi_controls(
216-
flanking_ops, controls, selection, l_iter, r_iter, qubit_manager
217+
flanking_ops, controls, selection, ancilla, l_iter, r_iter
217218
)
219+
qubit_manager.qfree(ancilla)
218220

219221

220222
class UnaryIterationGate(infra.GateWithRegisters):

0 commit comments

Comments
 (0)