Skip to content

Commit 5f78ba0

Browse files
authored
Add plot_custom to sinter python API (#1035)
The general sinter plotting function already exists, but is hidden from the user. This PR liberates plot_custom, so that I can perform fun tricks like plotting error rate / physical error rate vs distance, or plotting cpu-seconds per number of rounds, or plotting strong_id vs number of qubits, or whatever else sounds fun at the time, without having to write any plotting code other than y_func.
1 parent 3a83081 commit 5f78ba0

2 files changed

Lines changed: 66 additions & 0 deletions

File tree

doc/sinter_api.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi
4242
- [`sinter.iter_collect`](#sinter.iter_collect)
4343
- [`sinter.log_binomial`](#sinter.log_binomial)
4444
- [`sinter.log_factorial`](#sinter.log_factorial)
45+
- [`sinter.plot_custom`](#sinter.plot_custom)
4546
- [`sinter.plot_discard_rate`](#sinter.plot_discard_rate)
4647
- [`sinter.plot_error_rate`](#sinter.plot_error_rate)
4748
- [`sinter.post_selection_mask_from_4th_coord`](#sinter.post_selection_mask_from_4th_coord)
@@ -1452,6 +1453,70 @@ def log_factorial(
14521453
"""
14531454
```
14541455

1456+
<a name="sinter.plot_custom"></a>
1457+
```python
1458+
# sinter.plot_custom
1459+
1460+
# (at top-level in the sinter module)
1461+
def plot_custom(
1462+
*,
1463+
ax: 'plt.Axes',
1464+
stats: 'Iterable[sinter.TaskStats]',
1465+
x_func: Callable[[sinter.TaskStats], Any],
1466+
y_func: Callable[[sinter.TaskStats], Union[sinter.Fit, float, int]],
1467+
group_func: Callable[[sinter.TaskStats], ~TCurveId] = lambda _: None,
1468+
point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None,
1469+
filter_func: Callable[[sinter.TaskStats], Any] = lambda _: True,
1470+
plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
1471+
line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
1472+
) -> None:
1473+
"""Plots error rates in curves with uncertainty highlights.
1474+
1475+
Args:
1476+
ax: The plt.Axes to plot onto. For example, the `ax` value from `fig, ax = plt.subplots(1, 1)`.
1477+
stats: The collected statistics to plot.
1478+
x_func: The X coordinate to use for each stat's data point. For example, this could be
1479+
`x_func=lambda stat: stat.json_metadata['physical_error_rate']`.
1480+
y_func: The Y value to use for each stat's data point. This can be a float or it can be a
1481+
sinter.Fit value, in which case the curve will follow the fit.best value and a
1482+
highlighted area will be shown from fit.low to fit.high.
1483+
group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
1484+
The statistics are grouped into curves based on whether or not they get the same result
1485+
out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
1486+
If the result of the function is a dictionary, then optional keys in the dictionary will
1487+
also control the plotting of each curve. Available keys are:
1488+
'label': the label added to the legend for the curve
1489+
'color': the color used for plotting the curve
1490+
'marker': the marker used for the curve
1491+
'linestyle': the linestyle used for the curve
1492+
'sort': the order in which the curves will be plotted and added to the legend
1493+
e.g. if two curves (with different resulting dictionaries from group_func) share the same
1494+
value for key 'marker', they will be plotted with the same marker.
1495+
Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
1496+
point_label_func: Optional. Specifies text to draw next to data points.
1497+
filter_func: Optional. When specified, some curves will not be plotted.
1498+
The statistics are filtered and only plotted if filter_func(stat) returns True.
1499+
For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
1500+
where the saved metadata indicates the basis was 'x'.
1501+
plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
1502+
`plot` and `fill_between` used to do the actual plotting. For example, this can be used
1503+
to specify markers and colors. Takes the index of the curve in sorted order and also a
1504+
curve_id (these will be 0 and None respectively if group_func is not specified). For example,
1505+
this could be:
1506+
1507+
plot_args_func=lambda index, group_key, group_stats: {
1508+
'color': (
1509+
'red'
1510+
if group_key == 'decoder=pymatching p=0.001'
1511+
else 'blue'
1512+
),
1513+
}
1514+
line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
1515+
fit to every curve. The scales determine how to transform the coordinates before
1516+
performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
1517+
"""
1518+
```
1519+
14551520
<a name="sinter.plot_discard_rate"></a>
14561521
```python
14571522
# sinter.plot_discard_rate

glue/sample/src/sinter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
better_sorted_str_terms,
3838
plot_discard_rate,
3939
plot_error_rate,
40+
plot_custom,
4041
group_by,
4142
)
4243
from sinter._predict import (

0 commit comments

Comments
 (0)