Skip to content

Commit 1e70959

Browse files
committed
fix merging bug in generalize errors
1 parent e2f55f4 commit 1e70959

1 file changed

Lines changed: 45 additions & 13 deletions

File tree

src/py/generalize_dem.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ def get_dets_logicals(error: stim.DemInstruction):
1212
dets = dets.symmetric_difference({t.val})
1313
return dets, logicals
1414

15-
def spatial_key(detector_coords: dict, min_t_coord:float, max_t_coord: float, error: stim.DemInstruction):
16-
dets, logicals = get_dets_logicals(error)
15+
def spatial_key(detector_coords: dict, min_t_coord:float, max_t_coord: float, dets, logicals):
1716
d_coords = sorted([tuple(detector_coords[d]) for d in dets])
1817
min_d_coord = d_coords[0]
1918
relative_d_coords = [tuple(np.array(c)-np.array(min_d_coord)) for c in d_coords]
@@ -39,13 +38,35 @@ def get_detector_coords(dem: stim.DetectorErrorModel):
3938
max_t_coord = max(c[2] for c in detector_coords.values())
4039
return detector_coords, min_t_coord, max_t_coord
4140

41+
42+
# Analyze the errors to make the flip tables
43+
def merged_errors(dem):
44+
errors_by_symptom = {}
45+
for error in dem.flattened():
46+
if error.type != "error":
47+
continue
48+
probability = error.args_copy()[0]
49+
assert 0 <= probability and probability <= 1, error
50+
detectors, observables = get_dets_logicals(error)
51+
key = (tuple(sorted(detectors)), tuple(sorted(observables)))
52+
if key in errors_by_symptom:
53+
p0 = errors_by_symptom[key]["probability"]
54+
probability = p0 * (1 - probability) + (1 - p0) * probability
55+
error = {
56+
"probability": probability,
57+
"likelihood_cost": -np.log(probability / (1 - probability)),
58+
"detectors": list(detectors),
59+
"observables": list(observables),
60+
}
61+
errors_by_symptom[key] = error
62+
63+
return list(errors_by_symptom.values())
64+
4265
def get_key_to_probabilities(spatial_data, template, verbose=False):
4366
key_to_probabilities = {}
44-
for inst in template.flattened():
45-
if inst.type != 'error':
46-
continue
47-
probability = inst.args_copy()[0]
48-
key = spatial_key(*spatial_data, inst)
67+
for error in merged_errors(template):
68+
probability = error['probability']
69+
key = spatial_key(*spatial_data, error['detectors'], error['observables'])
4970
if key not in key_to_probabilities:
5071
key_to_probabilities[key] = []
5172
key_to_probabilities[key].append(probability)
@@ -61,7 +82,7 @@ def merge_concat(dictionaries: List[dict]):
6182
merged[k] = []
6283
merged[k] = np.concatenate([merged[k], d[k]])
6384
return merged
64-
85+
6586

6687
def generalize(templates: List[stim.DetectorErrorModel], scaffold: stim.DetectorErrorModel, verbose: bool=False) -> stim.DetectorErrorModel:
6788
# Get detector coords for all detectors
@@ -77,11 +98,22 @@ def generalize(templates: List[stim.DetectorErrorModel], scaffold: stim.Detector
7798
for key, probabilities in key_to_probabilities.items()
7899
}
79100
output_dem = stim.DetectorErrorModel()
80-
for inst in scaffold.flattened():
81-
if inst.type == 'error':
82-
# update the probability
83-
key = spatial_key(*spatial_data_scaffold, inst)
84-
inst = stim.DemInstruction(type='error', args=[key_to_probability[key]], targets=inst.targets_copy())
101+
for instruction in scaffold.flattened():
102+
if instruction.type != 'error':
103+
output_dem.append(instruction)
104+
for error in merged_errors(scaffold):
105+
# update the probability
106+
key = spatial_key(*spatial_data_scaffold, error['detectors'], error['observables'])
107+
inst = stim.DemInstruction(
108+
type='error',
109+
args=[key_to_probability[key]],
110+
targets=[
111+
stim.target_relative_detector_id(D)
112+
for D in error['detectors']
113+
] + [
114+
stim.target_logical_observable_id(L)
115+
for L in error['observables']
116+
])
85117
output_dem.append(inst)
86118

87119
return output_dem

0 commit comments

Comments
 (0)