Skip to content

Commit 2a3693b

Browse files
m-mcewenmliqai
authored andcommitted
Add pymatching correlated to sinter (#1046)
Basically a straight theft of Oscar's PR, but without adding the `package` dependency at Craig's request.
1 parent 97ec43b commit 2a3693b

5 files changed

Lines changed: 51 additions & 12 deletions

File tree

glue/sample/src/sinter/_collection/_collection_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ def _compute_task_ids(self):
227227

228228
shots_left = options.max_shots
229229
errors_left = options.max_errors
230+
if shots_left is None:
231+
raise ValueError("Didn't specify --max_shots.")
230232
if errors_left is None:
231233
errors_left = shots_left
232234
errors_left = min(errors_left, shots_left)

glue/sample/src/sinter/_decoding/_decoding.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,7 @@ def sample_decode(*,
177177
were executed. The detection fraction is the ratio of these two
178178
numbers.
179179
num_shots: The number of sample shots to take from the circuit.
180-
decoder: The name of the decoder to use. Allowed values are:
181-
"pymatching":
182-
Use pymatching min-weight-perfect-match decoder.
183-
"internal":
184-
Use internal decoder with uncorrelated decoding.
185-
"internal_correlated":
186-
Use internal decoder with correlated decoding.
180+
decoder: The name of the decoder to use. For example, 'pymatching'.
187181
tmp_dir: An existing directory that is currently empty where temporary
188182
files can be written as part of performing decoding. If set to
189183
None, one is created using the tempfile package.

glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
BUILT_IN_DECODERS: Dict[str, Decoder] = {
1313
'vacuous': VacuousDecoder(),
1414
'pymatching': PyMatchingDecoder(),
15+
'pymatching-correlated': PyMatchingDecoder(use_correlated_decoding=True),
1516
'fusion_blossom': FusionBlossomDecoder(),
1617
# an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049)
1718
'hypergraph_union_find': HyperUFDecoder(),

glue/sample/src/sinter/_decoding/_decoding_pymatching.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,52 @@
11
from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder
22

33

4+
def check_pymatching_version_for_correlated_decoding(pymatching):
5+
v = pymatching.__version__.split('.')
6+
try:
7+
a = int(v[0])
8+
b = int(v[1])
9+
c = int(''.join(e for e in v[2] if e in '0123456789')) # In case dev version
10+
except (ValueError, IndexError):
11+
return # Probably it's the future.
12+
13+
if (a, b, c) < (2, 3, 1):
14+
raise ValueError(
15+
"PyMatching version must be at least 2.3.1 for correlated decoding.\n"
16+
f"Installed version: {pymatching.__version__}\n"
17+
"To fix this, install a newer version of pymatching into your environment.\n"
18+
"For example, if you are using pip, run `pip install pymatching --upgrade`.\n"
19+
)
20+
21+
422
class PyMatchingCompiledDecoder(CompiledDecoder):
5-
def __init__(self, matcher: 'pymatching.Matching'):
23+
def __init__(self, matcher: 'pymatching.Matching', use_correlated_decoding: bool):
624
self.matcher = matcher
25+
self.use_correlated_decoding = use_correlated_decoding
726

827
def decode_shots_bit_packed(
928
self,
1029
*,
1130
bit_packed_detection_event_data: 'np.ndarray',
1231
) -> 'np.ndarray':
32+
kwargs = {}
33+
if self.use_correlated_decoding:
34+
kwargs['enable_correlations'] = True
1335
return self.matcher.decode_batch(
1436
shots=bit_packed_detection_event_data,
1537
bit_packed_shots=True,
1638
bit_packed_predictions=True,
1739
return_weights=False,
40+
**kwargs,
1841
)
1942

2043

2144
class PyMatchingDecoder(Decoder):
2245
"""Use pymatching to predict observables from detection events."""
2346

47+
def __init__(self, use_correlated_decoding: bool = False):
48+
self.use_correlated_decoding = use_correlated_decoding
49+
2450
def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder:
2551
try:
2652
import pymatching
@@ -31,7 +57,14 @@ def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> Compiled
3157
"For example, if you are using pip, run `pip install pymatching`.\n"
3258
) from ex
3359

34-
return PyMatchingCompiledDecoder(pymatching.Matching.from_detector_error_model(dem))
60+
kwargs = {}
61+
if self.use_correlated_decoding:
62+
check_pymatching_version_for_correlated_decoding(pymatching)
63+
kwargs['enable_correlations'] = True
64+
return PyMatchingCompiledDecoder(
65+
pymatching.Matching.from_detector_error_model(dem, **kwargs),
66+
use_correlated_decoding=self.use_correlated_decoding,
67+
)
3568

3669
def decode_via_files(self,
3770
*,
@@ -60,7 +93,9 @@ def decode_via_files(self,
6093
if not hasattr(pymatching, 'cli'):
6194
raise ValueError("""
6295
The installed version of pymatching has no `pymatching.cli` method.
96+
6397
sinter requires pymatching 2.1.0 or later.
98+
6499
If you're using pip to install packages, this can be fixed by running
65100
66101
```
@@ -69,13 +104,18 @@ def decode_via_files(self,
69104
70105
""")
71106

72-
result = pymatching.cli(command_line_args=[
107+
args = [
73108
"predict",
74109
"--dem", str(dem_path),
75110
"--in", str(dets_b8_in_path),
76111
"--in_format", "b8",
77112
"--out", str(obs_predictions_b8_out_path),
78113
"--out_format", "b8",
79-
])
114+
]
115+
if self.use_correlated_decoding:
116+
check_pymatching_version_for_correlated_decoding(pymatching)
117+
args.append("--enable_correlations")
118+
119+
result = pymatching.cli(command_line_args=args)
80120
if result:
81121
raise ValueError("pymatching.cli returned a non-zero exit code")

glue/sample/src/sinter/_decoding/_decoding_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo
233233

234234
@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES)
235235
def test_post_selection(decoder: str, force_streaming: Optional[bool]):
236+
if decoder == 'pymatching-correlated':
237+
pytest.skip("Correlated matching does not support error probabilities > 0.5 in from_detector_error_model")
236238
circuit = stim.Circuit("""
237239
X_ERROR(0.6) 0
238240
M 0
@@ -243,7 +245,7 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]):
243245
M 1
244246
DETECTOR(1, 0, 0) rec[-1]
245247
OBSERVABLE_INCLUDE(0) rec[-1]
246-
248+
247249
X_ERROR(0.1) 2
248250
M 2
249251
OBSERVABLE_INCLUDE(0) rec[-1]

0 commit comments

Comments
 (0)