diff --git a/doc/sinter_api.md b/doc/sinter_api.md index 5cb64f821..ac7e10006 100644 --- a/doc/sinter_api.md +++ b/doc/sinter_api.md @@ -42,6 +42,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`sinter.iter_collect`](#sinter.iter_collect) - [`sinter.log_binomial`](#sinter.log_binomial) - [`sinter.log_factorial`](#sinter.log_factorial) +- [`sinter.plot_custom`](#sinter.plot_custom) - [`sinter.plot_discard_rate`](#sinter.plot_discard_rate) - [`sinter.plot_error_rate`](#sinter.plot_error_rate) - [`sinter.post_selection_mask_from_4th_coord`](#sinter.post_selection_mask_from_4th_coord) @@ -1452,6 +1453,70 @@ def log_factorial( """ ``` + +```python +# sinter.plot_custom + +# (at top-level in the sinter module) +def plot_custom( + *, + ax: 'plt.Axes', + stats: 'Iterable[sinter.TaskStats]', + x_func: Callable[[sinter.TaskStats], Any], + y_func: Callable[[sinter.TaskStats], Union[sinter.Fit, float, int]], + group_func: Callable[[sinter.TaskStats], ~TCurveId] = lambda _: None, + point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None, + filter_func: Callable[[sinter.TaskStats], Any] = lambda _: True, + plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(), + line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None, +) -> None: + """Plots error rates in curves with uncertainty highlights. + + Args: + ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`. + stats: The collected statistics to plot. + x_func: The X coordinate to use for each stat's data point. For example, this could be + `x_func=lambda stat: stat.json_metadata['physical_error_rate']`. + y_func: The Y value to use for each stat's data point. This can be a float or it can be a + sinter.Fit value, in which case the curve will follow the fit.best value and a + highlighted area will be shown from fit.low to fit.high. + group_func: Optional. When specified, multiple curves will be plotted instead of one curve. + The statistics are grouped into curves based on whether or not they get the same result + out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. + point_label_func: Optional. Specifies text to draw next to data points. + filter_func: Optional. When specified, some curves will not be plotted. + The statistics are filtered and only plotted if filter_func(stat) returns True. + For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats + where the saved metadata indicates the basis was 'x'. + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to + `plot` and `fill_between` used to do the actual plotting. For example, this can be used + to specify markers and colors. Takes the index of the curve in sorted order and also a + curve_id (these will be 0 and None respectively if group_func is not specified). For example, + this could be: + + plot_args_func=lambda index, group_key, group_stats: { + 'color': ( + 'red' + if group_key == 'decoder=pymatching p=0.001' + else 'blue' + ), + } + line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line + fit to every curve. The scales determine how to transform the coordinates before + performing the fit, and can be set to 'linear', 'sqrt', or 'log'. + """ +``` + ```python # sinter.plot_discard_rate diff --git a/doc/usage_command_line.md b/doc/usage_command_line.md index b5f107f60..2e5418e32 100644 --- a/doc/usage_command_line.md +++ b/doc/usage_command_line.md @@ -1676,6 +1676,7 @@ SYNOPSIS [--out_format 01|b8|r8|ptb64|hits|dets] \ [--seed int] \ [--shots int] \ + [--skip_loop_folding] \ [--skip_reference_sample] DESCRIPTION @@ -1762,6 +1763,30 @@ OPTIONS Must be an integer between 0 and a quintillion (10^18). + --skip_loop_folding + Skips loop folding logic on the reference sample calculation. + + When this argument is specified, the reference sample (that is used + to convert measurement flip data from frame simulations into actual + measurement data) is generated by iterating through the entire + flattened circuit with no loop detection. + + Loop folding can enormously improve performance for circuits + containing REPEAT blocks with large repeat counts, by detecting + periodicity in loops and fast-forwarding across them when computing + the reference sample for the circuit. However, in some cases the + analysis is not able to detect the periodicity that is present. For + example, this has been observed in honeycomb code circuits. When + this happens, the folding-capable analysis is slower than simply + analyzing the flattened circuit without any specialized loop logic. + The `--skip_loop_folding` flag can be used to just analyze the + flattened circuit, bypassing this slowdown for circuits such as + honeycomb code circuits. + + By default, loop detection is enabled. Pass this flag to disable + it (when appropriate by use case). + + --skip_reference_sample Asserts the circuit can produce a noiseless sample that is just 0s. diff --git a/glue/cirq/stimcirq/__init__.py b/glue/cirq/stimcirq/__init__.py index d3a8e1aa1..2af970b75 100644 --- a/glue/cirq/stimcirq/__init__.py +++ b/glue/cirq/stimcirq/__init__.py @@ -3,6 +3,7 @@ from ._cx_swap_gate import CXSwapGate from ._cz_swap_gate import CZSwapGate from ._det_annotation import DetAnnotation +from ._feedback_pauli import FeedbackPauli from ._obs_annotation import CumulativeObservableAnnotation from ._shift_coords_annotation import ShiftCoordsAnnotation from ._stim_sampler import StimSampler @@ -19,6 +20,7 @@ JSON_RESOLVERS_DICT = { "CumulativeObservableAnnotation": CumulativeObservableAnnotation, "DetAnnotation": DetAnnotation, + "FeedbackPauli": FeedbackPauli, "MeasureAndOrResetGate": MeasureAndOrResetGate, "ShiftCoordsAnnotation": ShiftCoordsAnnotation, "SweepPauli": SweepPauli, diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py index 66a4a3037..e94da2c41 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim.py +++ b/glue/cirq/stimcirq/_cirq_to_stim.py @@ -5,9 +5,8 @@ import cirq import stim +import sympy -from ._i_error_gate import IErrorGate -from ._ii_error_gate import IIErrorGate from ._ii_gate import IIGate @@ -142,6 +141,58 @@ def cirq_circuit_to_stim_data( StimTypeHandler = Callable[[stim.Circuit, cirq.Gate, List[int], str], None] +StimOpTypeHandler = Callable[[stim.Circuit, cirq.Operation, List[int], str, List[Tuple[str, int]]], None] + + +def _stim_append_classically_controlled_gate( + circuit: stim.Circuit, + op: cirq.ClassicallyControlledOperation, + targets: List[int], + tag: str, + measurement_key_lengths: List[Tuple[str, int]]): + + if len(op.classical_controls) != 1: + raise NotImplementedError(f'Stim only supports single-control Pauli feedback, but got {op=}') + controls: list[cirq.KeyCondition] = [] + single_control, = op.classical_controls + if isinstance(single_control, cirq.KeyCondition): + controls.append(single_control) + elif isinstance(single_control, cirq.SympyCondition) and isinstance(single_control.expr, sympy.Xor) and all(isinstance(e, sympy.Symbol) for e in single_control.expr.args): + for symbol in single_control.expr.args: + controls.append(cirq.KeyCondition(key=cirq.MeasurementKey(str(symbol)), index=-1)) + else: + raise NotImplementedError(f'Stim only supports single-control Pauli feedback (i.e. a `cirq.KeyCondition` control), but got {single_control=}') + gate = op.without_classical_controls().gate + + if gate == cirq.X: + stim_gate = 'X' + elif gate == cirq.Y: + stim_gate = 'Y' + elif gate == cirq.Z: + stim_gate = 'Z' + else: + raise NotImplementedError(f'Stim only supports Pauli feedback, but got {op=}') + assert len(targets) == 1 + + for control in controls: + skips_left = control.index + for offset in range(len(measurement_key_lengths)): + m_key, m_len = measurement_key_lengths[-1 - offset] + if m_len != 1: + raise NotImplementedError(f"multi-qubit measurement {m_key!r}") + if m_key == control.key: + if skips_left > 0: + skips_left -= 1 + else: + rec_target = stim.target_rec(-1 - offset) + break + else: + raise ValueError( + f"{control!r} was processed before the measurement it referenced." + f" Make sure the referenced measurements keys are actually in the circuit, and come" + f" in an earlier moment (or earlier in the same moment's operation order)." + ) + circuit.append(f"C{stim_gate}", [rec_target, targets[0]], tag=tag) @functools.lru_cache(maxsize=1) @@ -278,6 +329,14 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: } +@functools.lru_cache() +def op_type_to_stim_append_func() -> Dict[Type[cirq.Operation], StimOpTypeHandler]: + """A dictionary mapping specific gate types to stim circuit appending functions.""" + return { + cirq.ClassicallyControlledOperation: _stim_append_classically_controlled_gate, + } + + def _stim_append_measurement_gate( circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str ): @@ -454,7 +513,8 @@ def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation, def process_operations(self, operations: Iterable[cirq.Operation]) -> None: g2f = gate_to_stim_append_func() - t2f = gate_type_to_stim_append_func() + tg2f = gate_type_to_stim_append_func() + to2f = op_type_to_stim_append_func() for op in operations: assert isinstance(op, cirq.Operation) tag = self.tag_func(op) @@ -500,11 +560,16 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None: continue # Look for recognized gate types like cirq.DepolarizingChannel. - type_append_func = t2f.get(type(gate)) + type_append_func = tg2f.get(type(gate)) if type_append_func is not None: type_append_func(self.out, gate, targets, tag=tag) continue + op_type_append_func = to2f.get(type(op)) + if op_type_append_func is not None: + op_type_append_func(self.out, op, targets, tag, self.key_out) + continue + # Ask unrecognized operations to decompose themselves into simpler operations. try: self.process_operations(cirq.decompose_once(op)) diff --git a/glue/cirq/stimcirq/_cirq_to_stim_test.py b/glue/cirq/stimcirq/_cirq_to_stim_test.py index 439702113..ff5e5c2ea 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim_test.py +++ b/glue/cirq/stimcirq/_cirq_to_stim_test.py @@ -6,6 +6,7 @@ import pytest import stim import stimcirq +import sympy from stimcirq._cirq_to_stim import cirq_circuit_to_stim_data, gate_to_stim_append_func @@ -425,3 +426,30 @@ def test_round_trip_example_circuit(): cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit.flattened()) circuit_back = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) assert len(circuit_back.shortest_graphlike_error()) == 3 + + +def test_xor_feedback(): + a, b, c, d, e = cirq.LineQubit.range(5) + cirq_circuit = cirq.Circuit([ + cirq.Moment( + cirq.measure(a, key='a'), + cirq.measure(b, key='b'), + cirq.measure(c, key='c'), + cirq.measure(d, key='d'), + ), + cirq.Moment( + cirq.X(e).with_classical_controls(cirq.SympyCondition(sympy.Xor( + sympy.Symbol('a'), + sympy.Symbol('b'), + sympy.Symbol('c'), + sympy.Symbol('d'), + ))), + ), + ]) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) + assert stim_circuit == stim.Circuit(""" + M 0 1 2 3 + TICK + CX rec[-4] 4 rec[-3] 4 rec[-2] 4 rec[-1] 4 + TICK + """) diff --git a/glue/cirq/stimcirq/_feedback_pauli.py b/glue/cirq/stimcirq/_feedback_pauli.py new file mode 100644 index 000000000..bd33ce43a --- /dev/null +++ b/glue/cirq/stimcirq/_feedback_pauli.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Tuple, Optional + +import cirq +import stim + + +@cirq.value_equality +class FeedbackPauli(cirq.Gate): + """A Pauli gate conditioned on a prior measurement.""" + + def __init__( + self, + *, + relative_measurement_index: Optional[int] = None, + pauli: cirq.Pauli, + ): + r""" + + Args: + relative_measurement_index: A negative integer identifying how many measurements ago is the measurement that + controls the Pauli operation. + pauli: The cirq Pauli operation to apply when the bit is True. + """ + if relative_measurement_index is not None and (relative_measurement_index >= 0 or not isinstance(relative_measurement_index, int)): + raise ValueError(f"{relative_measurement_index=} isn't a negative int (note {type(relative_measurement_index)=})") + self.relative_measurement_index = relative_measurement_index + self.pauli = pauli + + def _is_parameterized_(self) -> bool: + return False + + def _num_qubits_(self) -> int: + return 1 + + def _value_equality_values_(self) -> Any: + return self.pauli, self.relative_measurement_index + + def _circuit_diagram_info_(self, args: Any) -> str: + return f"{self.pauli}^rec[{self.relative_measurement_index}]" + + @staticmethod + def _json_namespace_() -> str: + return '' + + def _json_dict_(self) -> Dict[str, Any]: + return { + 'pauli': self.pauli, + 'relative_measurement_index': self.relative_measurement_index, + } + + def __repr__(self) -> str: + return ( + f'stimcirq.FeedbackPauli(' + f'relative_measurement_index={self.relative_measurement_index!r}, ' + f'pauli={self.pauli!r})' + ) + + def _stim_conversion_( + self, + *, + edit_circuit: stim.Circuit, + tag: str, + targets: List[int], + **kwargs, + ): + rec_target = stim.target_rec(self.relative_measurement_index) + edit_circuit.append(f"C{self.pauli}", [rec_target, targets[0]], tag=tag) diff --git a/glue/cirq/stimcirq/_feedback_pauli_test.py b/glue/cirq/stimcirq/_feedback_pauli_test.py new file mode 100644 index 000000000..2fc97f5ae --- /dev/null +++ b/glue/cirq/stimcirq/_feedback_pauli_test.py @@ -0,0 +1,209 @@ +import cirq +import pytest +import stim +import stimcirq + + +def test_cirq_to_stim_to_cirq_classical_control(): + q = cirq.LineQubit(0) + cirq_circuit = cirq.Circuit( + cirq.measure(q, key="test"), + cirq.X(q).with_classical_controls("test").with_tags("test2") + ) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) + assert stim_circuit == stim.Circuit(""" + M 0 + TICK + CX[test2] rec[-1] 0 + TICK + """) + assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit( + cirq.measure(q, key="0"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags("test2") + ) + + +def test_cirq_to_stim_to_cirq_feedback_pauli(): + q = cirq.LineQubit(0) + cirq_circuit = cirq.Circuit( + cirq.measure(q, key="test"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3') + ) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) + assert stim_circuit == stim.Circuit(""" + M 0 + TICK + CX[test3] rec[-1] 0 + TICK + """) + assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit( + cirq.measure(q, key="0"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3') + ) + + +def test_stim_to_cirq_conversion(): + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + XCZ rec[-1] 3 + """)) + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + YCZ rec[-1] 3 + """)) + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + CY 3 rec[-1] + """)) + with pytest.raises(NotImplementedError, match="wrong target"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + CX 3 rec[-1] + """)) + with pytest.raises(NotImplementedError, match="Two classical"): + stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 1 + TICK + CZ rec[-1] rec[-2] + """)) + + assert stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + M 0 + TICK + ZCX rec[-1] 0 + ZCY rec[-1] 1 + ZCZ rec[-1] 2 + XCZ 3 rec[-1] + YCZ 4 rec[-1] + ZCZ 5 rec[-1] + """)) == cirq.Circuit( + cirq.Moment( + cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='0')), + ), + cirq.Moment( + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(0)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(1)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(2)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(3)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(4)), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(5)), + ), + ) + + +def test_stim_conversion(): + a, b, c = cirq.LineQubit.range(3) + + with pytest.raises(ValueError, match="earlier"): + stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit(cirq.Moment(cirq.X(a).with_classical_controls("unknown"))) + ) + with pytest.raises(ValueError, match="earlier"): + stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment( + cirq.X(a).with_classical_controls("unknown"), cirq.measure(b, key="later") + ) + ) + ) + with pytest.raises(ValueError, match="earlier"): + stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment(cirq.X(a).with_classical_controls("unknown")), + cirq.Moment(cirq.measure(b, key="later")), + ) + ) + assert stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment(cirq.measure(b, key="earlier")), + cirq.Moment(cirq.X(b).with_classical_controls("earlier")), + ) + ) == stim.Circuit( + """ + QUBIT_COORDS(1) 0 + M 0 + TICK + CX rec[-1] 0 + TICK + """ + ) + + assert stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.Moment(cirq.measure(a, key="a"), cirq.measure(b, key="b")), + cirq.Moment( + cirq.X(b).with_classical_controls("a"), + ), + cirq.Moment( + cirq.Z(b).with_classical_controls("b"), + ), + ) + ) == stim.Circuit( + """ + M 0 1 + TICK + CX rec[-2] 1 + TICK + CZ rec[-1] 1 + TICK + """ + ) + + +def test_diagram(): + a, b = cirq.LineQubit.range(2) + cirq.testing.assert_has_diagram( + cirq.Circuit( + cirq.measure(a, key="a"), + cirq.measure(b, key="b"), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli='Y').on(a), + ), + """ +0: ---M('a')---Y^rec[-1]--- + +1: ---M('b')--------------- + """, + use_unicode_characters=False, + ) + + +def test_repr(): + val = stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y) + assert eval(repr(val), {"cirq": cirq, "stimcirq": stimcirq}) == val + + +def test_equality(): + eq = cirq.testing.EqualsTester() + eq.add_equality_group( + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X), + stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X)) + eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y)) + eq.add_equality_group( + stimcirq.FeedbackPauli(relative_measurement_index=-4, pauli=cirq.X), + ) + eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-10, pauli=cirq.Z)) + + +def test_json_serialization(): + c = cirq.Circuit( + stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X).on(cirq.LineQubit(0)), + stimcirq.FeedbackPauli(relative_measurement_index=-5, pauli=cirq.Y).on(cirq.LineQubit(1)), + stimcirq.FeedbackPauli(relative_measurement_index=-7, pauli=cirq.Z).on(cirq.LineQubit(2)), + ) + json = cirq.to_json(c) + c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) + assert c == c2 + + +def test_json_backwards_compat_exact(): + raw = stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X) + packed = '{\n "cirq_type": "FeedbackPauli",\n "pauli": {\n "cirq_type": "_PauliX",\n "exponent": 1.0,\n "global_shift": 0.0\n },\n "relative_measurement_index": -3\n}' + assert cirq.to_json(raw) == packed + assert cirq.read_json(json_text=packed, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw diff --git a/glue/cirq/stimcirq/_stim_to_cirq.py b/glue/cirq/stimcirq/_stim_to_cirq.py index 593bf4797..392391830 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq.py +++ b/glue/cirq/stimcirq/_stim_to_cirq.py @@ -8,6 +8,7 @@ Dict, Iterable, List, + Optional, Tuple, Union, ) @@ -25,6 +26,7 @@ from ._obs_annotation import CumulativeObservableAnnotation from ._shift_coords_annotation import ShiftCoordsAnnotation from ._sweep_pauli import SweepPauli +from ._feedback_pauli import FeedbackPauli def _stim_targets_to_dense_pauli_string( @@ -64,7 +66,7 @@ def _proper_transform_circuit_qubits(circuit: cirq.AbstractCircuit, remap: Dict[ class CircuitTranslationTracker: - def __init__(self, flatten: bool): + def __init__(self, flatten: bool, single_measure_key: Optional[str] = None): self.qubit_coords: Dict[int, cirq.Qid] = {} self.origin: DefaultDict[float] = collections.defaultdict(float) self.num_measurements_seen = 0 @@ -72,11 +74,17 @@ def __init__(self, flatten: bool): self.tick_circuit = cirq.Circuit() self.flatten = flatten self.have_seen_loop = False + self.single_measure_key = single_measure_key def get_next_measure_id(self) -> int: self.num_measurements_seen += 1 return self.num_measurements_seen - 1 + def get_next_measure_key(self) -> str: + if self.single_measure_key is None: + return str(self.get_next_measure_id()) + return self.single_measure_key + def append_operation(self, op: cirq.Operation) -> None: self.tick_circuit.append(op, strategy=cirq.InsertStrategy.INLINE) @@ -186,7 +194,7 @@ def process_measurement_instruction( for t in targets: if not t.is_qubit_target: raise NotImplementedError(f"instruction={instruction!r}") - key = str(self.get_next_measure_id()) + key = self.get_next_measure_key() self.append_operation( MeasureAndOrResetGate( measure=measure, @@ -248,7 +256,7 @@ def process_mpp(self, instruction: stim.CircuitInstruction) -> None: obs = _stim_targets_to_dense_pauli_string(group) qubits = [cirq.LineQubit(t.value) for t in group] - key = str(self.get_next_measure_id()) + key = self.get_next_measure_key() self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags)) def process_spp_dag(self, instruction: stim.CircuitInstruction) -> None: @@ -290,7 +298,7 @@ def process_m_pair(self, instruction: stim.CircuitInstruction, basis: str) -> No if targets[0].is_inverted_result_target ^ targets[1].is_inverted_result_target: obs *= -1 qubits = [cirq.LineQubit(targets[0].value), cirq.LineQubit(targets[1].value)] - key = str(self.get_next_measure_id()) + key = self.get_next_measure_key() self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags)) def process_mxx(self, instruction: stim.CircuitInstruction) -> None: @@ -309,7 +317,7 @@ def process_mpad(self, instruction: stim.CircuitInstruction) -> None: if t.value == 1: obs *= -1 qubits = [] - key = str(self.get_next_measure_id()) + key = self.get_next_measure_key() self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits)) def process_correlated_error(self, instruction: stim.CircuitInstruction) -> None: @@ -407,9 +415,11 @@ def __call__( tracker.process_gate_instruction(gate=self.gate, instruction=instruction) class SweepableGateHandler: - def __init__(self, pauli_gate: cirq.Pauli, gate: cirq.Gate): + def __init__(self, pauli_gate: cirq.Pauli, gate: cirq.Gate, allow_first: bool, allow_second: bool): self.pauli_gate = pauli_gate self.gate = gate + self.allow_first = allow_first + self.allow_second = allow_second def __call__( self, tracker: 'CircuitTranslationTracker', instruction: stim.CircuitInstruction @@ -422,8 +432,12 @@ def __call__( for k in range(0, len(targets), 2): a = targets[k] b = targets[k + 1] + if not a.is_qubit_target and not self.allow_first: + raise NotImplementedError(f"Classical control is on the wrong target: instruction={instruction!r}") + if not b.is_qubit_target and not self.allow_second: + raise NotImplementedError(f"Classical control is on the wrong target: instruction={instruction!r}") if not a.is_qubit_target and not b.is_qubit_target: - raise NotImplementedError(f"instruction={instruction!r}") + raise NotImplementedError(f"Two classical controls: instruction={instruction!r}") if a.is_sweep_bit_target or b.is_sweep_bit_target: if b.is_sweep_bit_target: a, b = b, a @@ -435,6 +449,16 @@ def __call__( pauli=self.pauli_gate, ).on(cirq.LineQubit(b.value)).with_tags(*tags) ) + elif a.is_measurement_record_target or b.is_measurement_record_target: + if b.is_measurement_record_target: + a, b = b, a + assert not a.is_inverted_result_target + tracker.append_operation( + FeedbackPauli( + relative_measurement_index=a.value, + pauli=self.pauli_gate, + ).on(cirq.LineQubit(b.value)).with_tags(*tags) + ) else: if not a.is_qubit_target or not b.is_qubit_target: raise NotImplementedError(f"instruction={instruction!r}") @@ -585,17 +609,17 @@ def handler( "ISWAP_DAG": gate(cirq.ISWAP ** -1), "XCX": gate(cirq.PauliInteractionGate(cirq.X, False, cirq.X, False)), "XCY": gate(cirq.PauliInteractionGate(cirq.X, False, cirq.Y, False)), - "XCZ": sweep_gate(cirq.X, cirq.PauliInteractionGate(cirq.X, False, cirq.Z, False)), + "XCZ": sweep_gate(cirq.X, cirq.PauliInteractionGate(cirq.X, False, cirq.Z, False), False, True), "YCX": gate(cirq.PauliInteractionGate(cirq.Y, False, cirq.X, False)), "YCY": gate(cirq.PauliInteractionGate(cirq.Y, False, cirq.Y, False)), - "YCZ": sweep_gate(cirq.Y, cirq.PauliInteractionGate(cirq.Y, False, cirq.Z, False)), - "CX": sweep_gate(cirq.X, cirq.CNOT), - "CNOT": sweep_gate(cirq.X, cirq.CNOT), - "ZCX": sweep_gate(cirq.X, cirq.CNOT), - "CY": sweep_gate(cirq.Y, cirq.Y.controlled(1)), - "ZCY": sweep_gate(cirq.Y, cirq.Y.controlled(1)), - "CZ": sweep_gate(cirq.Z, cirq.CZ), - "ZCZ": sweep_gate(cirq.Z, cirq.CZ), + "YCZ": sweep_gate(cirq.Y, cirq.PauliInteractionGate(cirq.Y, False, cirq.Z, False), False, True), + "CX": sweep_gate(cirq.X, cirq.CNOT, True, False), + "CNOT": sweep_gate(cirq.X, cirq.CNOT, True, False), + "ZCX": sweep_gate(cirq.X, cirq.CNOT, True, False), + "CY": sweep_gate(cirq.Y, cirq.Y.controlled(1), True, False), + "ZCY": sweep_gate(cirq.Y, cirq.Y.controlled(1), True, False), + "CZ": sweep_gate(cirq.Z, cirq.CZ, True, True), + "ZCZ": sweep_gate(cirq.Z, cirq.CZ, True, True), "DEPOLARIZE1": noise(lambda p: cirq.DepolarizingChannel(p, 1)), "DEPOLARIZE2": noise(lambda p: cirq.DepolarizingChannel(p, 2)), "X_ERROR": noise(cirq.X.with_probability), @@ -632,12 +656,17 @@ def handler( } -def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False) -> cirq.Circuit: +def stim_circuit_to_cirq_circuit( + circuit: stim.Circuit, + *, + flatten: bool = False, + single_measure_key: Optional[str] = None, +) -> cirq.Circuit: """Converts a stim circuit into an equivalent cirq circuit. Qubit indices are turned into cirq.LineQubit instances. Measurements are keyed by their ordering (e.g. the first measurement is keyed "0", the second - is keyed "1", etc). + is keyed "1", etc) unless a fixed measure_key is provided. Not all circuits can be converted: - ELSE_CORRELATED_ERROR instructions are not supported. @@ -652,6 +681,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False explicitly repeating their instructions multiple times. Also, SHIFT_COORDS instructions are removed by appropriately adjusting the coordinate metadata of later instructions. + single_measure_key: Defaults to None. If provided, all measurements are + keyed with this string instead of sequentially generated numbers. Returns: The converted circuit. @@ -671,6 +702,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False │ 1: ───────X──────────────────!M('0')─── """ - tracker = CircuitTranslationTracker(flatten=flatten) + tracker = CircuitTranslationTracker( + flatten=flatten, single_measure_key=single_measure_key + ) tracker.process_circuit(repetitions=1, circuit=circuit) return tracker.output() diff --git a/glue/cirq/stimcirq/_stim_to_cirq_test.py b/glue/cirq/stimcirq/_stim_to_cirq_test.py index facd79dd2..0439ce759 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq_test.py +++ b/glue/cirq/stimcirq/_stim_to_cirq_test.py @@ -778,4 +778,50 @@ def test_round_trip_with_pauli_obs(): """) cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) restored_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) - assert restored_circuit == stim_circuit \ No newline at end of file + assert restored_circuit == stim_circuit + + +def test_single_measure_key_order(): + stim_circuits = [ + stim.Circuit( + """ + X 1 + X 1 3 + X 1 3 + X 1 3 2 + M 1 + M 3 + M 2 + M 0 + """ + ), + stim.Circuit( + """ + X 1 + X 1 + X 1 + X 1 + M 1 3 + X 2 + M 2 0 + """ + ) + ] + measure_key = "m" + for stim_circuit in stim_circuits: + cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit( + stim_circuit, single_measure_key=measure_key + ) + qubits = cirq.LineQubit.range(4) + expected_order = [ + qubits[targ.qubit_value] + for inst in stim_circuit if inst.name == "M" + for targ in inst.targets_copy() + ] + actual_order = [] + for op in cirq_circuit.all_operations(): + if isinstance(op.gate, cirq.MeasurementGate): + assert op.gate.key == measure_key + assert len(op.qubits) == 1 + actual_order.append(op.qubits[0]) + assert expected_order == actual_order diff --git a/glue/crumble/main.js b/glue/crumble/main.js index 1d16f559d..7888b9d32 100644 --- a/glue/crumble/main.js +++ b/glue/crumble/main.js @@ -195,12 +195,12 @@ function makeChordHandlers() { res.set('shift+t', preview => editorState.rotate45(-1, preview)); res.set('t', preview => editorState.rotate45(+1, preview)); - res.set('escape', () => editorState.clearFocus); + res.set('escape', () => editorState.clearFocus()); res.set('delete', preview => editorState.deleteAtFocus(preview)); res.set('backspace', preview => editorState.deleteAtFocus(preview)); res.set('ctrl+delete', preview => editorState.deleteCurLayer(preview)); res.set('ctrl+insert', preview => editorState.insertLayer(preview)); - res.set('ctrl+backspace', () => editorState.deleteCurLayer); + res.set('ctrl+backspace', preview => editorState.deleteCurLayer(preview)); res.set('ctrl+z', preview => { if (!preview) editorState.undo() }); res.set('ctrl+y', preview => { if (!preview) editorState.redo() }); res.set('ctrl+shift+z', preview => { if (!preview) editorState.redo() }); diff --git a/glue/sample/src/sinter/__init__.py b/glue/sample/src/sinter/__init__.py index 47f702793..9046657f2 100644 --- a/glue/sample/src/sinter/__init__.py +++ b/glue/sample/src/sinter/__init__.py @@ -37,6 +37,7 @@ better_sorted_str_terms, plot_discard_rate, plot_error_rate, + plot_custom, group_by, ) from sinter._predict import ( diff --git a/glue/sample/src/sinter/_collection/_collection_manager.py b/glue/sample/src/sinter/_collection/_collection_manager.py index 19cba6bf8..637cd868b 100644 --- a/glue/sample/src/sinter/_collection/_collection_manager.py +++ b/glue/sample/src/sinter/_collection/_collection_manager.py @@ -227,6 +227,8 @@ def _compute_task_ids(self): shots_left = options.max_shots errors_left = options.max_errors + if shots_left is None: + raise ValueError("Didn't specify --max_shots.") if errors_left is None: errors_left = shots_left errors_left = min(errors_left, shots_left) diff --git a/glue/sample/src/sinter/_data/_anon_task_stats.py b/glue/sample/src/sinter/_data/_anon_task_stats.py index 42281c07d..c7c4dcaa9 100644 --- a/glue/sample/src/sinter/_data/_anon_task_stats.py +++ b/glue/sample/src/sinter/_data/_anon_task_stats.py @@ -1,10 +1,6 @@ import collections import dataclasses -from typing import Counter, Union, TYPE_CHECKING -import numpy as np - -if TYPE_CHECKING: - from sinter._data._task_stats import TaskStats +from typing import Counter @dataclasses.dataclass(frozen=True) @@ -35,16 +31,10 @@ class AnonTaskStats: custom_counts: Counter[str] = dataclasses.field(default_factory=collections.Counter) def __post_init__(self): - assert isinstance(self.errors, (int, np.integer)) - assert isinstance(self.shots, (int, np.integer)) - assert isinstance(self.discards, (int, np.integer)) - assert isinstance(self.seconds, (int, float, np.integer, np.floating)) - assert isinstance(self.custom_counts, collections.Counter) assert self.errors >= 0 assert self.discards >= 0 assert self.seconds >= 0 assert self.shots >= self.errors + self.discards - assert all(isinstance(k, str) and isinstance(v, (int, np.integer)) for k, v in self.custom_counts.items()) def __repr__(self) -> str: terms = [] @@ -80,10 +70,10 @@ def __add__(self, other: 'AnonTaskStats') -> 'AnonTaskStats': """ if isinstance(other, AnonTaskStats): return AnonTaskStats( - shots=self.shots + other.shots, - errors=self.errors + other.errors, - discards=self.discards + other.discards, - seconds=self.seconds + other.seconds, + shots=int(self.shots + other.shots), + errors=int(self.errors + other.errors), + discards=int(self.discards + other.discards), + seconds=float(self.seconds + other.seconds), custom_counts=self.custom_counts + other.custom_counts, ) diff --git a/glue/sample/src/sinter/_decoding/_decoding.py b/glue/sample/src/sinter/_decoding/_decoding.py index 1e54f87ef..e45aef72b 100644 --- a/glue/sample/src/sinter/_decoding/_decoding.py +++ b/glue/sample/src/sinter/_decoding/_decoding.py @@ -177,13 +177,7 @@ def sample_decode(*, were executed. The detection fraction is the ratio of these two numbers. num_shots: The number of sample shots to take from the circuit. - decoder: The name of the decoder to use. Allowed values are: - "pymatching": - Use pymatching min-weight-perfect-match decoder. - "internal": - Use internal decoder with uncorrelated decoding. - "internal_correlated": - Use internal decoder with correlated decoding. + decoder: The name of the decoder to use. For example, 'pymatching'. tmp_dir: An existing directory that is currently empty where temporary files can be written as part of performing decoding. If set to None, one is created using the tempfile package. diff --git a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py index 92d8d49dd..93ffa584f 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py +++ b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py @@ -12,6 +12,7 @@ BUILT_IN_DECODERS: Dict[str, Decoder] = { 'vacuous': VacuousDecoder(), 'pymatching': PyMatchingDecoder(), + 'pymatching-correlated': PyMatchingDecoder(use_correlated_decoding=True), 'fusion_blossom': FusionBlossomDecoder(), # an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049) 'hypergraph_union_find': HyperUFDecoder(), diff --git a/glue/sample/src/sinter/_decoding/_decoding_pymatching.py b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py index b57bb32bc..1473d037f 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_pymatching.py +++ b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py @@ -1,26 +1,52 @@ from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder +def check_pymatching_version_for_correlated_decoding(pymatching): + v = pymatching.__version__.split('.') + try: + a = int(v[0]) + b = int(v[1]) + c = int(''.join(e for e in v[2] if e in '0123456789')) # In case dev version + except (ValueError, IndexError): + return # Probably it's the future. + + if (a, b, c) < (2, 3, 1): + raise ValueError( + "PyMatching version must be at least 2.3.1 for correlated decoding.\n" + f"Installed version: {pymatching.__version__}\n" + "To fix this, install a newer version of pymatching into your environment.\n" + "For example, if you are using pip, run `pip install pymatching --upgrade`.\n" + ) + + class PyMatchingCompiledDecoder(CompiledDecoder): - def __init__(self, matcher: 'pymatching.Matching'): + def __init__(self, matcher: 'pymatching.Matching', use_correlated_decoding: bool): self.matcher = matcher + self.use_correlated_decoding = use_correlated_decoding def decode_shots_bit_packed( self, *, bit_packed_detection_event_data: 'np.ndarray', ) -> 'np.ndarray': + kwargs = {} + if self.use_correlated_decoding: + kwargs['enable_correlations'] = True return self.matcher.decode_batch( shots=bit_packed_detection_event_data, bit_packed_shots=True, bit_packed_predictions=True, return_weights=False, + **kwargs, ) class PyMatchingDecoder(Decoder): """Use pymatching to predict observables from detection events.""" + def __init__(self, use_correlated_decoding: bool = False): + self.use_correlated_decoding = use_correlated_decoding + def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder: try: import pymatching @@ -31,7 +57,14 @@ def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> Compiled "For example, if you are using pip, run `pip install pymatching`.\n" ) from ex - return PyMatchingCompiledDecoder(pymatching.Matching.from_detector_error_model(dem)) + kwargs = {} + if self.use_correlated_decoding: + check_pymatching_version_for_correlated_decoding(pymatching) + kwargs['enable_correlations'] = True + return PyMatchingCompiledDecoder( + pymatching.Matching.from_detector_error_model(dem, **kwargs), + use_correlated_decoding=self.use_correlated_decoding, + ) def decode_via_files(self, *, @@ -60,7 +93,9 @@ def decode_via_files(self, if not hasattr(pymatching, 'cli'): raise ValueError(""" The installed version of pymatching has no `pymatching.cli` method. + sinter requires pymatching 2.1.0 or later. + If you're using pip to install packages, this can be fixed by running ``` @@ -69,13 +104,18 @@ def decode_via_files(self, """) - result = pymatching.cli(command_line_args=[ + args = [ "predict", "--dem", str(dem_path), "--in", str(dets_b8_in_path), "--in_format", "b8", "--out", str(obs_predictions_b8_out_path), "--out_format", "b8", - ]) + ] + if self.use_correlated_decoding: + check_pymatching_version_for_correlated_decoding(pymatching) + args.append("--enable_correlations") + + result = pymatching.cli(command_line_args=args) if result: raise ValueError("pymatching.cli returned a non-zero exit code") diff --git a/glue/sample/src/sinter/_decoding/_decoding_test.py b/glue/sample/src/sinter/_decoding/_decoding_test.py index cd4e28d0d..7dd08f379 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_test.py +++ b/glue/sample/src/sinter/_decoding/_decoding_test.py @@ -233,6 +233,8 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo @pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) def test_post_selection(decoder: str, force_streaming: Optional[bool]): + if decoder == 'pymatching-correlated': + pytest.skip("Correlated matching does not support error probabilities > 0.5 in from_detector_error_model") circuit = stim.Circuit(""" X_ERROR(0.6) 0 M 0 @@ -243,7 +245,7 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]): M 1 DETECTOR(1, 0, 0) rec[-1] OBSERVABLE_INCLUDE(0) rec[-1] - + X_ERROR(0.1) 2 M 2 OBSERVABLE_INCLUDE(0) rec[-1] diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py index ea244b849..4111a2be1 100755 --- a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py @@ -81,7 +81,7 @@ def classify_discards_and_errors( out_count_observable_error_combos[err_key] += 1 num_errors = np.count_nonzero(fail_mask) - return num_discards, num_errors + return int(num_discards), int(num_errors) class DiskDecoder(CompiledDecoder): diff --git a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_flip.py b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_flip.py index 12d634557..a13b9e873 100644 --- a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_flip.py +++ b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_flip.py @@ -261,6 +261,14 @@ def parse_leakage_tag(op: stim.CircuitInstruction) -> LeakageParams | None: if op.name != "MPAD": raise ValueError("Only MPAD can have a LEAKAGE_MEASUREMENT tag.") return _parse_leakage_measurement(tag) + elif tag == "LEAKAGE_SWAP": + if op.name not in ["II_ERROR", "SWAP", "II"]: + raise ValueError("Only II_ERROR and SWAP can have a LEAKAGE_SWAP tag.") + return None + elif tag.startswith("LEAKAGE_DETECTOR"): + if op.name != "DETECTOR": + raise ValueError("Only DETECTOR can have a LEAKAGE_DETECTOR tag.") + return None # from here on out, we raise an error on anything malformed match = LEAKAGE_TAG_MATCH.fullmatch(tag) diff --git a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_tableau.py b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_tableau.py index c05736b1f..722fbffd7 100644 --- a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_tableau.py +++ b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_tag_parsing_tableau.py @@ -323,6 +323,7 @@ def _parse_conditioned_on_pair(args: str, tag: str) -> LeakageConditioningParams ADDITIONAL_TAGS = [ "LEAKAGE_MEASUREMENT", + "LEAKAGE_SWAP" ] @@ -396,6 +397,14 @@ def parse_leakage_tag(op: stim.CircuitInstruction) -> LeakageParams | None: if op.name != "MPAD": raise ValueError("Only MPAD can have a LEAKAGE_MEASUREMENT tag.") return _parse_leakage_measurement(tag) + elif tag.startswith("LEAKAGE_SWAP"): + if op.name not in ["II_ERROR", "SWAP", "II"]: + raise ValueError("Only II_ERROR and SWAP can have a LEAKAGE_SWAP tag.") + return None + elif tag.startswith("LEAKAGE_DETECTOR"): + if op.name != "DETECTOR": + raise ValueError("Only DETECTOR can have a LEAKAGE_DETECTOR tag.") + return None # from here on out, we raise an error on anything malformed match = LEAKAGE_TAG_MATCH.fullmatch(tag) @@ -466,6 +475,7 @@ def parse_leakage_tag(op: stim.CircuitInstruction) -> LeakageParams | None: f"Failed to recognise existing leakage tag name {name}. " "This one is on us, not you. File a bug." ) + raise ValueError( f"Unrecognised LEAKAGE tag name {name}: " f"must be one of { diff --git a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip.py b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip.py index ce475924f..0c5b6aaea 100644 --- a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip.py +++ b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip.py @@ -1,5 +1,6 @@ import dataclasses import itertools as it +from copy import copy import numpy as np from numpy.typing import NDArray @@ -111,6 +112,8 @@ def handle_op(self, op: stim.CircuitInstruction, sss: FlipsideSimulator): ]: # if we're depolarizing on leak, these instructions undo that error, so we re-depolarize self._depolarize_leaked_qubits(fss=sss, op=op) + elif op.tag == "LEAKAGE_SWAP": + self.leakage_swap(op=op) return params = self.ops_to_params[op] @@ -128,6 +131,16 @@ def handle_op(self, op: stim.CircuitInstruction, sss: FlipsideSimulator): case _: raise ValueError(f"Unrecognised LEAKAGE params: {params}") + def leakage_swap( + self, + op: stim.CircuitInstruction + ): + """Implement swaping of leakage states""" + for target in op.target_groups(): + state_old = copy(self.state[target[0].value]) + self.state[target[0].value] = copy(self.state[target[1].value]) + self.state[target[1].value] = state_old + def _depolarize_leaked_qubits( self, op: stim.CircuitInstruction, fss: FlipsideSimulator ): diff --git a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip_test.py b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip_test.py index 345ceae0f..f812f926a 100644 --- a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip_test.py +++ b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_flip_test.py @@ -177,6 +177,27 @@ def test_leakage_transition_1(self): assert np.all(xs_before == xs_after) # Check for NO depolarization assert np.all(zs_before == zs_after) + def test_leakage_swap(self): + circuit_a = stim.Circuit( + """ + R 0 1 + H 0 1 + I_ERROR[LEAKAGE_TRANSITION_1: (1.0, U-->2)] 0 + II[LEAKAGE_SWAP] 0 1 + """ + ) + fss_a = self._get_simulator_for_circuit(circuit_a) + fss_a.interactive_do(circuit_a[0]) + fss_a.interactive_do(circuit_a[1]) + assert np.all(fss_a.compiled_op_handler.state[0, :] == 0) + + fss_a.interactive_do(circuit_a[2]) + assert np.all(fss_a.compiled_op_handler.state[0, :] == 2) + assert np.all(fss_a.compiled_op_handler.state[1, :] == 0) + fss_a.interactive_do(circuit_a[3]) + assert np.all(fss_a.compiled_op_handler.state[0, :] == 0) + assert np.all(fss_a.compiled_op_handler.state[1, :] == 2) + def test_leakage_transition_Z(self): circuit = stim.Circuit( """ diff --git a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau.py b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau.py index 2d48ee8eb..277c6227a 100644 --- a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau.py +++ b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau.py @@ -196,7 +196,11 @@ def _construct_stim_instruction( def handle_op(self, op: stim.CircuitInstruction, sss: TablesideSimulator): """handle a single stim CircuitInstruction.""" if op not in self.claimed_ops_keys: - if self.unconditional_condition_on_U: + if op.tag == "LEAKAGE_SWAP": + self.leakage_swap(op=op) + sss._do_bare_instruction(op) + return + elif self.unconditional_condition_on_U: op_name = op.name gate_data = stim.GateData(op_name) if gate_data.produces_measurements or not ( @@ -247,6 +251,16 @@ def handle_op(self, op: stim.CircuitInstruction, sss: TablesideSimulator): case _: raise ValueError(f"Unrecognised LEAKAGE params: {params}") + def leakage_swap( + self, + op: stim.CircuitInstruction + ): + """Implement swaping of leakage states""" + for target in op.target_groups(): + state_old = self.state[target[0].value] + self.state[target[0].value] = self.state[target[1].value] + self.state[target[1].value] = state_old + def leakage_conditioning( self, op: stim.CircuitInstruction, diff --git a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau_test.py b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau_test.py index 19e7c5b9d..d5fc08d73 100644 --- a/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau_test.py +++ b/glue/stimside/src/stimside/op_handlers/leakage_handlers/leakage_uint8_tableau_test.py @@ -123,6 +123,28 @@ def test_leakage_transition_1(self): assert np.any(records[:, :] == 0) assert np.any(records[:, :] == 1) + def test_leakage_swap(self): + # Part a: Computational to Leaked + circuit_a = stim.Circuit( + """ + R 0 1 + H 0 1 + I_ERROR[LEAKAGE_TRANSITION_1: (1.0, U-->2)] 0 + II[LEAKAGE_SWAP] 0 1 + """ + ) + tss_a = self._get_simulator_for_circuit(circuit_a) + tss_a.interactive_do(circuit_a[0]) + tss_a.interactive_do(circuit_a[1]) + assert np.all(tss_a._compiled_op_handler.state[:] == 0) + tss_a.interactive_do(circuit_a[2]) + assert np.all(tss_a._compiled_op_handler.state[0] == 2) + assert np.all(tss_a._compiled_op_handler.state[1] == 0) + + tss_a.interactive_do(circuit_a[3]) + assert np.all(tss_a._compiled_op_handler.state[0] == 0) + assert np.all(tss_a._compiled_op_handler.state[1] == 2) + def test_leakage_transition_2(self): # Test case: (U, U) -> (L2, L3) circuit = stim.Circuit( diff --git a/glue/stimside/src/stimside/sampler_tableau.py b/glue/stimside/src/stimside/sampler_tableau.py index 78870228a..a2c723b2c 100644 --- a/glue/stimside/src/stimside/sampler_tableau.py +++ b/glue/stimside/src/stimside/sampler_tableau.py @@ -23,16 +23,18 @@ def __init__( | None ) = None, decoder: sinter.Decoder = sinter.BUILT_IN_DECODERS["pymatching"], + seed: int | None = None, ): self.op_handler = op_handler self.decoder: sinter.Decoder = decoder self.batch_size = batch_size self.dem_gen = dem_gen + self.seed = seed def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler: if task.circuit is None: raise ValueError( - "FlipsideSampler requires a circuit in the task to compile a sampler." + "TablesideSampler requires a circuit in the task to compile a sampler." ) if self.dem_gen is None: dem_gen = task.detector_error_model or task.circuit.detector_error_model() @@ -50,6 +52,7 @@ def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler # The op_handler only handles 1 shot, # But the simulator generates multiple shots with the modified circuit on demand ), + seed=self.seed, ) @@ -64,12 +67,14 @@ def __init__( ), compiled_op_handler: CompiledOpHandler, batch_size: int, + seed: int | None = None, ): self.circuit = circuit self.tab_simulator = TablesideSimulator( circuit=circuit, compiled_op_handler=compiled_op_handler, batch_size=batch_size, + seed=seed, ) self.batch_size = batch_size diff --git a/glue/stimside/src/stimside/simulator_tableau.py b/glue/stimside/src/stimside/simulator_tableau.py index 21df2fa57..8939ce2de 100644 --- a/glue/stimside/src/stimside/simulator_tableau.py +++ b/glue/stimside/src/stimside/simulator_tableau.py @@ -65,7 +65,7 @@ def __init__( self.np_rng = np.random.default_rng(seed=seed) - self._tableau_simulator = stim.TableauSimulator() + self._tableau_simulator = stim.TableauSimulator(seed=seed) self._new_circuit = stim.Circuit() self._construct_reference_circuit = construct_reference_circuit if self._construct_reference_circuit: @@ -125,7 +125,7 @@ def clear(self): # self._tableau_simulator.set_inverse_tableau(stim.Tableau(0)) # Without a clear method from the TableauSimulator, we will just initiate a new one for now - self._tableau_simulator = stim.TableauSimulator() + self._tableau_simulator = stim.TableauSimulator(seed=self.seed) self._new_circuit.clear() self._final_measurement_records = None @@ -357,4 +357,4 @@ def get_detector_flips(self, append_observables=False) -> NDArray[np.bool_]: def get_observable_flips(self) -> NDArray[np.bool_]: if self._observable_flips is None: self._convert_measurements_to_detector_flips() - return self._observable_flips \ No newline at end of file + return self._observable_flips diff --git a/package-lock.json b/package-lock.json index dfb70ab3b..6165a9197 100644 --- a/package-lock.json +++ b/package-lock.json @@ -318,9 +318,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "dependencies": { "brace-expansion": "^1.1.7" @@ -888,9 +888,9 @@ } }, "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "requires": { "brace-expansion": "^1.1.7" diff --git a/src/stim/cmd/command_sample.cc b/src/stim/cmd/command_sample.cc index 2d94acdc1..01d85f59b 100644 --- a/src/stim/cmd/command_sample.cc +++ b/src/stim/cmd/command_sample.cc @@ -21,18 +21,20 @@ #include "stim/simulators/tableau_simulator.h" #include "stim/util_bot/arg_parse.h" #include "stim/util_bot/probability_util.h" +#include "stim/util_top/reference_sample_tree.h" using namespace stim; int stim::command_sample(int argc, const char **argv) { check_for_unknown_arguments( - {"--seed", "--skip_reference_sample", "--out_format", "--out", "--in", "--shots"}, + {"--seed", "--skip_reference_sample", "--skip_loop_folding", "--out_format", "--out", "--in", "--shots"}, {"--sample", "--frame0"}, "sample", argc, argv); const auto &out_format = find_enum_argument("--out_format", "01", format_name_to_enum_map(), argc, argv); bool skip_reference_sample = find_bool_argument("--skip_reference_sample", argc, argv); + bool skip_loop_folding = find_bool_argument("--skip_loop_folding", argc, argv); uint64_t num_shots = find_argument("--shots", argc, argv) ? (uint64_t)find_int64_argument("--shots", 1, 0, INT64_MAX, argc, argv) : find_argument("--sample", argc, argv) ? (uint64_t)find_int64_argument("--sample", 1, 0, INT64_MAX, argc, argv) @@ -56,7 +58,13 @@ int stim::command_sample(int argc, const char **argv) { auto circuit = Circuit::from_file(in); simd_bits ref(0); if (!skip_reference_sample) { - ref = TableauSimulator::reference_sample_circuit(circuit); + if (skip_loop_folding) { + ref = TableauSimulator::reference_sample_circuit(circuit); + } else { + ReferenceSampleTree reference_sample_measurement_bits = + ReferenceSampleTree::from_circuit_reference_sample(circuit.aliased_noiseless_circuit()); + reference_sample_measurement_bits.decompress_into(ref); + } } sample_batch_measurements_writing_results_to_disk(circuit, ref, num_shots, out, out_format.id, rng); } @@ -128,6 +136,37 @@ SubCommandHelp stim::command_sample_help() { )PARAGRAPH"), }); + result.flags.push_back( + SubCommandHelpFlag{ + "--skip_loop_folding", + "bool", + "false", + {"[none]", "[switch]"}, + clean_doc_string(R"PARAGRAPH( + Skips loop folding logic on the reference sample calculation. + + When this argument is specified, the reference sample (that is used + to convert measurement flip data from frame simulations into actual + measurement data) is generated by iterating through the entire + flattened circuit with no loop detection. + + Loop folding can enormously improve performance for circuits + containing REPEAT blocks with large repeat counts, by detecting + periodicity in loops and fast-forwarding across them when computing + the reference sample for the circuit. However, in some cases the + analysis is not able to detect the periodicity that is present. For + example, this has been observed in honeycomb code circuits. When + this happens, the folding-capable analysis is slower than simply + analyzing the flattened circuit without any specialized loop logic. + The `--skip_loop_folding` flag can be used to just analyze the + flattened circuit, bypassing this slowdown for circuits such as + honeycomb code circuits. + + By default, loop detection is enabled. Pass this flag to disable + it (when appropriate by use case). + )PARAGRAPH"), + }); + result.flags.push_back( SubCommandHelpFlag{ "--out_format", diff --git a/src/stim/cmd/command_sample.test.cc b/src/stim/cmd/command_sample.test.cc index 2ed89ede5..4ce7abd4f 100644 --- a/src/stim/cmd/command_sample.test.cc +++ b/src/stim/cmd/command_sample.test.cc @@ -26,7 +26,7 @@ std::unordered_map line_freq_with_lifetime_matching_ar std::unordered_map result{}; size_t start = 0; for (size_t k = 0; k <= data.size(); k++) { - if (data[k] == '\n' || data[k] == '\0') { + if (k == data.size() || data[k] == '\n' || data[k] == '\0') { result[data.substr(start, k - start)]++; start = k + 1; } diff --git a/src/stim/diagram/crumble_data.cc b/src/stim/diagram/crumble_data.cc index 28611b433..c81dfb58d 100644 --- a/src/stim/diagram/crumble_data.cc +++ b/src/stim/diagram/crumble_data.cc @@ -680,14 +680,14 @@ std::string stim_draw_internal::make_crumble_html() { result.append(R"CRUMBLE_PART(=e.Z.name||t.Y[0]!==e.Y[0])}this.Le(a,t)}}}(document.getElementById("cvn"));function zt(){var t=m.xe().xt().replaceAll("\nPOLYGON","\n#!pragma POLYGON").replaceAll("\nERR","\n#!pragma ERR").replaceAll("\nMARK","\n#!pragma MARK"),r=y;r.value=t+"\n",r.focus(),r.select()}v.addEventListener("click",t=>{zt()}),Dt.addEventListener("click",t=>{var r=y.value,r=u.It(r);m.commit(r)}),St.addEventListener("click",t=>{var r=document.getElementById("divImportExport");"none"===r.style.display?(r.style.display="block",St.textContent="Hide Import/Export",zt()):(r.style.display="none",St.textContent="Show Import/Export",y.value=""),setTimeout(()=>{window.scrollTo(0,0)},0)}),xt.addEventListener("click",t=>{m.Be()}),kt.addEventListener("click",t=>{m.Nr()}),Lt.addEventListener("click",t=>{m.de=new Map(m.we.entries()),m.Ge()}),Ut.addEventListener("click",t=>{m.ei(!1),m.Ge()}),Gt.addEventListener("click",t=>{"none"===Ht.style.display?(Ht.style.display="block",Gt.textContent="Hide Example Circuits"):(Ht.style.display="none",Gt.textC)CRUMBLE_PART"); result.append(R"CRUMBLE_PART(ontent="Show Example Circuits")}),Ft.addEventListener("click",t=>{m.de=new Map,m.Ge()}),gt.addEventListener("click",t=>{m.Dr()}),Ct.addEventListener("click",t=>{m.qe()}),Tt.addEventListener("click",t=>{m.Ve(1,!1)}),Nt.addEventListener("click",t=>{m.Ve(-1,!1)}),t.addEventListener("click",t=>{m.Ke(!1)}),At.addEventListener("click",t=>{m.ze(!1)}),Pt.addEventListener("click",t=>{m.je(m.Ur+1)}),Ot.addEventListener("click",t=>{m.je(m.Ur-1)}),window.addEventListener("resize",t=>{m.canvas.width=m.canvas.scrollWidth,m.canvas.height=m.canvas.scrollHeight,m.Ge()}),m.canvas.addEventListener("mousemove",t=>{m.qr=t.offsetX+e,m.Wr=t.offsetY+o;var r=m.canvas.width/2;Kt&&1===t.buttons?m.je(Math.floor((t.offsetX-r)/8)):m.Ge()});let Kt=!1;m.canvas.addEventListener("mousedown",t=>{m.qr=t.offsetX+e,m.Wr=t.offsetY+o,m.ue=t.offsetX+e,m.Xe=t.offsetY+o;var r=m.canvas.width/2;(Kt=t.offsetY<20&&t.offsetX>r&&1===t.buttons)?m.je(Math.floor((t.offsetX-r)/8)):m.Ge()}),m.canvas.addEventListener("mouseup",t=>{var r=m.$e(t.altKey);m.ue=void 0)CRUMBLE_PART"); result.append(R"CRUMBLE_PART(,m.Xe=void 0,m.qr=t.offsetX+e,m.Wr=t.offsetY+o,m.Je(r,t.shiftKey,t.ctrlKey),1===t.buttons&&(Kt=!1)});let Qt=void 0;async function $t(){let e=m.xe();e.yt=[e.yt[m.Ur]],0{var r=e.Rt[2*t],t=e.Rt[2*t+1];return m.we.has(r+","+t)}),[r,t]=D(m.we.values()),e=e.St(-r,-t));var t,r=e.xt();Qt=r;try{await navigator.clipboard.writeText(r)}catch(t){console.warn("Failed to write to clipboard. Using fallback emulated clipboard.",t)}}async function Bt(e){let i;try{i=await navigator.clipboard.readText()}catch(t){console.warn("Failed to read from clipboard. Using fallback emulated clipboard.",t),i=Qt}if(void 0!==i){let r=u.It(i);if(1!==r.yt.length)throw new Error(i);let t=m.xe();0m.Ve(-1,t)),o.set("t",t=>m.Ve(1,t)),o.set("escape",()=>m.Ue),o.set("delete",t=>m.He(t)),o.set("backspace",t=>m.He(t)),o.set("ctrl+delete",t=>m.ze(t)),o.set("ctrl+insert",t=>m.Ke(t)),o.set("ctrl+backspace",()=>m.ze),o.set("ctrl+z",t=>{t||m.Nr()}),o.set("ctrl+y",t=>{t||m.Dr()}),o.set("ctrl+shift+z",t=>{t||m.Dr()}),o.set("ctrl+c",async t=>{await $t()}),o.set("ctrl+v",Bt),o.set("ctrl+x",async t=>{var r;await $t(),0===m.we.size?((r=m.xe()).yt[m.Ur].et.clear(),r.yt[m.Ur].it.length=0,m.Le(r,t)):m.He(t)}),o.set("l",t=>{t||(m.de=new Map(m.we.entries()),m.Ge())}),o.set(" ",t=>m.ei(t));for(let[t,r]of[["1",0],["2",1],["3",2],["4",3],["5",4],["6",5],["7",6],["8",7],["9",8],["0",9],["-",10],["=",11],["\\",12],["`",13]])o.set(""+t,t=>m.ri(t,r)),o.set(t+"+x",t=>m.ni(t,B.get("MARKX").M(r))),o.set(t+"+y",t=>m.ni(t,B.get("MARKY").M(r))),o.set(t+"+z",t=>m.ni(t,B.get("MARKZ").M(r))),o.set(t+"+d",t=>m.li(t,r)),o.set(t+)CRUMBLE_PART"); - result.append(R"CRUMBLE_PART("+o",t=>m.si(t,r)),o.set(t+"+j",t=>m.vi(t,r)),o.set(t+"+k",t=>m.fi(t,r));let r=.25;function a(t,r,e=void 0){for(var i of t){if(o.has(i))throw new Error("Chord collision: "+i);o.set(i,t=>m.ni(t,B.get(r)))}void 0!==e&&a(t.map(t=>"shift+"+t),e)}return o.set("p",t=>m.ni(t,B.get("POLYGON"),[1,0,0,r])),o.set("alt+p",t=>m.ni(t,B.get("POLYGON"),[0,1,0,r])),o.set("shift+p",t=>m.ni(t,B.get("POLYGON"),[0,0,1,r])),o.set("p+x",t=>m.ni(t,B.get("POLYGON"),[1,0,0,r])),o.set("p+y",t=>m.ni(t,B.get("POLYGON"),[0,1,0,r])),o.set("p+z",t=>m.ni(t,B.get("POLYGON"),[0,0,1,r])),o.set("p+x+y",t=>m.ni(t,B.get("POLYGON"),[1,1,0,r])),o.set("p+x+z",t=>m.ni(t,B.get("POLYGON"),[1,0,1,r])),o.set("p+y+z",t=>m.ni(t,B.get("POLYGON"),[0,1,1,r])),o.set("p+x+y+z",t=>m.ni(t,B.get("POLYGON"),[1,1,1,r])),o.set("m+p+x",t=>m.ni(t,f("X".repeat(m.we.size)),[])),o.set("m+p+y",t=>m.ni(t,f("Y".repeat(m.we.size)),[])),o.set("m+p+z",t=>m.ni(t,f("Z".repeat(m.we.size)),[])),o.set("f",t=>m.De(t)),o.set("g",t=>m.Fe(t)),o.set("shift+>",t=>m.We((t,r)=>[t+1,r],t,!1)))CRUMBLE_PART"); - result.append(R"CRUMBLE_PART(,o.set("shift+<",t=>m.We((t,r)=>[t-1,r],t,!1)),o.set("shift+v",t=>m.We((t,r)=>[t,r+1],t,!1)),o.set("shift+^",t=>m.We((t,r)=>[t,r-1],t,!1)),o.set(">",t=>m.We((t,r)=>[t+1,r],t,!1)),o.set("<",t=>m.We((t,r)=>[t-1,r],t,!1)),o.set("v",t=>m.We((t,r)=>[t,r+1],t,!1)),o.set("^",t=>m.We((t,r)=>[t,r-1],t,!1)),o.set(".",t=>m.We((t,r)=>[t+.5,r+.5],t,!1)),a(["h","h+y","h+x+z"],"H","H"),a(["h+z","h+x+y"],"H_XY","H_XY"),a(["h+x","h+y+z"],"H_YZ","H_YZ"),a(["s+x","s+y+z"],"SQRT_X","SQRT_X_DAG"),a(["s+y","s+x+z"],"SQRT_Y","SQRT_Y_DAG"),a(["s","s+z","s+x+y"],"S","S_DAG"),a(["r+x","r+y+z"],"RX"),a(["r+y","r+x+z"],"RY"),a(["r","r+z","r+x+y"],"R"),a(["m+x","m+y+z"],"MX"),a(["m+y","m+x+z"],"MY"),a(["m","m+z","m+x+y"],"M"),a(["m+r+x","m+r+y+z"],"MRX"),a(["m+r+y","m+r+x+z"],"MRY"),a(["m+r","m+r+z","m+r+x+y"],"MR"),a(["c"],"CX","CX"),a(["c+x"],"CX","CX"),a(["c+y"],"CY","CY"),a(["c+z"],"CZ","CZ"),a(["j+x"],"X","X"),a(["j+y"],"Y","Y"),a(["j+z"],"Z","Z"),a(["c+x+y"],"XCY","XCY"),a(["alt+c+x"],"XCX","XCX"),a(["alt+c+y"],"YCY","YCY"),a(["w"])CRUMBLE_PART"); - result.append(R"CRUMBLE_PART(,"SWAP","SWAP"),a(["w+x"],"CXSWAP",void 0),a(["c+w+x"],"CXSWAP",void 0),a(["i+w"],"ISWAP","ISWAP_DAG"),a(["w+z"],"CZSWAP",void 0),a(["c+w+z"],"CZSWAP",void 0),a(["c+w"],"CZSWAP",void 0),a(["c+t"],"C_XYZ","C_ZYX"),a(["c+s+x"],"SQRT_XX","SQRT_XX_DAG"),a(["c+s+y"],"SQRT_YY","SQRT_YY_DAG"),a(["c+s+z"],"SQRT_ZZ","SQRT_ZZ_DAG"),a(["c+s"],"SQRT_ZZ","SQRT_ZZ_DAG"),a(["c+m+x"],"MXX","MXX"),a(["c+m+y"],"MYY","MYY"),a(["c+m+z"],"MZZ","MZZ"),a(["c+m"],"MZZ","MZZ"),o}();function Wt(r){if(m.Oe.jt(r),"keydown"===r.type){if("q"===r.key.toLowerCase())return e=r.shiftKey?5:1,void m.je(m.Ur-e);if("e"===r.key.toLowerCase())return e=r.shiftKey?5:1,void m.je(m.Ur+e);if("Home"===r.key)return m.je(0),void r.preventDefault();if("End"===r.key)return m.je(m.xe().yt.length-1),void r.preventDefault()}var t=m.Oe.Bt;if(0!==t.length){for(var e=t[t.length-1];0{m.Te.set(m.Ne(void 0));var t=m.Oe.qt(!1),a=(r.width=r.scrollWidth,r.height=r.scrollHeight,r.getContext("2d"));a.clearRect(0,0,r.width,r.height),a.textAlign="right",a.textBaseline="middle",a.fillText("X",7.5,24.5),a.fillText("Y",7.5,56.5),a.fillText("Z",7.5,88.5),a.textAlign="center",a.textBaseline="bottom";for(let t=0;t{try{var t,r=(()=>{var t=document.location.hash.substring(1),r=new Map;if(""!==t)for(var e of t.split("&")){var i,o=e.indexOf("=");-1!==o&&(i=e.substring(0,o),e=decodeURIComponent(e.substring(o+1)),r.set(i,e))})CRUMBLE_PART"); - result.append(R"CRUMBLE_PART(return r})(),e=(r.has("circuit")||("[[[DEFAULT-CIRCUIT-CONTENT-LITERAL]]]"===(t=document.getElementById("txtDefaultCircuit")).value.replaceAll("_","-")?r.set("circuit",""):r.set("circuit",t.value)),u.It(r.get("circuit"))),i=e.xt();Vt.clear(i),e.yt.every(t=>t.ut())&&1===r.size&&i===r.get("circuit")?p.Ie():p.be(i,Yt(i))}catch(t){throw new Error(t)}},window.addEventListener("popstate",v),v(),Vt.kr().Yr().Zr(1).subscribe(t=>{p.be(t,Yt(t))})}m.Te.mr().subscribe(t=>requestAnimationFrame(()=>wt(m.canvas.getContext("2d"),t))),window.addEventListener("focus",()=>{m.Oe.Vt()}),window.addEventListener("blur",()=>{m.Oe.Vt()});for(let r of document.getElementById("examples-div").querySelectorAll("a"))r.onclick=t=>{if(!(t.shiftKey||t.ctrlKey||t.altKey||0!==t.button))return t=r.href.split("#circuit=")[1],m.rev.commit(t),!1}; + result.append(R"CRUMBLE_PART(.Ur].put(new F(h.Z,h.tag,h.Y,new Uint32Array(c)))}m.Le(t,e)}}const qt=function(){let o=new Map;o.set("shift+t",t=>m.Ve(-1,t)),o.set("t",t=>m.Ve(1,t)),o.set("escape",()=>m.Ue()),o.set("delete",t=>m.He(t)),o.set("backspace",t=>m.He(t)),o.set("ctrl+delete",t=>m.ze(t)),o.set("ctrl+insert",t=>m.Ke(t)),o.set("ctrl+backspace",t=>m.ze(t)),o.set("ctrl+z",t=>{t||m.Nr()}),o.set("ctrl+y",t=>{t||m.Dr()}),o.set("ctrl+shift+z",t=>{t||m.Dr()}),o.set("ctrl+c",async t=>{await $t()}),o.set("ctrl+v",Bt),o.set("ctrl+x",async t=>{var r;await $t(),0===m.we.size?((r=m.xe()).yt[m.Ur].et.clear(),r.yt[m.Ur].it.length=0,m.Le(r,t)):m.He(t)}),o.set("l",t=>{t||(m.de=new Map(m.we.entries()),m.Ge())}),o.set(" ",t=>m.ei(t));for(let[t,r]of[["1",0],["2",1],["3",2],["4",3],["5",4],["6",5],["7",6],["8",7],["9",8],["0",9],["-",10],["=",11],["\\",12],["`",13]])o.set(""+t,t=>m.ri(t,r)),o.set(t+"+x",t=>m.ni(t,B.get("MARKX").M(r))),o.set(t+"+y",t=>m.ni(t,B.get("MARKY").M(r))),o.set(t+"+z",t=>m.ni(t,B.get("MARKZ").M(r))),o.set(t+"+d",t=>m.li(t,r)),o.se)CRUMBLE_PART"); + result.append(R"CRUMBLE_PART(t(t+"+o",t=>m.si(t,r)),o.set(t+"+j",t=>m.vi(t,r)),o.set(t+"+k",t=>m.fi(t,r));let r=.25;function a(t,r,e=void 0){for(var i of t){if(o.has(i))throw new Error("Chord collision: "+i);o.set(i,t=>m.ni(t,B.get(r)))}void 0!==e&&a(t.map(t=>"shift+"+t),e)}return o.set("p",t=>m.ni(t,B.get("POLYGON"),[1,0,0,r])),o.set("alt+p",t=>m.ni(t,B.get("POLYGON"),[0,1,0,r])),o.set("shift+p",t=>m.ni(t,B.get("POLYGON"),[0,0,1,r])),o.set("p+x",t=>m.ni(t,B.get("POLYGON"),[1,0,0,r])),o.set("p+y",t=>m.ni(t,B.get("POLYGON"),[0,1,0,r])),o.set("p+z",t=>m.ni(t,B.get("POLYGON"),[0,0,1,r])),o.set("p+x+y",t=>m.ni(t,B.get("POLYGON"),[1,1,0,r])),o.set("p+x+z",t=>m.ni(t,B.get("POLYGON"),[1,0,1,r])),o.set("p+y+z",t=>m.ni(t,B.get("POLYGON"),[0,1,1,r])),o.set("p+x+y+z",t=>m.ni(t,B.get("POLYGON"),[1,1,1,r])),o.set("m+p+x",t=>m.ni(t,f("X".repeat(m.we.size)),[])),o.set("m+p+y",t=>m.ni(t,f("Y".repeat(m.we.size)),[])),o.set("m+p+z",t=>m.ni(t,f("Z".repeat(m.we.size)),[])),o.set("f",t=>m.De(t)),o.set("g",t=>m.Fe(t)),o.set("shift+>",t=>m.We((t,r)=>[t+1,r],t,)CRUMBLE_PART"); + result.append(R"CRUMBLE_PART(!1)),o.set("shift+<",t=>m.We((t,r)=>[t-1,r],t,!1)),o.set("shift+v",t=>m.We((t,r)=>[t,r+1],t,!1)),o.set("shift+^",t=>m.We((t,r)=>[t,r-1],t,!1)),o.set(">",t=>m.We((t,r)=>[t+1,r],t,!1)),o.set("<",t=>m.We((t,r)=>[t-1,r],t,!1)),o.set("v",t=>m.We((t,r)=>[t,r+1],t,!1)),o.set("^",t=>m.We((t,r)=>[t,r-1],t,!1)),o.set(".",t=>m.We((t,r)=>[t+.5,r+.5],t,!1)),a(["h","h+y","h+x+z"],"H","H"),a(["h+z","h+x+y"],"H_XY","H_XY"),a(["h+x","h+y+z"],"H_YZ","H_YZ"),a(["s+x","s+y+z"],"SQRT_X","SQRT_X_DAG"),a(["s+y","s+x+z"],"SQRT_Y","SQRT_Y_DAG"),a(["s","s+z","s+x+y"],"S","S_DAG"),a(["r+x","r+y+z"],"RX"),a(["r+y","r+x+z"],"RY"),a(["r","r+z","r+x+y"],"R"),a(["m+x","m+y+z"],"MX"),a(["m+y","m+x+z"],"MY"),a(["m","m+z","m+x+y"],"M"),a(["m+r+x","m+r+y+z"],"MRX"),a(["m+r+y","m+r+x+z"],"MRY"),a(["m+r","m+r+z","m+r+x+y"],"MR"),a(["c"],"CX","CX"),a(["c+x"],"CX","CX"),a(["c+y"],"CY","CY"),a(["c+z"],"CZ","CZ"),a(["j+x"],"X","X"),a(["j+y"],"Y","Y"),a(["j+z"],"Z","Z"),a(["c+x+y"],"XCY","XCY"),a(["alt+c+x"],"XCX","XCX"),a(["alt+c+y"],"YCY","YCY"),a([)CRUMBLE_PART"); + result.append(R"CRUMBLE_PART("w"],"SWAP","SWAP"),a(["w+x"],"CXSWAP",void 0),a(["c+w+x"],"CXSWAP",void 0),a(["i+w"],"ISWAP","ISWAP_DAG"),a(["w+z"],"CZSWAP",void 0),a(["c+w+z"],"CZSWAP",void 0),a(["c+w"],"CZSWAP",void 0),a(["c+t"],"C_XYZ","C_ZYX"),a(["c+s+x"],"SQRT_XX","SQRT_XX_DAG"),a(["c+s+y"],"SQRT_YY","SQRT_YY_DAG"),a(["c+s+z"],"SQRT_ZZ","SQRT_ZZ_DAG"),a(["c+s"],"SQRT_ZZ","SQRT_ZZ_DAG"),a(["c+m+x"],"MXX","MXX"),a(["c+m+y"],"MYY","MYY"),a(["c+m+z"],"MZZ","MZZ"),a(["c+m"],"MZZ","MZZ"),o}();function Wt(r){if(m.Oe.jt(r),"keydown"===r.type){if("q"===r.key.toLowerCase())return e=r.shiftKey?5:1,void m.je(m.Ur-e);if("e"===r.key.toLowerCase())return e=r.shiftKey?5:1,void m.je(m.Ur+e);if("Home"===r.key)return m.je(0),void r.preventDefault();if("End"===r.key)return m.je(m.xe().yt.length-1),void r.preventDefault()}var t=m.Oe.Bt;if(0!==t.length){for(var e=t[t.length-1];0{m.Te.set(m.Ne(void 0));var t=m.Oe.qt(!1),a=(r.width=r.scrollWidth,r.height=r.scrollHeight,r.getContext("2d"));a.clearRect(0,0,r.width,r.height),a.textAlign="right",a.textBaseline="middle",a.fillText("X",7.5,24.5),a.fillText("Y",7.5,56.5),a.fillText("Z",7.5,88.5),a.textAlign="center",a.textBaseline="bottom";for(let t=0;t{try{var t,r=(()=>{var t=document.location.hash.substring(1),r=new Map;if(""!==t)for(var e of t.split("&")){var i,o=e.indexOf("=");-1!==o&&(i=e.substring(0,o),e=decodeURIComponent(e.substring(o+1)),r.set(i,)CRUMBLE_PART"); + result.append(R"CRUMBLE_PART(e))}return r})(),e=(r.has("circuit")||("[[[DEFAULT-CIRCUIT-CONTENT-LITERAL]]]"===(t=document.getElementById("txtDefaultCircuit")).value.replaceAll("_","-")?r.set("circuit",""):r.set("circuit",t.value)),u.It(r.get("circuit"))),i=e.xt();Vt.clear(i),e.yt.every(t=>t.ut())&&1===r.size&&i===r.get("circuit")?p.Ie():p.be(i,Yt(i))}catch(t){throw new Error(t)}},window.addEventListener("popstate",v),v(),Vt.kr().Yr().Zr(1).subscribe(t=>{p.be(t,Yt(t))})}m.Te.mr().subscribe(t=>requestAnimationFrame(()=>wt(m.canvas.getContext("2d"),t))),window.addEventListener("focus",()=>{m.Oe.Vt()}),window.addEventListener("blur",()=>{m.Oe.Vt()});for(let r of document.getElementById("examples-div").querySelectorAll("a"))r.onclick=t=>{if(!(t.shiftKey||t.ctrlKey||t.altKey||0!==t.button))return t=r.href.split("#circuit=")[1],m.rev.commit(t),!1}; )CRUMBLE_PART"); result.append(R"CRUMBLE_PART( )CRUMBLE_PART"); diff --git a/src/stim/search/graphlike/algo.cc b/src/stim/search/graphlike/algo.cc index 96a5e234d..3e71102ad 100644 --- a/src/stim/search/graphlike/algo.cc +++ b/src/stim/search/graphlike/algo.cc @@ -15,7 +15,7 @@ #include "stim/search/graphlike/algo.h" #include -#include +#include #include #include @@ -27,7 +27,7 @@ using namespace stim; using namespace stim::impl_search_graphlike; -DetectorErrorModel backtrack_path(const std::map &back_map, const SearchState &final_state) { +DetectorErrorModel backtrack_path(const std::unordered_map &back_map, const SearchState &final_state) { DetectorErrorModel out; auto cur_state = final_state; while (true) { @@ -55,7 +55,7 @@ DetectorErrorModel stim::shortest_graphlike_undetectable_logical_error( } std::queue queue; - std::map back_map; + std::unordered_map back_map; // Mark the vacuous dead-end state as already seen. back_map.emplace(empty_search_state, empty_search_state); diff --git a/src/stim/search/graphlike/algo.perf.cc b/src/stim/search/graphlike/algo.perf.cc index dd297f91a..81f3234f9 100644 --- a/src/stim/search/graphlike/algo.perf.cc +++ b/src/stim/search/graphlike/algo.perf.cc @@ -36,3 +36,21 @@ BENCHMARK(find_graphlike_logical_error_surface_code_d25) { std::cout << "bad"; } } + +BENCHMARK(find_graphlike_logical_error_surface_code_d11_r1000) { + CircuitGenParameters params(1000, 11, "rotated_memory_x"); + params.after_clifford_depolarization = 0.001; + params.before_measure_flip_probability = 0.001; + params.after_reset_flip_probability = 0.001; + params.before_round_data_depolarization = 0.001; + auto circuit = generate_surface_code_circuit(params).circuit; + auto model = ErrorAnalyzer::circuit_to_detector_error_model(circuit, true, true, false, 0.0, false, true); + + size_t total = 0; + benchmark_go([&]() { + total += stim::shortest_graphlike_undetectable_logical_error(model, false).instructions.size(); + }).goal_millis(100); + if (total % 11 != 0 || total == 0) { + std::cout << "bad"; + } +} diff --git a/src/stim/search/graphlike/search_state.h b/src/stim/search/graphlike/search_state.h index 0e99343d3..d36820ad8 100644 --- a/src/stim/search/graphlike/search_state.h +++ b/src/stim/search/graphlike/search_state.h @@ -42,6 +42,22 @@ struct SearchState { }; std::ostream &operator<<(std::ostream &out, const SearchState &v); +inline void hash_combine(size_t &h, uint64_t x) { + h ^= std::hash{}(x) + 0x9e3779b97f4a7c15ULL + (h << 6) + (h >> 2); // mimic Boost's hash-combine function +} + +struct SearchStateHash { + size_t operator()(const SearchState &s) const { + SearchState c = s.canonical(); + size_t h = std::hash{}(c.det_active); + hash_combine(h, c.det_held); + for (size_t i = 0; i < c.obs_mask.num_u64_padded(); i++) { + hash_combine(h, c.obs_mask.u64[i]); + } + return h; + } +}; + } // namespace impl_search_graphlike } // namespace stim diff --git a/src/stim/search/graphlike/search_state.test.cc b/src/stim/search/graphlike/search_state.test.cc index c150e8bff..ec2dc1ded 100644 --- a/src/stim/search/graphlike/search_state.test.cc +++ b/src/stim/search/graphlike/search_state.test.cc @@ -148,3 +148,11 @@ TEST(search_graphlike, DemAdjGraphSearchState_canonical_ordering) { TEST(search_graphlike, DemAdjGraphSearchState_str) { ASSERT_EQ(SearchState(1, 2, obs_mask(3)).str(), "D1 D2 L0 L1 "); } + +TEST(search_graphlike, SearchStateHash_operator) { + ASSERT_EQ(SearchStateHash{}(SearchState(1, 2, obs_mask(3))), SearchStateHash{}(SearchState(2, 1, obs_mask(3)))); + ASSERT_EQ(SearchStateHash{}(SearchState(1, 2, obs_mask(3))), SearchStateHash{}(SearchState(1, 2, obs_mask(3)))); + + ASSERT_NE(SearchStateHash{}(SearchState(1, 2, obs_mask(3))), SearchStateHash{}(SearchState(2, 2, obs_mask(3)))); + ASSERT_NE(SearchStateHash{}(SearchState(1, 2, obs_mask(3))), SearchStateHash{}(SearchState(1, 2, obs_mask(4)))); +} diff --git a/src/stim/util_bot/probability_util.cc b/src/stim/util_bot/probability_util.cc index a0dd5ba7e..2b66c0857 100644 --- a/src/stim/util_bot/probability_util.cc +++ b/src/stim/util_bot/probability_util.cc @@ -21,16 +21,25 @@ using namespace stim; RareErrorIterator::RareErrorIterator(float probability) - : next_candidate(0), is_one(probability == 1), dist(probability) { + : next_candidate(0), probability(probability) { if (!(probability >= 0 && probability <= 1)) { throw std::out_of_range("Invalid probability: " + std::to_string(probability)); } + if (0 < probability && probability < 1) { + dist = std::geometric_distribution(probability); + } } size_t RareErrorIterator::next(std::mt19937_64 &rng) { - size_t result = next_candidate + (is_one ? 0 : dist(rng)); - next_candidate = result + 1; - return result; + if (probability == 0) { + return SIZE_MAX; + } else if (probability == 1) { + return next_candidate++; + } else { + size_t result = next_candidate + dist(rng); + next_candidate = result + 1; + return result; + } } std::vector stim::sample_hit_indices(float probability, size_t attempts, std::mt19937_64 &rng) { diff --git a/src/stim/util_bot/probability_util.h b/src/stim/util_bot/probability_util.h index 65da497d4..e825a5fba 100644 --- a/src/stim/util_bot/probability_util.h +++ b/src/stim/util_bot/probability_util.h @@ -32,8 +32,10 @@ constexpr uint64_t INTENTIONAL_VERSION_SEED_INCOMPATIBILITY = 0xDEADBEEF124CULL; /// Gets more efficient as the hit probability drops. struct RareErrorIterator { size_t next_candidate; - bool is_one = false; + float probability; std::geometric_distribution dist; + RareErrorIterator() = delete; + RareErrorIterator(const RareErrorIterator &) = delete; RareErrorIterator(float probability); size_t next(std::mt19937_64 &rng); diff --git a/src/stim/util_top/reference_sample_tree.h b/src/stim/util_top/reference_sample_tree.h index 92dfcf144..366f4e7d4 100644 --- a/src/stim/util_top/reference_sample_tree.h +++ b/src/stim/util_top/reference_sample_tree.h @@ -37,6 +37,10 @@ struct ReferenceSampleTree { /// Writes the contents of the tree into the given output vector. void decompress_into(std::vector &output) const; + /// Writes the contents of the tree into the given output simd_bits. + template + void decompress_into(simd_bits &output) const; + /// Folds redundant children into the repetition count, if they repeat this many times. /// /// For example, if the tree's children are [A, B, C, A, B, C] and the tree has no diff --git a/src/stim/util_top/reference_sample_tree.inl b/src/stim/util_top/reference_sample_tree.inl index 94c52e27c..670ad97c8 100644 --- a/src/stim/util_top/reference_sample_tree.inl +++ b/src/stim/util_top/reference_sample_tree.inl @@ -2,6 +2,19 @@ namespace stim { +template +void ReferenceSampleTree::decompress_into(simd_bits &output) const { + std::vector v; + this->decompress_into(v); + + simd_bits result(v.size()); + for (size_t k = 0; k < v.size(); k++) { + result[k] ^= v[k]; + } + + output = std::move(result); +} + template ReferenceSampleTree CompressedReferenceSampleHelper::do_loop_with_no_folding(const Circuit &loop, uint64_t reps) { ReferenceSampleTree result;