@@ -12,8 +12,7 @@ def get_dets_logicals(error: stim.DemInstruction):
1212 dets = dets .symmetric_difference ({t .val })
1313 return dets , logicals
1414
15- def spatial_key (detector_coords : dict , min_t_coord :float , max_t_coord : float , error : stim .DemInstruction ):
16- dets , logicals = get_dets_logicals (error )
15+ def spatial_key (detector_coords : dict , min_t_coord :float , max_t_coord : float , dets , logicals ):
1716 d_coords = sorted ([tuple (detector_coords [d ]) for d in dets ])
1817 min_d_coord = d_coords [0 ]
1918 relative_d_coords = [tuple (np .array (c )- np .array (min_d_coord )) for c in d_coords ]
@@ -39,13 +38,35 @@ def get_detector_coords(dem: stim.DetectorErrorModel):
3938 max_t_coord = max (c [2 ] for c in detector_coords .values ())
4039 return detector_coords , min_t_coord , max_t_coord
4140
41+
42+ # Analyze the errors to make the flip tables
43+ def merged_errors (dem ):
44+ errors_by_symptom = {}
45+ for error in dem .flattened ():
46+ if error .type != "error" :
47+ continue
48+ probability = error .args_copy ()[0 ]
49+ assert 0 <= probability and probability <= 1 , error
50+ detectors , observables = get_dets_logicals (error )
51+ key = (tuple (sorted (detectors )), tuple (sorted (observables )))
52+ if key in errors_by_symptom :
53+ p0 = errors_by_symptom [key ]["probability" ]
54+ probability = p0 * (1 - probability ) + (1 - p0 ) * probability
55+ error = {
56+ "probability" : probability ,
57+ "likelihood_cost" : - np .log (probability / (1 - probability )),
58+ "detectors" : list (detectors ),
59+ "observables" : list (observables ),
60+ }
61+ errors_by_symptom [key ] = error
62+
63+ return list (errors_by_symptom .values ())
64+
4265def get_key_to_probabilities (spatial_data , template , verbose = False ):
4366 key_to_probabilities = {}
44- for inst in template .flattened ():
45- if inst .type != 'error' :
46- continue
47- probability = inst .args_copy ()[0 ]
48- key = spatial_key (* spatial_data , inst )
67+ for error in merged_errors (template ):
68+ probability = error ['probability' ]
69+ key = spatial_key (* spatial_data , error ['detectors' ], error ['observables' ])
4970 if key not in key_to_probabilities :
5071 key_to_probabilities [key ] = []
5172 key_to_probabilities [key ].append (probability )
@@ -61,7 +82,7 @@ def merge_concat(dictionaries: List[dict]):
6182 merged [k ] = []
6283 merged [k ] = np .concatenate ([merged [k ], d [k ]])
6384 return merged
64-
85+
6586
6687def generalize (templates : List [stim .DetectorErrorModel ], scaffold : stim .DetectorErrorModel , verbose : bool = False ) -> stim .DetectorErrorModel :
6788 # Get detector coords for all detectors
@@ -77,11 +98,22 @@ def generalize(templates: List[stim.DetectorErrorModel], scaffold: stim.Detector
7798 for key , probabilities in key_to_probabilities .items ()
7899 }
79100 output_dem = stim .DetectorErrorModel ()
80- for inst in scaffold .flattened ():
81- if inst .type == 'error' :
82- # update the probability
83- key = spatial_key (* spatial_data_scaffold , inst )
84- inst = stim .DemInstruction (type = 'error' , args = [key_to_probability [key ]], targets = inst .targets_copy ())
101+ for instruction in scaffold .flattened ():
102+ if instruction .type != 'error' :
103+ output_dem .append (instruction )
104+ for error in merged_errors (scaffold ):
105+ # update the probability
106+ key = spatial_key (* spatial_data_scaffold , error ['detectors' ], error ['observables' ])
107+ inst = stim .DemInstruction (
108+ type = 'error' ,
109+ args = [key_to_probability [key ]],
110+ targets = [
111+ stim .target_relative_detector_id (D )
112+ for D in error ['detectors' ]
113+ ] + [
114+ stim .target_logical_observable_id (L )
115+ for L in error ['observables' ]
116+ ])
85117 output_dem .append (inst )
86118
87119 return output_dem
0 commit comments