@@ -243,7 +243,7 @@ def _pauli_strings_to_basis_change_ops(
243243
244244
245245def _pauli_strings_to_basis_change_with_sweep (
246- pauli_strings : list [ops .PauliString ], qid_list : list [ops .Qid ]
246+ pauli_strings : Sequence [ops .PauliString ], qid_list : Sequence [ops .Qid ]
247247) -> dict [str , float ]:
248248 """Decide single-qubit rotation sweep parameters for basis change.
249249
@@ -274,20 +274,28 @@ def _pauli_strings_to_basis_change_with_sweep(
274274def _generate_basis_change_circuits (
275275 normalized_circuits_to_pauli : dict [circuits .FrozenCircuit , list [list [ops .PauliString ]]],
276276 insert_strategy : circuits .InsertStrategy ,
277+ qubits_to_measure : Sequence [ops .Qid ] | None = None ,
277278) -> list [circuits .Circuit ]:
278279 """Generates basis change circuits for each group of Pauli strings."""
279280 pauli_measurement_circuits : list [circuits .Circuit ] = []
280281
281282 for input_circuit , pauli_string_groups in normalized_circuits_to_pauli .items ():
283+ global_qubits = list (qubits_to_measure ) if qubits_to_measure is not None else None
284+
282285 basis_change_circuits = []
283286 input_circuit_unfrozen = input_circuit .unfreeze ()
284287 for pauli_strings in pauli_string_groups :
285- # Extract qubits from Pauli strings
286- qid_list = _extract_readout_qubits (pauli_strings )
288+ if global_qubits is not None :
289+ # Use the user-provided override
290+ current_qid_list = global_qubits
291+ else :
292+ # Extract qubits from Pauli strings
293+ current_qid_list = _extract_readout_qubits (pauli_strings )
294+
287295 basis_change_circuit = circuits .Circuit (
288296 input_circuit_unfrozen ,
289- _pauli_strings_to_basis_change_ops (pauli_strings , qid_list ),
290- ops .measure (* qid_list , key = "result" ),
297+ _pauli_strings_to_basis_change_ops (pauli_strings , current_qid_list ),
298+ ops .measure (* current_qid_list , key = "result" ),
291299 strategy = insert_strategy ,
292300 )
293301 basis_change_circuits .append (basis_change_circuit )
@@ -299,28 +307,55 @@ def _generate_basis_change_circuits(
299307def _generate_basis_change_circuits_with_sweep (
300308 normalized_circuits_to_pauli : dict [circuits .FrozenCircuit , list [list [ops .PauliString ]]],
301309 insert_strategy : circuits .InsertStrategy ,
310+ qubits_to_measure : Sequence [ops .Qid ] | None = None ,
302311) -> tuple [list [circuits .Circuit ], list [cirq .Sweepable ]]:
303312 """Generates basis change circuits for each group of Pauli strings with sweep."""
304313 parameterized_circuits : list [circuits .Circuit ] = []
305314 sweep_params : list [cirq .Sweepable ] = []
306315 for input_circuit , pauli_string_groups in normalized_circuits_to_pauli .items ():
307- for pauli_strings in pauli_string_groups :
308- # Extract qubits from Pauli strings
309- qid_list = _extract_readout_qubits ( pauli_strings )
310- phi_symbols = sympy .symbols (f"phi :{ len (qid_list )} " )
311- theta_symbols = sympy . symbols ( f"theta: { len ( qid_list ) } " )
316+ # If qubits_to_measure is provided, use it
317+ if qubits_to_measure :
318+ phi_symbols = sympy . symbols ( f"phi: { len ( qubits_to_measure ) } " )
319+ theta_symbols = sympy .symbols (f"theta :{ len (qubits_to_measure )} " )
320+
312321 # Create phased gates and measurement operator
313322 phased_gates = [
314323 ops .PhasedXPowGate (phase_exponent = (a - 1 ) / 2 , exponent = b )(qubit )
315- for a , b , qubit in zip (phi_symbols , theta_symbols , qid_list )
324+ for a , b , qubit in zip (phi_symbols , theta_symbols , qubits_to_measure )
316325 ]
317- measurement_op = ops .M (* qid_list , key = "result" )
326+ measurement_op = ops .M (* qubits_to_measure , key = "result" )
327+
318328 parameterized_circuit = circuits .Circuit (
319329 input_circuit .unfreeze (), phased_gates , measurement_op , strategy = insert_strategy
320330 )
321- sweep_param = _pauli_strings_to_basis_change_with_sweep (pauli_strings , qid_list )
322- parameterized_circuits .append (parameterized_circuit )
331+ sweep_param = []
332+ for pauli_strings in pauli_string_groups :
333+ sweep_param .append (
334+ _pauli_strings_to_basis_change_with_sweep (pauli_strings , qubits_to_measure )
335+ )
323336 sweep_params .append (sweep_param )
337+ parameterized_circuits .append (parameterized_circuit )
338+
339+ else :
340+ for pauli_strings in pauli_string_groups :
341+ # Extract qubits from Pauli strings
342+ qid_list = _extract_readout_qubits (pauli_strings )
343+ phi_symbols = sympy .symbols (f"phi:{ len (qid_list )} " )
344+ theta_symbols = sympy .symbols (f"theta:{ len (qid_list )} " )
345+ # Create phased gates and measurement operator
346+ phased_gates = [
347+ ops .PhasedXPowGate (phase_exponent = (a - 1 ) / 2 , exponent = b )(qubit )
348+ for a , b , qubit in zip (phi_symbols , theta_symbols , qid_list )
349+ ]
350+ measurement_op = ops .M (* qid_list , key = "result" )
351+ parameterized_circuit = circuits .Circuit (
352+ input_circuit .unfreeze (), phased_gates , measurement_op , strategy = insert_strategy
353+ )
354+ sweep_param_dict = _pauli_strings_to_basis_change_with_sweep (
355+ pauli_strings , qid_list
356+ )
357+ parameterized_circuits .append (parameterized_circuit )
358+ sweep_params .append (sweep_param_dict )
324359 return parameterized_circuits , sweep_params
325360
326361
@@ -376,6 +411,7 @@ def _process_pauli_measurement_results(
376411 pauli_repetitions : int ,
377412 timestamp : float ,
378413 disable_readout_mitigation : bool = False ,
414+ fixed_calibration_key : tuple [ops .Qid , ...] | None = None ,
379415) -> list [PauliStringMeasurementResult ]:
380416 """Calculates both error-mitigated expectation values and unmitigated expectation values
381417 from measurement results.
@@ -396,22 +432,28 @@ def _process_pauli_measurement_results(
396432 timestamp: The timestamp of the calibration results.
397433 disable_readout_mitigation: If set to True, returns no error-mitigated error
398434 expectation values.
435+ fixed_calibration_key: If provided, uses this key to retrieve the calibration result
436+ from `calibration_results` for all Pauli strings, regardless of their specific
437+ support. This is used when `measure_on_full_support` is True.
399438
400439 Returns:
401440 A list of PauliStringMeasurementResult.
402441 """
403-
404442 pauli_measurement_results : list [PauliStringMeasurementResult ] = []
405443
406444 for pauli_group_index , circuit_result in enumerate (circuit_results ):
407445 measurement_results = circuit_result .measurements ["result" ]
408446 pauli_strs = pauli_string_groups [pauli_group_index ]
409- pauli_readout_qubits = _extract_readout_qubits (pauli_strs )
447+
448+ if fixed_calibration_key is not None :
449+ pauli_readout_qubits = list (fixed_calibration_key )
450+ calibration_key = fixed_calibration_key
451+ else :
452+ pauli_readout_qubits = _extract_readout_qubits (pauli_strs )
453+ calibration_key = tuple (pauli_readout_qubits )
410454
411455 calibration_result = (
412- calibration_results [tuple (pauli_readout_qubits )]
413- if not disable_readout_mitigation
414- else None
456+ calibration_results [calibration_key ] if not disable_readout_mitigation else None
415457 )
416458
417459 for pauli_str in pauli_strs :
@@ -488,6 +530,7 @@ def measure_pauli_strings(
488530 rng_or_seed : np .random .Generator | int ,
489531 use_sweep : bool = False ,
490532 insert_strategy : circuits .InsertStrategy = circuits .InsertStrategy .INLINE ,
533+ measure_on_full_support : bool = False ,
491534) -> list [CircuitToPauliStringsMeasurementResult ]:
492535 """Measures expectation values of Pauli strings on given circuits with/without
493536 readout error mitigation.
@@ -521,7 +564,11 @@ def measure_pauli_strings(
521564 use_sweep: If True, uses parameterized circuits and sweeps parameters
522565 for both Pauli measurements and readout benchmarking. Defaults to False.
523566 insert_strategy: The strategy for inserting measurement operations into the circuit.
524- Defaults to circuits.InsertStrategy.INLINE.
567+ measure_on_full_support: If True, calculates the union of all qubits used in all
568+ Pauli strings (the full support). All circuits will then measure this full set
569+ of qubits, and readout benchmarking will be performed only once on this full set,
570+ rather than for every unique subset of Pauli qubits. This significantly reduces
571+ overhead when measuring many Pauli strings with varying support.
525572
526573 Returns:
527574 A list of CircuitToPauliStringsMeasurementResult objects, where each object contains:
@@ -542,12 +589,24 @@ def measure_pauli_strings(
542589
543590 # Extract unique qubit tuples from input pauli strings
544591 unique_qubit_tuples = set ()
545- for pauli_string_groups in normalized_circuits_to_pauli .values ():
546- for pauli_strings in pauli_string_groups :
547- unique_qubit_tuples .add (tuple (_extract_readout_qubits (pauli_strings )))
592+ if measure_on_full_support :
593+ full_support : set [ops .Qid ] = set ()
594+ for pauli_string_groups in normalized_circuits_to_pauli .values ():
595+ for pauli_strings in pauli_string_groups :
596+ for pauli_string in pauli_strings :
597+ full_support .update (pauli_string .qubits )
598+ # One calibration group
599+ unique_qubit_tuples .add (tuple (sorted (full_support )))
600+ else :
601+ for pauli_string_groups in normalized_circuits_to_pauli .values ():
602+ for pauli_strings in pauli_string_groups :
603+ unique_qubit_tuples .add (tuple (_extract_readout_qubits (pauli_strings )))
604+
548605 # qubits_list is a list of qubit tuples
549606 qubits_list = sorted (unique_qubit_tuples )
550607
608+ qubits_to_measure_arg = list (qubits_list [0 ]) if measure_on_full_support else None
609+
551610 # Build the basis-change circuits for each Pauli string group
552611 pauli_measurement_circuits : list [circuits .Circuit ] = []
553612 sweep_params : list [cirq .Sweepable ] = []
@@ -561,7 +620,7 @@ def measure_pauli_strings(
561620
562621 if use_sweep :
563622 pauli_measurement_circuits , sweep_params = _generate_basis_change_circuits_with_sweep (
564- normalized_circuits_to_pauli , insert_strategy
623+ normalized_circuits_to_pauli , insert_strategy , qubits_to_measure_arg
565624 )
566625
567626 # Run benchmarking using sweep for readout calibration
@@ -578,7 +637,7 @@ def measure_pauli_strings(
578637
579638 else :
580639 pauli_measurement_circuits = _generate_basis_change_circuits (
581- normalized_circuits_to_pauli , insert_strategy
640+ normalized_circuits_to_pauli , insert_strategy , qubits_to_measure_arg
582641 )
583642
584643 # Run shuffled benchmarking for readout calibration
@@ -595,26 +654,43 @@ def measure_pauli_strings(
595654 # Process the results to calculate expectation values
596655 results : list [CircuitToPauliStringsMeasurementResult ] = []
597656 circuit_result_index = 0
657+ input_circuit_index = 0
658+
598659 for input_circuit , pauli_string_groups in normalized_circuits_to_pauli .items ():
599660 disable_readout_mitigation = False if num_random_bitstrings != 0 else True
600661
601662 circuits_results_for_group : Sequence [cirq .ResultDict ] | Sequence [cirq .Result ] = []
602- results_slice = slice (circuit_result_index , circuit_result_index + len (pauli_string_groups ))
603- if use_sweep :
604- circuits_results_for_group = [r [0 ] for r in sweep_circuits_results [results_slice ]]
605663
664+ if use_sweep :
665+ if measure_on_full_support :
666+ circuits_results_for_group = sweep_circuits_results [input_circuit_index ]
667+ input_circuit_index += 1
668+ else :
669+ results_slice = slice (
670+ circuit_result_index , circuit_result_index + len (pauli_string_groups )
671+ )
672+ circuits_results_for_group = [r [0 ] for r in sweep_circuits_results [results_slice ]]
673+ circuit_result_index += len (pauli_string_groups )
606674 else :
675+ results_slice = slice (
676+ circuit_result_index , circuit_result_index + len (pauli_string_groups )
677+ )
607678 circuits_results_for_group = circuits_results [results_slice ]
679+ circuit_result_index += len (pauli_string_groups )
608680
609- circuit_result_index += len (pauli_string_groups )
610-
681+ fixed_calibration_key = (
682+ tuple (qubits_to_measure_arg )
683+ if measure_on_full_support and qubits_to_measure_arg is not None
684+ else None
685+ )
611686 pauli_measurement_results = _process_pauli_measurement_results (
612687 pauli_string_groups ,
613688 circuits_results_for_group ,
614689 calibration_results ,
615690 pauli_repetitions ,
616691 time .time (),
617692 disable_readout_mitigation ,
693+ fixed_calibration_key ,
618694 )
619695 results .append (
620696 CircuitToPauliStringsMeasurementResult (
0 commit comments