11from sinter ._decoding ._decoding_decoder_class import Decoder , CompiledDecoder
22
33
4+ def check_pymatching_version_for_correlated_decoding (pymatching ):
5+ v = pymatching .__version__ .split ('.' )
6+ try :
7+ a = int (v [0 ])
8+ b = int (v [1 ])
9+ c = int ('' .join (e for e in v [2 ] if e in '0123456789' )) # In case dev version
10+ except (ValueError , IndexError ):
11+ return # Probably it's the future.
12+
13+ if (a , b , c ) < (2 , 3 , 1 ):
14+ raise ValueError (
15+ "PyMatching version must be at least 2.3.1 for correlated decoding.\n "
16+ f"Installed version: { pymatching .__version__ } \n "
17+ "To fix this, install a newer version of pymatching into your environment.\n "
18+ "For example, if you are using pip, run `pip install pymatching --upgrade`.\n "
19+ )
20+
21+
422class PyMatchingCompiledDecoder (CompiledDecoder ):
5- def __init__ (self , matcher : 'pymatching.Matching' ):
23+ def __init__ (self , matcher : 'pymatching.Matching' , use_correlated_decoding : bool ):
624 self .matcher = matcher
25+ self .use_correlated_decoding = use_correlated_decoding
726
827 def decode_shots_bit_packed (
928 self ,
1029 * ,
1130 bit_packed_detection_event_data : 'np.ndarray' ,
1231 ) -> 'np.ndarray' :
32+ kwargs = {}
33+ if self .use_correlated_decoding :
34+ kwargs ['enable_correlations' ] = True
1335 return self .matcher .decode_batch (
1436 shots = bit_packed_detection_event_data ,
1537 bit_packed_shots = True ,
1638 bit_packed_predictions = True ,
1739 return_weights = False ,
40+ ** kwargs ,
1841 )
1942
2043
2144class PyMatchingDecoder (Decoder ):
2245 """Use pymatching to predict observables from detection events."""
2346
47+ def __init__ (self , use_correlated_decoding : bool = False ):
48+ self .use_correlated_decoding = use_correlated_decoding
49+
2450 def compile_decoder_for_dem (self , * , dem : 'stim.DetectorErrorModel' ) -> CompiledDecoder :
2551 try :
2652 import pymatching
@@ -31,7 +57,14 @@ def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> Compiled
3157 "For example, if you are using pip, run `pip install pymatching`.\n "
3258 ) from ex
3359
34- return PyMatchingCompiledDecoder (pymatching .Matching .from_detector_error_model (dem ))
60+ kwargs = {}
61+ if self .use_correlated_decoding :
62+ check_pymatching_version_for_correlated_decoding (pymatching )
63+ kwargs ['enable_correlations' ] = True
64+ return PyMatchingCompiledDecoder (
65+ pymatching .Matching .from_detector_error_model (dem , ** kwargs ),
66+ use_correlated_decoding = self .use_correlated_decoding ,
67+ )
3568
3669 def decode_via_files (self ,
3770 * ,
@@ -60,7 +93,9 @@ def decode_via_files(self,
6093 if not hasattr (pymatching , 'cli' ):
6194 raise ValueError ("""
6295The installed version of pymatching has no `pymatching.cli` method.
96+
6397sinter requires pymatching 2.1.0 or later.
98+
6499If you're using pip to install packages, this can be fixed by running
65100
66101```
@@ -69,13 +104,18 @@ def decode_via_files(self,
69104
70105""" )
71106
72- result = pymatching . cli ( command_line_args = [
107+ args = [
73108 "predict" ,
74109 "--dem" , str (dem_path ),
75110 "--in" , str (dets_b8_in_path ),
76111 "--in_format" , "b8" ,
77112 "--out" , str (obs_predictions_b8_out_path ),
78113 "--out_format" , "b8" ,
79- ])
114+ ]
115+ if self .use_correlated_decoding :
116+ check_pymatching_version_for_correlated_decoding (pymatching )
117+ args .append ("--enable_correlations" )
118+
119+ result = pymatching .cli (command_line_args = args )
80120 if result :
81121 raise ValueError ("pymatching.cli returned a non-zero exit code" )
0 commit comments