Skip to content

Commit fde100e

Browse files
Add measure_on_full_support to optimize Pauli string readout mitigation (#7760)
This PR introduces a new argument, measure_on_full_support, to the measure_pauli_strings function. This feature significantly reduces the overhead of readout error calibration when measuring a large number of Pauli strings that act on different subsets of qubits. --------- Co-authored-by: eliottrosenberg <61400172+eliottrosenberg@users.noreply.github.com>
1 parent b6c8c09 commit fde100e

2 files changed

Lines changed: 127 additions & 34 deletions

File tree

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py

Lines changed: 106 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _pauli_strings_to_basis_change_ops(
243243

244244

245245
def _pauli_strings_to_basis_change_with_sweep(
246-
pauli_strings: list[ops.PauliString], qid_list: list[ops.Qid]
246+
pauli_strings: Sequence[ops.PauliString], qid_list: Sequence[ops.Qid]
247247
) -> dict[str, float]:
248248
"""Decide single-qubit rotation sweep parameters for basis change.
249249
@@ -274,20 +274,28 @@ def _pauli_strings_to_basis_change_with_sweep(
274274
def _generate_basis_change_circuits(
275275
normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]],
276276
insert_strategy: circuits.InsertStrategy,
277+
qubits_to_measure: Sequence[ops.Qid] | None = None,
277278
) -> list[circuits.Circuit]:
278279
"""Generates basis change circuits for each group of Pauli strings."""
279280
pauli_measurement_circuits: list[circuits.Circuit] = []
280281

281282
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
283+
global_qubits = list(qubits_to_measure) if qubits_to_measure is not None else None
284+
282285
basis_change_circuits = []
283286
input_circuit_unfrozen = input_circuit.unfreeze()
284287
for pauli_strings in pauli_string_groups:
285-
# Extract qubits from Pauli strings
286-
qid_list = _extract_readout_qubits(pauli_strings)
288+
if global_qubits is not None:
289+
# Use the user-provided override
290+
current_qid_list = global_qubits
291+
else:
292+
# Extract qubits from Pauli strings
293+
current_qid_list = _extract_readout_qubits(pauli_strings)
294+
287295
basis_change_circuit = circuits.Circuit(
288296
input_circuit_unfrozen,
289-
_pauli_strings_to_basis_change_ops(pauli_strings, qid_list),
290-
ops.measure(*qid_list, key="result"),
297+
_pauli_strings_to_basis_change_ops(pauli_strings, current_qid_list),
298+
ops.measure(*current_qid_list, key="result"),
291299
strategy=insert_strategy,
292300
)
293301
basis_change_circuits.append(basis_change_circuit)
@@ -299,28 +307,55 @@ def _generate_basis_change_circuits(
299307
def _generate_basis_change_circuits_with_sweep(
300308
normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]],
301309
insert_strategy: circuits.InsertStrategy,
310+
qubits_to_measure: Sequence[ops.Qid] | None = None,
302311
) -> tuple[list[circuits.Circuit], list[cirq.Sweepable]]:
303312
"""Generates basis change circuits for each group of Pauli strings with sweep."""
304313
parameterized_circuits: list[circuits.Circuit] = []
305314
sweep_params: list[cirq.Sweepable] = []
306315
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
307-
for pauli_strings in pauli_string_groups:
308-
# Extract qubits from Pauli strings
309-
qid_list = _extract_readout_qubits(pauli_strings)
310-
phi_symbols = sympy.symbols(f"phi:{len(qid_list)}")
311-
theta_symbols = sympy.symbols(f"theta:{len(qid_list)}")
316+
# If qubits_to_measure is provided, use it
317+
if qubits_to_measure:
318+
phi_symbols = sympy.symbols(f"phi:{len(qubits_to_measure)}")
319+
theta_symbols = sympy.symbols(f"theta:{len(qubits_to_measure)}")
320+
312321
# Create phased gates and measurement operator
313322
phased_gates = [
314323
ops.PhasedXPowGate(phase_exponent=(a - 1) / 2, exponent=b)(qubit)
315-
for a, b, qubit in zip(phi_symbols, theta_symbols, qid_list)
324+
for a, b, qubit in zip(phi_symbols, theta_symbols, qubits_to_measure)
316325
]
317-
measurement_op = ops.M(*qid_list, key="result")
326+
measurement_op = ops.M(*qubits_to_measure, key="result")
327+
318328
parameterized_circuit = circuits.Circuit(
319329
input_circuit.unfreeze(), phased_gates, measurement_op, strategy=insert_strategy
320330
)
321-
sweep_param = _pauli_strings_to_basis_change_with_sweep(pauli_strings, qid_list)
322-
parameterized_circuits.append(parameterized_circuit)
331+
sweep_param = []
332+
for pauli_strings in pauli_string_groups:
333+
sweep_param.append(
334+
_pauli_strings_to_basis_change_with_sweep(pauli_strings, qubits_to_measure)
335+
)
323336
sweep_params.append(sweep_param)
337+
parameterized_circuits.append(parameterized_circuit)
338+
339+
else:
340+
for pauli_strings in pauli_string_groups:
341+
# Extract qubits from Pauli strings
342+
qid_list = _extract_readout_qubits(pauli_strings)
343+
phi_symbols = sympy.symbols(f"phi:{len(qid_list)}")
344+
theta_symbols = sympy.symbols(f"theta:{len(qid_list)}")
345+
# Create phased gates and measurement operator
346+
phased_gates = [
347+
ops.PhasedXPowGate(phase_exponent=(a - 1) / 2, exponent=b)(qubit)
348+
for a, b, qubit in zip(phi_symbols, theta_symbols, qid_list)
349+
]
350+
measurement_op = ops.M(*qid_list, key="result")
351+
parameterized_circuit = circuits.Circuit(
352+
input_circuit.unfreeze(), phased_gates, measurement_op, strategy=insert_strategy
353+
)
354+
sweep_param_dict = _pauli_strings_to_basis_change_with_sweep(
355+
pauli_strings, qid_list
356+
)
357+
parameterized_circuits.append(parameterized_circuit)
358+
sweep_params.append(sweep_param_dict)
324359
return parameterized_circuits, sweep_params
325360

326361

@@ -376,6 +411,7 @@ def _process_pauli_measurement_results(
376411
pauli_repetitions: int,
377412
timestamp: float,
378413
disable_readout_mitigation: bool = False,
414+
fixed_calibration_key: tuple[ops.Qid, ...] | None = None,
379415
) -> list[PauliStringMeasurementResult]:
380416
"""Calculates both error-mitigated expectation values and unmitigated expectation values
381417
from measurement results.
@@ -396,22 +432,28 @@ def _process_pauli_measurement_results(
396432
timestamp: The timestamp of the calibration results.
397433
disable_readout_mitigation: If set to True, returns no error-mitigated error
398434
expectation values.
435+
fixed_calibration_key: If provided, uses this key to retrieve the calibration result
436+
from `calibration_results` for all Pauli strings, regardless of their specific
437+
support. This is used when `measure_on_full_support` is True.
399438
400439
Returns:
401440
A list of PauliStringMeasurementResult.
402441
"""
403-
404442
pauli_measurement_results: list[PauliStringMeasurementResult] = []
405443

406444
for pauli_group_index, circuit_result in enumerate(circuit_results):
407445
measurement_results = circuit_result.measurements["result"]
408446
pauli_strs = pauli_string_groups[pauli_group_index]
409-
pauli_readout_qubits = _extract_readout_qubits(pauli_strs)
447+
448+
if fixed_calibration_key is not None:
449+
pauli_readout_qubits = list(fixed_calibration_key)
450+
calibration_key = fixed_calibration_key
451+
else:
452+
pauli_readout_qubits = _extract_readout_qubits(pauli_strs)
453+
calibration_key = tuple(pauli_readout_qubits)
410454

411455
calibration_result = (
412-
calibration_results[tuple(pauli_readout_qubits)]
413-
if not disable_readout_mitigation
414-
else None
456+
calibration_results[calibration_key] if not disable_readout_mitigation else None
415457
)
416458

417459
for pauli_str in pauli_strs:
@@ -488,6 +530,7 @@ def measure_pauli_strings(
488530
rng_or_seed: np.random.Generator | int,
489531
use_sweep: bool = False,
490532
insert_strategy: circuits.InsertStrategy = circuits.InsertStrategy.INLINE,
533+
measure_on_full_support: bool = False,
491534
) -> list[CircuitToPauliStringsMeasurementResult]:
492535
"""Measures expectation values of Pauli strings on given circuits with/without
493536
readout error mitigation.
@@ -521,7 +564,11 @@ def measure_pauli_strings(
521564
use_sweep: If True, uses parameterized circuits and sweeps parameters
522565
for both Pauli measurements and readout benchmarking. Defaults to False.
523566
insert_strategy: The strategy for inserting measurement operations into the circuit.
524-
Defaults to circuits.InsertStrategy.INLINE.
567+
measure_on_full_support: If True, calculates the union of all qubits used in all
568+
Pauli strings (the full support). All circuits will then measure this full set
569+
of qubits, and readout benchmarking will be performed only once on this full set,
570+
rather than for every unique subset of Pauli qubits. This significantly reduces
571+
overhead when measuring many Pauli strings with varying support.
525572
526573
Returns:
527574
A list of CircuitToPauliStringsMeasurementResult objects, where each object contains:
@@ -542,12 +589,24 @@ def measure_pauli_strings(
542589

543590
# Extract unique qubit tuples from input pauli strings
544591
unique_qubit_tuples = set()
545-
for pauli_string_groups in normalized_circuits_to_pauli.values():
546-
for pauli_strings in pauli_string_groups:
547-
unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings)))
592+
if measure_on_full_support:
593+
full_support: set[ops.Qid] = set()
594+
for pauli_string_groups in normalized_circuits_to_pauli.values():
595+
for pauli_strings in pauli_string_groups:
596+
for pauli_string in pauli_strings:
597+
full_support.update(pauli_string.qubits)
598+
# One calibration group
599+
unique_qubit_tuples.add(tuple(sorted(full_support)))
600+
else:
601+
for pauli_string_groups in normalized_circuits_to_pauli.values():
602+
for pauli_strings in pauli_string_groups:
603+
unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings)))
604+
548605
# qubits_list is a list of qubit tuples
549606
qubits_list = sorted(unique_qubit_tuples)
550607

608+
qubits_to_measure_arg = list(qubits_list[0]) if measure_on_full_support else None
609+
551610
# Build the basis-change circuits for each Pauli string group
552611
pauli_measurement_circuits: list[circuits.Circuit] = []
553612
sweep_params: list[cirq.Sweepable] = []
@@ -561,7 +620,7 @@ def measure_pauli_strings(
561620

562621
if use_sweep:
563622
pauli_measurement_circuits, sweep_params = _generate_basis_change_circuits_with_sweep(
564-
normalized_circuits_to_pauli, insert_strategy
623+
normalized_circuits_to_pauli, insert_strategy, qubits_to_measure_arg
565624
)
566625

567626
# Run benchmarking using sweep for readout calibration
@@ -578,7 +637,7 @@ def measure_pauli_strings(
578637

579638
else:
580639
pauli_measurement_circuits = _generate_basis_change_circuits(
581-
normalized_circuits_to_pauli, insert_strategy
640+
normalized_circuits_to_pauli, insert_strategy, qubits_to_measure_arg
582641
)
583642

584643
# Run shuffled benchmarking for readout calibration
@@ -595,26 +654,43 @@ def measure_pauli_strings(
595654
# Process the results to calculate expectation values
596655
results: list[CircuitToPauliStringsMeasurementResult] = []
597656
circuit_result_index = 0
657+
input_circuit_index = 0
658+
598659
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
599660
disable_readout_mitigation = False if num_random_bitstrings != 0 else True
600661

601662
circuits_results_for_group: Sequence[cirq.ResultDict] | Sequence[cirq.Result] = []
602-
results_slice = slice(circuit_result_index, circuit_result_index + len(pauli_string_groups))
603-
if use_sweep:
604-
circuits_results_for_group = [r[0] for r in sweep_circuits_results[results_slice]]
605663

664+
if use_sweep:
665+
if measure_on_full_support:
666+
circuits_results_for_group = sweep_circuits_results[input_circuit_index]
667+
input_circuit_index += 1
668+
else:
669+
results_slice = slice(
670+
circuit_result_index, circuit_result_index + len(pauli_string_groups)
671+
)
672+
circuits_results_for_group = [r[0] for r in sweep_circuits_results[results_slice]]
673+
circuit_result_index += len(pauli_string_groups)
606674
else:
675+
results_slice = slice(
676+
circuit_result_index, circuit_result_index + len(pauli_string_groups)
677+
)
607678
circuits_results_for_group = circuits_results[results_slice]
679+
circuit_result_index += len(pauli_string_groups)
608680

609-
circuit_result_index += len(pauli_string_groups)
610-
681+
fixed_calibration_key = (
682+
tuple(qubits_to_measure_arg)
683+
if measure_on_full_support and qubits_to_measure_arg is not None
684+
else None
685+
)
611686
pauli_measurement_results = _process_pauli_measurement_results(
612687
pauli_string_groups,
613688
circuits_results_for_group,
614689
calibration_results,
615690
pauli_repetitions,
616691
time.time(),
617692
disable_readout_mitigation,
693+
fixed_calibration_key,
618694
)
619695
results.append(
620696
CircuitToPauliStringsMeasurementResult(

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,32 +494,49 @@ def test_many_group_pauli_in_circuits_with_coefficient(use_sweep: bool) -> None:
494494
circuit_3 = cirq.FrozenCircuit(_create_ghz(8, qubits_3))
495495

496496
circuits_to_pauli: dict[cirq.FrozenCircuit, list[list[cirq.PauliString]]] = {}
497+
497498
circuits_to_pauli[circuit_1] = [
498499
_generate_qwc_paulis(
499-
_generate_random_pauli_string(qubits_1, enable_coeff=True, allow_pauli_i=False), 4
500+
_generate_random_pauli_string(qubits_1, enable_coeff=True, allow_pauli_i=False), 2
500501
)
501502
]
503+
502504
circuits_to_pauli[circuit_2] = [
503505
_generate_qwc_paulis(
504-
_generate_random_pauli_string(qubits_2, enable_coeff=True, allow_pauli_i=False), 5
506+
_generate_random_pauli_string(qubits_2, enable_coeff=True, allow_pauli_i=False), 2
505507
)
506508
]
509+
507510
circuits_to_pauli[circuit_3] = [
508511
_generate_qwc_paulis(
509-
_generate_random_pauli_string(qubits_3, enable_coeff=True, allow_pauli_i=False), 6
512+
_generate_random_pauli_string(qubits_3, enable_coeff=True, allow_pauli_i=False), 2
510513
)
511514
]
512515

513516
sampler = NoisySingleQubitReadoutSampler(p0=0.03, p1=0.05, seed=1234)
514517
simulator = cirq.Simulator()
515518

516519
circuits_with_pauli_expectations = measure_pauli_strings(
517-
circuits_to_pauli, sampler, 300, 300, 300, np.random.default_rng(), use_sweep
520+
circuits_to_pauli,
521+
sampler,
522+
300,
523+
300,
524+
300,
525+
np.random.default_rng(),
526+
use_sweep,
527+
measure_on_full_support=True,
518528
)
519529

520530
for circuit_with_pauli_expectations in circuits_with_pauli_expectations:
521531
assert isinstance(circuit_with_pauli_expectations.circuit, cirq.FrozenCircuit)
522532

533+
expected_group_count = len(circuits_to_pauli[circuit_with_pauli_expectations.circuit][0])
534+
535+
assert len(circuit_with_pauli_expectations.results) == expected_group_count, (
536+
f"Expected {expected_group_count} results (groups) for circuit, "
537+
f"but got {len(circuit_with_pauli_expectations.results)}."
538+
)
539+
523540
expected_val_simulation = simulator.simulate(
524541
circuit_with_pauli_expectations.circuit.unfreeze()
525542
)

0 commit comments

Comments
 (0)