Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 148 additions & 2 deletions crates/circuit/src/circuit_drawer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crossterm::terminal;
use hashbrown::HashSet;
use itertools::{Itertools, MinMaxResult};
use pyo3::prelude::*;
use std::f64::consts::PI;
use std::fmt::Debug;
use std::ops::Index;
use unicode_segmentation::UnicodeSegmentation;
Expand Down Expand Up @@ -60,7 +61,10 @@ pub fn draw_circuit(
if approx::abs_diff_eq!(*f, 0.) {
String::new()
} else {
format!("global phase: {}\n", f)
format!(
"global phase: {}\n",
format_float_pi(*f).unwrap_or_else(|| f.to_string())
)
}
}
Param::ParameterExpression(expr) => {
Expand Down Expand Up @@ -763,7 +767,7 @@ impl TextDrawer {
.params_view()
.iter()
.map(|param| match param {
Param::Float(f) => f.to_string(),
Param::Float(f) => format_float_pi(*f).unwrap_or_else(|| f.to_string()),
Param::ParameterExpression(expr) => expr.to_string(),
_ => format!("{:?}", param),
})
Expand Down Expand Up @@ -1202,6 +1206,93 @@ impl TextDrawer {
}
}

/// Computes if a number is close to an integer
/// fraction or multiple of PI and returns the
/// corresponding string.
///
/// Args:
/// f : Number to check.
///
/// Returns:
/// The string representation of output. None if no Pi formatting is found.
pub fn format_float_pi(f: f64) -> Option<String> {
const DENOMINATOR: i64 = 16;
// epsilon value defines the threshold to detect pi.
const EPS: f64 = 1e-9;

// pi_str is needed to match the output expected according to the format needed
let pi_str = "π";

// f_abs and sign help us working through each steps
let f_abs = f.abs();
let sign = if f < 0.0 { "-" } else { "" };

// Detecting 0 before moving on
if f_abs < EPS {
return Some("0".to_string());
}

// First check is for whole multiples of pi
let val = f_abs / PI;
let round = val.round();
if val >= 1.0 - EPS && (val - round).abs() < EPS {
let round = round as usize;
return Some(if round == 1 {
format!("{}{}", sign, pi_str)
} else {
format!("{}{}{}", sign, round, pi_str)
});
}

// Second is a check for powers of pi
if f_abs > PI {
if let Some(k) = (2..=4).find(|k| (f_abs - PI.powi(*k)).abs() < EPS) {
return Some(format!("{}{}^{}", sign, pi_str, k));
}
}
Comment thread
OnyxBrumeSky marked this conversation as resolved.

// Third is a check for a number larger than DENOMINATOR * pi, not a
// multiple or power of pi, since no fractions will exceed DENOMINATOR * pi
if f_abs > (DENOMINATOR as f64 * PI) {
return None;
}

// Fourth check is for fractions for 1*pi in the numer and any
// number in the denom.
let val = PI / f_abs;
let round = val.round();
if round >= 1.0 && (val - round).abs() < EPS {
let d = round as usize;
let str_out = format!("{}{}/{}", sign, pi_str, d);
return Some(str_out);
}

// Fifth check is for fractions of the form (numer/denom) * pi or (numer/denom) / pi
// where 1 <= numer,denom <= DENOMINATOR, which are not covered in the previous checks.
// Ex. 15pi/16, 2pi/5, 15pi/2, 16pi/9 or 15/16pi, 2/5pi, 15/2pi, 16/9pi
for denom in 1..=DENOMINATOR {
for numer in 1..=DENOMINATOR {
let up = numer as f64 / denom as f64;
let val = up * PI;
if (f_abs - val).abs() < EPS {
let str_out = format!("{}{}{}/{}", sign, numer, pi_str, denom);
return Some(str_out);
}
let val = up / PI;
if (f_abs - val).abs() < EPS {
let str_out = match denom {
1 => format!("{}{}/{}", sign, numer, pi_str),
d => format!("{}{}/{}{}", sign, numer, d, pi_str),
};
return Some(str_out);
}
}
}

// fall back when no conversion is possible
None
}

#[cfg(test)]
mod tests {
use ndarray::Array2;
Expand Down Expand Up @@ -1998,4 +2089,59 @@ q_1: ┤ Ry(🎩) ├┤1 ├┤ 💶🔉(🎩) ├┤1 ├
";
assert_eq!(result, expected.trim_start_matches("\n"));
}

#[test]
fn test_format_float_pi() {
let test_points = [
(0.0, Some("0")),
(-0.0, Some("0")),
(1e-12, Some("0")),
(PI, Some("π")),
(-PI, Some("-π")),
(2.0 * PI, Some("2π")),
(3.0 * PI, Some("3π")),
(10.0 * PI, Some("10π")),
(16.0 * PI, Some("16π")),
(-2.0 * PI, Some("-2π")),
(-5.0 * PI, Some("-5π")),
(PI.powi(2), Some("π^2")),
(-PI.powi(2), Some("-π^2")),
(PI.powi(3), Some("π^3")),
(PI.powi(4), Some("π^4")),
(PI / 2.0, Some("π/2")),
(PI / 3.0, Some("π/3")),
(PI / 4.0, Some("π/4")),
(PI / 6.0, Some("π/6")),
(-PI / 2.0, Some("-π/2")),
(2.0 * PI / 3.0, Some("2π/3")),
(3.0 * PI / 4.0, Some("3π/4")),
(5.0 * PI / 6.0, Some("5π/6")),
(7.0 * PI / 4.0, Some("7π/4")),
(15.0 * PI / 16.0, Some("15π/16")),
(-2.0 * PI / 3.0, Some("-2π/3")),
(1.0 / PI, Some("1/π")),
(2.0 / PI, Some("2/π")),
(1.0 / (2.0 * PI), Some("1/2π")),
(3.0 / (4.0 * PI), Some("3/4π")),
(-1.0 / PI, Some("-1/π")),
(-1.0 / (2.0 * PI), Some("-1/2π")),
(-18.0 / 16.0 * PI, Some("-9π/8")),
(60.0 / 44.0 / PI, Some("15/11π")),
(17.0 * PI + 1.0, None),
(100.0, None),
(1.0, None),
(2.0, None),
(1.5, None),
(-7.3, None),
(PI + 1e-6, None),
(PI - 1e-6, None),
(PI / 2.0 + 1e-6, None),
(17.0 * PI / 2.0, None),
(9.0 / (17.0 * PI), None),
];

for test in test_points {
assert_eq!(format_float_pi(test.0), test.1.map(|s| s.to_string()));
}
}
}