Skip to content

Commit 76788a1

Browse files
committed
Add Store impl of ProgramNode
1 parent 49e8ebd commit 76788a1

2 files changed

Lines changed: 161 additions & 0 deletions

File tree

crates/providers/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
mod data_tree;
1414
mod program_node;
15+
mod store;
1516
pub mod tensor;
1617
pub use data_tree::{DataTree, PathEntry};
1718
pub use program_node::{MissingCallError, ProgramNode};
19+
pub use store::Store;

crates/providers/src/store.rs

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// This code is part of Qiskit.
2+
//
3+
// (C) Copyright IBM 2026
4+
//
5+
// This code is licensed under the Apache License, Version 2.0. You may
6+
// obtain a copy of this license in the LICENSE.txt file in the root directory
7+
// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
8+
//
9+
// Any modifications or derivative works of this code must retain this
10+
// copyright notice, and modified files need to carry a notice indicating
11+
// that they have been altered from the originals.
12+
13+
use crate::data_tree::DataTree;
14+
use crate::program_node::{MissingCallError, ProgramNode};
15+
use crate::tensor::{Tensor, TensorType};
16+
use std::sync::OnceLock;
17+
18+
/// A program node that owns constant data and outputs it unconditionally.
19+
///
20+
/// `Store` takes no inputs; its `call()` always returns the data it was constructed with.
21+
/// In a data-flow graph, `Store` nodes play the role of constants — they are wired to
22+
/// the input ports of computation nodes to supply fixed values.
23+
pub struct Store {
24+
data: DataTree<Tensor>,
25+
output_types: DataTree<TensorType>,
26+
}
27+
28+
impl Store {
29+
/// Construct a new `Store` holding the given data.
30+
pub fn new(data: DataTree<Tensor>) -> Self {
31+
let output_types = derive_output_types(&data);
32+
Self { data, output_types }
33+
}
34+
35+
/// Return a reference to the stored data.
36+
pub fn data(&self) -> &DataTree<Tensor> {
37+
&self.data
38+
}
39+
}
40+
41+
/// Recursively derive output types from concrete tensor data.
42+
fn derive_output_types(data: &DataTree<Tensor>) -> DataTree<TensorType> {
43+
match data {
44+
DataTree::Leaf(tensor) => DataTree::new_leaf(tensor.tensor_type()),
45+
DataTree::Branch(_) => {
46+
let mut result = DataTree::with_capacity(data.len());
47+
for (key, child) in data.iter_children() {
48+
let child_type = derive_output_types(child);
49+
if let Some(k) = key {
50+
result.insert_branch(k, child_type);
51+
} else {
52+
result.push_branch(child_type);
53+
}
54+
}
55+
result
56+
}
57+
}
58+
}
59+
60+
impl ProgramNode for Store {
61+
fn name(&self) -> &'static str {
62+
"store"
63+
}
64+
65+
fn namespace(&self) -> &'static str {
66+
"core"
67+
}
68+
69+
fn input_types(&self) -> &DataTree<TensorType> {
70+
static EMPTY: OnceLock<DataTree<TensorType>> = OnceLock::new();
71+
EMPTY.get_or_init(DataTree::new)
72+
}
73+
74+
fn output_types(&self) -> &DataTree<TensorType> {
75+
&self.output_types
76+
}
77+
78+
fn implements_call(&self) -> bool {
79+
true
80+
}
81+
82+
fn call(&self, _args: &DataTree<Tensor>) -> Result<DataTree<Tensor>, MissingCallError> {
83+
Ok(self.data.clone())
84+
}
85+
}
86+
87+
#[cfg(test)]
88+
mod tests {
89+
use super::*;
90+
use crate::tensor::{DType, DTypeLike, Dim, Tensor};
91+
92+
#[test]
93+
fn test_store_leaf_call() {
94+
let data = DataTree::new_leaf(Tensor::from([1.0_f64, 2.0, 3.0]));
95+
let store = Store::new(data);
96+
let result = store.call(&DataTree::new()).unwrap();
97+
let DataTree::Leaf(Tensor::F64(arr)) = result else {
98+
panic!("expected f64 leaf");
99+
};
100+
assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
101+
}
102+
103+
#[test]
104+
fn test_store_output_types_leaf() {
105+
let data = DataTree::new_leaf(Tensor::from([1.0_f64, 2.0, 3.0]));
106+
let store = Store::new(data);
107+
let DataTree::Leaf(tt) = store.output_types() else {
108+
panic!("expected leaf output type");
109+
};
110+
assert!(matches!(tt.dtype, DTypeLike::Concrete(DType::F64)));
111+
assert_eq!(tt.shape, vec![Dim::Fixed(3)]);
112+
assert!(!tt.broadcastable);
113+
}
114+
115+
#[test]
116+
fn test_store_output_types_2d() {
117+
use ndarray::arr2;
118+
let data =
119+
DataTree::new_leaf(Tensor::F64(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn()));
120+
let store = Store::new(data);
121+
let DataTree::Leaf(tt) = store.output_types() else {
122+
panic!("expected leaf output type");
123+
};
124+
assert_eq!(tt.shape, vec![Dim::Fixed(2), Dim::Fixed(2)]);
125+
}
126+
127+
#[test]
128+
fn test_store_branched() {
129+
let mut data = DataTree::new();
130+
data.insert_leaf("a", Tensor::from([1.0_f64, 2.0]));
131+
data.insert_leaf("b", Tensor::from([10_i32, 20, 30]));
132+
let store = Store::new(data);
133+
134+
assert!(store.input_types().is_empty());
135+
assert_eq!(store.name(), "store");
136+
assert_eq!(store.namespace(), "core");
137+
assert_eq!(store.full_name(), "core.store");
138+
139+
let out_types = store.output_types();
140+
let DataTree::Leaf(tt_a) = out_types.get_by_str_key("a").unwrap() else {
141+
panic!("expected leaf at a");
142+
};
143+
assert!(matches!(tt_a.dtype, DTypeLike::Concrete(DType::F64)));
144+
assert_eq!(tt_a.shape, vec![Dim::Fixed(2)]);
145+
146+
let DataTree::Leaf(tt_b) = out_types.get_by_str_key("b").unwrap() else {
147+
panic!("expected leaf at b");
148+
};
149+
assert!(matches!(tt_b.dtype, DTypeLike::Concrete(DType::I32)));
150+
assert_eq!(tt_b.shape, vec![Dim::Fixed(3)]);
151+
}
152+
153+
#[test]
154+
fn test_store_no_inputs() {
155+
let store = Store::new(DataTree::new_leaf(Tensor::from([42.0_f64])));
156+
assert!(store.input_types().is_empty());
157+
assert!(store.implements_call());
158+
}
159+
}

0 commit comments

Comments
 (0)