Skip to content

Commit c8bea7a

Browse files
committed
Add impls of ProgramNode for various math operations
1 parent 76788a1 commit c8bea7a

4 files changed

Lines changed: 998 additions & 0 deletions

File tree

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
// This code is part of Qiskit.
2+
//
3+
// (C) Copyright IBM 2026
4+
//
5+
// This code is licensed under the Apache License, Version 2.0. You may
6+
// obtain a copy of this license in the LICENSE.txt file in the root directory
7+
// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
8+
//
9+
// Any modifications or derivative works of this code must retain this
10+
// copyright notice, and modified files need to carry a notice indicating
11+
// that they have been altered from the originals.
12+
13+
use crate::data_tree::DataTree;
14+
use crate::program_node::{MissingCallError, ProgramNode};
15+
use crate::tensor::{DTypeLike, Tensor, TensorType, promotion};
16+
use std::borrow::Cow;
17+
use std::sync::OnceLock;
18+
19+
/// Shared input type spec for all elementwise binary nodes: two broadcastable tensors `x` and `y`.
20+
fn elementwise_binary_input_types() -> &'static DataTree<TensorType> {
21+
static LOCK: OnceLock<DataTree<TensorType>> = OnceLock::new();
22+
LOCK.get_or_init(|| {
23+
let mut types = DataTree::with_capacity(2);
24+
types.insert_leaf(
25+
"x",
26+
TensorType {
27+
dtype: DTypeLike::Var("x".into()),
28+
shape: vec![],
29+
broadcastable: true,
30+
},
31+
);
32+
types.insert_leaf(
33+
"y",
34+
TensorType {
35+
dtype: DTypeLike::Var("y".into()),
36+
shape: vec![],
37+
broadcastable: true,
38+
},
39+
);
40+
types
41+
})
42+
}
43+
44+
/// Shared output type spec for all elementwise binary nodes: a single tensor of the promoted dtype.
45+
fn elementwise_binary_output_types() -> &'static DataTree<TensorType> {
46+
static LOCK: OnceLock<DataTree<TensorType>> = OnceLock::new();
47+
LOCK.get_or_init(|| {
48+
DataTree::new_leaf(TensorType {
49+
dtype: DTypeLike::Promotion(
50+
vec![DTypeLike::Var("x".into()), DTypeLike::Var("y".into())].into(),
51+
),
52+
shape: vec![],
53+
broadcastable: true,
54+
})
55+
})
56+
}
57+
58+
/// Extract `x` and `y` from `args`, promote dtypes, and apply `op` element-wise.
59+
fn binary_elementwise_call(
60+
args: &DataTree<Tensor>,
61+
op: impl Fn(&Tensor, &Tensor) -> Tensor,
62+
) -> Result<DataTree<Tensor>, MissingCallError> {
63+
let DataTree::Leaf(x) = args.get_by_str_key("x").expect("missing input x") else {
64+
panic!("expected leaf at x");
65+
};
66+
let DataTree::Leaf(y) = args.get_by_str_key("y").expect("missing input y") else {
67+
panic!("expected leaf at y");
68+
};
69+
let out_dtype = promotion(x.dtype(), y.dtype());
70+
71+
// Use copy-on-write smart pointer to avoid cloning when promotion is unnecessary
72+
let x = if x.dtype() == out_dtype {
73+
Cow::Borrowed(x)
74+
} else {
75+
Cow::Owned(x.clone().cast(out_dtype))
76+
};
77+
let y = if y.dtype() == out_dtype {
78+
Cow::Borrowed(y)
79+
} else {
80+
Cow::Owned(y.clone().cast(out_dtype))
81+
};
82+
Ok(DataTree::new_leaf(op(x.as_ref(), y.as_ref())))
83+
}
84+
85+
/// Generate a [`ProgramNode`] struct for an elementwise binary operation.
86+
macro_rules! elementwise_binary_node {
87+
($name:ident, $node_name:literal, $call_fn:expr) => {
88+
#[doc = concat!("Elementwise `", $node_name, "` of two broadcastable tensors.")]
89+
pub struct $name;
90+
91+
impl ProgramNode for $name {
92+
fn name(&self) -> &'static str {
93+
$node_name
94+
}
95+
fn namespace(&self) -> &'static str {
96+
"math"
97+
}
98+
fn input_types(&self) -> &DataTree<TensorType> {
99+
elementwise_binary_input_types()
100+
}
101+
fn output_types(&self) -> &DataTree<TensorType> {
102+
elementwise_binary_output_types()
103+
}
104+
fn implements_call(&self) -> bool {
105+
true
106+
}
107+
fn call(&self, args: &DataTree<Tensor>) -> Result<DataTree<Tensor>, MissingCallError> {
108+
binary_elementwise_call(args, $call_fn)
109+
}
110+
}
111+
};
112+
}
113+
114+
elementwise_binary_node!(Add, "add", |x, y| x + y);
115+
elementwise_binary_node!(Subtract, "subtract", |x, y| x - y);
116+
elementwise_binary_node!(Multiply, "multiply", |x, y| x * y);
117+
elementwise_binary_node!(Divide, "divide", |x, y| x / y);
118+
elementwise_binary_node!(Remainder, "remainder", |x, y| x % y);
119+
elementwise_binary_node!(Power, "power", |x, y| x.pow(y));
120+
121+
#[cfg(test)]
122+
mod tests {
123+
use super::*;
124+
use crate::tensor::{DType, Tensor};
125+
126+
fn args(x: Tensor, y: Tensor) -> DataTree<Tensor> {
127+
let mut tree = DataTree::new();
128+
tree.insert_leaf("x", x);
129+
tree.insert_leaf("y", y);
130+
tree
131+
}
132+
133+
#[test]
134+
fn test_add_same_dtype() {
135+
let result = Add.call(&args(Tensor::from([1.0_f64, 2.0, 3.0]), Tensor::from([4.0_f64, 5.0, 6.0]))).unwrap();
136+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
137+
panic!("expected f64 leaf")
138+
};
139+
assert_eq!(arr.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
140+
}
141+
142+
#[test]
143+
fn test_add_promotes_dtype() {
144+
let result = Add.call(&args(Tensor::from([1.0_f32, 2.0]), Tensor::from([3.0_f64, 4.0]))).unwrap();
145+
let DataTree::Leaf(tensor) = result else {
146+
panic!("expected leaf")
147+
};
148+
assert_eq!(tensor.dtype(), DType::F64);
149+
let Tensor::F64(arr) = tensor else {
150+
panic!("expected f64")
151+
};
152+
assert_eq!(arr.as_slice().unwrap(), &[4.0, 6.0]);
153+
}
154+
155+
#[test]
156+
fn test_add_broadcasts_1d_scalar() {
157+
// shape [3] + shape [1] -> shape [3]
158+
let result = Add.call(&args(Tensor::from([1.0_f64, 2.0, 3.0]), Tensor::from([10.0_f64]))).unwrap();
159+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
160+
panic!("expected f64 leaf")
161+
};
162+
assert_eq!(arr.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
163+
}
164+
165+
#[test]
166+
fn test_add_broadcasts_2d_with_1d() {
167+
// shape [2, 3] + shape [3] -> shape [2, 3]
168+
use ndarray::arr2;
169+
let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn());
170+
let y = Tensor::from([10.0_f64, 20.0, 30.0]);
171+
let result = Add.call(&args(x, y)).unwrap();
172+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
173+
panic!("expected f64 leaf")
174+
};
175+
let expected = arr2(&[[11.0_f64, 22.0, 33.0], [14.0, 25.0, 36.0]]).into_dyn();
176+
assert_eq!(arr, expected);
177+
}
178+
179+
#[test]
180+
fn test_subtract() {
181+
let result = Subtract
182+
.call(&args(Tensor::from([5.0_f64, 6.0, 7.0]), Tensor::from([1.0_f64, 2.0, 3.0])))
183+
.unwrap();
184+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
185+
panic!()
186+
};
187+
assert_eq!(arr.as_slice().unwrap(), &[4.0, 4.0, 4.0]);
188+
}
189+
190+
#[test]
191+
fn test_multiply() {
192+
let result = Multiply
193+
.call(&args(Tensor::from([2.0_f64, 3.0, 4.0]), Tensor::from([10.0_f64, 10.0, 10.0])))
194+
.unwrap();
195+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
196+
panic!()
197+
};
198+
assert_eq!(arr.as_slice().unwrap(), &[20.0, 30.0, 40.0]);
199+
}
200+
201+
#[test]
202+
fn test_divide() {
203+
let result = Divide
204+
.call(&args(Tensor::from([10.0_f64, 9.0, 8.0]), Tensor::from([2.0_f64, 3.0, 4.0])))
205+
.unwrap();
206+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
207+
panic!()
208+
};
209+
assert_eq!(arr.as_slice().unwrap(), &[5.0, 3.0, 2.0]);
210+
}
211+
212+
#[test]
213+
fn test_remainder() {
214+
let result = Remainder
215+
.call(&args(Tensor::from([7.0_f64, 8.0, 9.0]), Tensor::from([3.0_f64, 3.0, 3.0])))
216+
.unwrap();
217+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
218+
panic!()
219+
};
220+
assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 0.0]);
221+
}
222+
223+
#[test]
224+
fn test_power() {
225+
let result = Power
226+
.call(&args(Tensor::from([2.0_f64, 3.0, 4.0]), Tensor::from([3.0_f64, 2.0, 1.0])))
227+
.unwrap();
228+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
229+
panic!()
230+
};
231+
assert_eq!(arr.as_slice().unwrap(), &[8.0, 9.0, 4.0]);
232+
}
233+
234+
#[test]
235+
fn test_power_broadcasts() {
236+
// shape [3] ** shape [1] -> shape [3]
237+
let result = Power.call(&args(Tensor::from([2.0_f64, 3.0, 4.0]), Tensor::from([2.0_f64]))).unwrap();
238+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
239+
panic!()
240+
};
241+
assert_eq!(arr.as_slice().unwrap(), &[4.0, 9.0, 16.0]);
242+
}
243+
}

0 commit comments

Comments
 (0)