From dc27da7c045eac257248b13ff30af011467f1502 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Sat, 28 Mar 2026 07:44:56 -0400 Subject: [PATCH 01/10] Add data tree struct This commit adds a new data structure to the new providers crate DataTree. The DataTree is a generic tree structure that will be used to define the operation ports in the QuantumProgram's tensor compute graph's nodes. Right now this is just one of the building blocks towards defining the QuantumProgram. As it isn't being used right now since the rest of the components don't exist yet, this is solely self tested. When subsequent components are added tests using the DataTree as part of Operation types and eventually in a QuantumProgram will be needed as well. --- Cargo.lock | 3 + crates/providers/Cargo.toml | 4 + crates/providers/src/data_tree.rs | 460 ++++++++++++++++++++++++++++++ crates/providers/src/lib.rs | 4 + 4 files changed, 471 insertions(+) create mode 100644 crates/providers/src/data_tree.rs diff --git a/Cargo.lock b/Cargo.lock index 4d5ec222efb0..b2dfb578593a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2341,6 +2341,9 @@ dependencies = [ [[package]] name = "qiskit-providers" version = "2.5.0-dev" +dependencies = [ + "hashbrown 0.15.5", +] [[package]] name = "qiskit-pyext" diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index e9f2e7ff623b..84b0184037af 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -10,5 +10,9 @@ name = "qiskit_providers" [dependencies] +[dependencies.hashbrown] +workspace = true +features = ["rayon", "serde"] + [lints] workspace = true diff --git a/crates/providers/src/data_tree.rs b/crates/providers/src/data_tree.rs new file mode 100644 index 000000000000..5b1b409d48db --- /dev/null +++ b/crates/providers/src/data_tree.rs @@ -0,0 +1,460 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use std::marker::PhantomData; + +use hashbrown::HashMap; + +/// An item stored in a `DataTree` +/// +/// This can either be a Leaf which is a concrete item of type `T` or a subtree. +#[derive(Debug, Clone, PartialEq)] +pub enum TreeEntry<'a, T> { + Leaf(T), + // TODO: Box this to reduce memory consumption + Tree(DataTree<'a, T>), +} + +impl<'a, T> TreeEntry<'a, T> { + /// Return true if the entry is a leaf + pub fn is_leaf(&self) -> bool { + match self { + Self::Leaf(_) => true, + Self::Tree(_) => false, + } + } + + /// Consume the entry and return the leaf value otherwise panic + pub fn unwrap_leaf(self) -> T { + match self { + Self::Leaf(data) => data, + Self::Tree(_) => panic!("called TreeEntry::unwrap_leaf() on a `Tree` value"), + } + } + + /// Consume the entry and return the data tree otherwise panic + pub fn unwrap_tree(self) -> DataTree<'a, T> { + match self { + Self::Leaf(_) => panic!("called TreeEntry::unwrap_tree() on a `Leaf` value"), + Self::Tree(data) => data, + } + } + + /// Return a reference to the underlying tree + /// + /// This will be None if the `TreeEntry` is a `Leaf` + pub fn as_tree_ref(&self) -> Option<&DataTree<'a, T>> { + match *self { + Self::Leaf(_) => None, + Self::Tree(ref tree) => Some(tree), + } + } + + /// Return a reference to the underlying tree + /// + /// This will be None if the `TreeEntry` is a `Tree` + pub fn as_leaf_ref(&self) -> Option<&T> { + match *self { + Self::Leaf(ref val) => Some(val), + Self::Tree(_) => None, + } + } +} + +/// A generic tree that is addressable either by either indices or string keys +#[derive(Debug, Clone)] +pub struct DataTree<'a, T> { + data: Vec>, + keys: HashMap, + _marker: PhantomData<&'a T>, +} + +impl<'a, T> Default for DataTree<'a, T> { + fn default() -> Self { + Self::new() + } +} + +impl<'a, T> DataTree<'a, T> { + /// Create a new empty data tree + pub fn new() -> Self { + DataTree { + data: Vec::new(), + keys: HashMap::new(), + _marker: PhantomData, + } + } + + /// Create a new empty data tree with an underlying allocation of a given size. + /// + /// The specified capacity is the number of items of type T stored in the `DataTree` + /// along with an associated `String` key for each element in the tree. This does not + /// account for nesting in the allocation as each layer in the tree is a separate + /// `DataTree` object. + pub fn with_capacity(capacity: usize) -> Self { + DataTree { + data: Vec::with_capacity(capacity), + keys: HashMap::with_capacity(capacity), + _marker: PhantomData, + } + } + + /// The number of items in this `DataTree`. This length is just the number of items in this + /// local tree object and will not recurse through the tree to compute the total number of + /// leaves. If you want to do that you should use [`DataTree::iter_leaves`]. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut inner_tree = DataTree::with_capacity(5); + /// inner_tree.insert_leaf("y", 10); + /// inner_tree.insert_leaf("z", 11); + /// inner_tree.insert_leaf("a", 12); + /// inner_tree.insert_leaf("b", 13); + /// inner_tree.push_leaf(15); + /// + /// let mut tree = DataTree::new(); + /// tree.insert_tree("x", inner_tree); + /// assert_eq!(tree.len(), 1); + /// ``` + pub fn len(&self) -> usize { + self.data.len() + } + + /// Return whether this `DataTree` has an items in it. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Take a string key and return the entry at the given key. + /// + /// The "." character is reserved in keys and used to indicate a path + /// through the graph. + /// + /// This will return `None` if the string key can not be found. This includes + /// an invalid path, such as a path containing component or a leaf node in the + /// middle. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut inner_tree = DataTree::new(); + /// inner_tree.insert_leaf("y", 10); + /// let mut tree = DataTree::new(); + /// tree.insert_tree("x", inner_tree); + /// let result = tree.get_by_str_key("x.y").unwrap().as_leaf_ref(); + /// assert_eq!(*result.unwrap(), 10); + /// ``` + pub fn get_by_str_key(&self, key: &str) -> Option<&TreeEntry<'_, T>> { + if key.contains(".") { + let mut components = key.split("."); + let first = self.get_by_str_key(components.next().unwrap()); + components.fold(first, |tree, key| match tree { + Some(entry) => { + match entry { + // If we encounter a leaf in the accumulated tree than + // that means we have an incorrect path and there is no + // match + TreeEntry::Leaf(_) => None, + TreeEntry::Tree(tree) => tree.get_by_str_key(key), + } + } + None => None, + }) + } else { + self.keys.get(key).map(|value| &self.data[*value]) + } + } + + /// Get an item from the `DataTree` by index. + /// + /// This will return `None` if the index is not valid. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut inner_tree = DataTree::new(); + /// inner_tree.insert_leaf("y", 10); + /// let mut tree = DataTree::new(); + /// tree.insert_tree("x", inner_tree); + /// tree.push_leaf(124); + /// let Some(result) = tree.get(1).unwrap().as_leaf_ref() else { + /// panic!("Encountered an unexpected Tree"); + /// }; + /// assert_eq!(*result, 124); + /// let subtree = tree.get(0).unwrap().as_tree_ref().unwrap(); + /// let subtree_result = subtree.get(0).unwrap().as_leaf_ref(); + /// assert_eq!(*subtree_result.unwrap(), 10); + /// ``` + pub fn get(&self, index: usize) -> Option<&TreeEntry<'_, T>> { + self.data.get(index) + } + + /// Insert a new leaf node with an associated string key + /// + /// If a key is provided that is already in the tree the new value will be associated with + /// with the key and the old value will remain in the tree but without a string key. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("y", 10); + /// tree.insert_leaf("y", 1000); + /// let result = tree.get_by_str_key("y").unwrap().as_leaf_ref(); + /// assert_eq!(*result.unwrap(), 1000); + /// ``` + pub fn insert_leaf(&mut self, key: &str, value: T) { + self.data.push(TreeEntry::Leaf(value)); + self.keys.insert(key.to_string(), self.data.len() - 1); + } + + /// Add a new leaf to the tree + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.push_leaf(10); + /// tree.push_leaf(1000); + /// assert_eq!(vec![10, 1000], tree.iter_leaves().copied().collect::>()); + /// ``` + pub fn push_leaf(&mut self, value: T) { + self.data.push(TreeEntry::Leaf(value)); + } + + /// Add a subtree to the tree with an associated string key + /// + /// If a key is provided that is already in the tree the new value will be associated with + /// with the key and the old value will remain in the tree but without a string key. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("y", 10); + /// let mut subtree = DataTree::with_capacity(2); + /// subtree.push_leaf(123); + /// subtree.push_leaf(456); + /// tree.insert_tree("y", subtree); + /// let result = tree.get_by_str_key("y").unwrap().as_tree_ref().unwrap(); + /// let leaves: Vec<_> = result.iter_leaves().copied().collect(); + /// assert_eq!(leaves, vec![123, 456]); + /// ``` + pub fn insert_tree(&mut self, key: &str, value: DataTree<'a, T>) { + self.data.push(TreeEntry::Tree(value)); + self.keys.insert(key.to_string(), self.data.len() - 1); + } + + pub fn push_tree(&mut self, value: DataTree<'a, T>) { + self.data.push(TreeEntry::Tree(value)); + } + + /// Return an iterator over the leaves in the `DataTree` + /// + /// This method will return an iterator over all leave nodes in the tree by traversing the tree + /// in a DFS order. + /// + /// # Example + /// + /// Traversing this tree: + /// + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut subsubsubtree = DataTree::new(); + /// subsubsubtree.push_leaf(3); + /// subsubsubtree.push_leaf(4); + /// let mut subsubtree = DataTree::new(); + /// subsubtree.push_tree(subsubsubtree); + /// subsubtree.insert_leaf("b", 5); + /// let mut subsubtree_prime = DataTree::new(); + /// subsubtree_prime.push_leaf(7); + /// let mut subtree = DataTree::new(); + /// subtree.insert_tree("c", subsubtree); + /// subtree.insert_leaf("d", 6); + /// subtree.push_tree(subsubtree_prime); + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("a", 0); + /// tree.insert_tree("root", subtree); + /// tree.insert_leaf("z", 26); + /// let leaves: Vec<_> = tree.iter_leaves().copied().collect(); + /// let expected = vec![0, 3, 4, 5, 6, 7, 26]; + /// assert_eq!(leaves, expected); + /// ``` + pub fn iter_leaves(&self) -> IterLeaves<'_, T> { + IterLeaves { + tree: self, + index: 0, + inner: None, + inner_next: None, + } + } +} + +#[derive(Debug)] +pub struct IterLeaves<'a, T> { + tree: &'a DataTree<'a, T>, + index: usize, + inner: Option>>, + inner_next: Option<&'a T>, +} + +impl<'a, T: std::fmt::Debug> Iterator for IterLeaves<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.index >= self.tree.len() { + return None; + } + let entry = &self.tree.data[self.index]; + match entry { + TreeEntry::Leaf(val) => { + self.index += 1; + Some(val) + } + TreeEntry::Tree(subtree) => match self.inner { + Some(ref mut inner) => { + if let Some(val) = inner.next() { + let return_val = self.inner_next; + self.inner_next = Some(val); + return_val + } else { + self.inner = None; + self.index += 1; + let return_val = self.inner_next; + self.inner_next = None; + return_val + } + } + None => { + self.inner = Some(Box::new(subtree.iter_leaves())); + let val = self.inner.as_mut().map(|x| x.next().unwrap()); + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + val + } + }, + } + } +} + +impl<'a, T: PartialEq> PartialEq for DataTree<'a, T> { + fn eq(&self, other: &DataTree) -> bool { + self.data == other.data && self.keys == other.keys + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_data_leaf() { + let mut tree = DataTree::new(); + tree.push_leaf(42); + let result = tree.get(0).unwrap().clone(); + assert_eq!(result.unwrap_leaf(), 42); + } + + #[test] + fn test_flat_dict() { + let mut tree = DataTree::with_capacity(3); + tree.insert_leaf("a", 1); + tree.insert_leaf("b", 2); + let result = tree.get_by_str_key("b").unwrap().clone(); + assert_eq!(result.unwrap_leaf(), 2); + let result = tree.get_by_str_key("a").unwrap().clone(); + assert_eq!(result.unwrap_leaf(), 1); + } + + #[test] + fn test_nested_dict() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + let mut tree = DataTree::new(); + tree.insert_tree("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + assert_eq!(None, tree.get_by_str_key("z.y")); + let expected = TreeEntry::Tree(inner_tree); + assert_eq!(Some(&expected), tree.get_by_str_key("x")); + } + + #[test] + fn test_nested_dict_iter() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.push_leaf(4); + inner_inner_tree.push_leaf(5); + inner_tree.push_tree(inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_tree("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + assert_eq!( + vec![10, 1, 2, 3, 4, 5, 100], + tree.iter_leaves().copied().collect::>() + ); + } + + #[test] + fn test_get_by_str() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.insert_leaf("a", 4); + inner_inner_tree.push_leaf(5); + let inner_inner_tree_expected = inner_inner_tree.clone(); + inner_tree.insert_tree("yyy", inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_tree("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + let result = tree.get_by_str_key("x.yyy.a"); + assert_eq!(result, Some(&TreeEntry::Leaf(4))); + assert_eq!(tree.get_by_str_key("z"), Some(&TreeEntry::Leaf(100))); + assert_eq!( + tree.get_by_str_key("x.yyy"), + Some(&TreeEntry::Tree(inner_inner_tree_expected)) + ); + assert_eq!(tree.get_by_str_key("x.yy"), Some(&TreeEntry::Leaf(1))); + } + + #[test] + fn test_get_by_str_no_match() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.insert_leaf("a", 4); + inner_inner_tree.push_leaf(5); + inner_tree.insert_tree("yyy", inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_tree("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + assert_eq!(None, tree.get_by_str_key("a")); + assert_eq!(None, tree.get_by_str_key("x.yyyy")); + assert_eq!(None, tree.get_by_str_key("x.yy.a")); + assert_eq!(None, tree.get_by_str_key("🎩")); + assert_eq!(None, tree.get_by_str_key("z.yyy.a")); + } +} diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 95dc338739c7..ccc1b9b38e95 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -9,3 +9,7 @@ // Any modifications or derivative works of this code must retain this // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. + +mod data_tree; + +pub use data_tree::{DataTree, TreeEntry}; From e32cf5e25bceb5a1904bb34d47ea5865b4e9fcaa Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 14 Apr 2026 14:52:50 -0400 Subject: [PATCH 02/10] Add PathEntry enum to track a path through a DataTree This commit adds a new type PathEntry that is used to outline a path through the DataTree into a leaf node. Along with this are two new methods to lookup a leaf node by a path and also to traverse the tree to get leaf nodes along with the path to that node. --- crates/providers/src/data_tree.rs | 205 +++++++++++++++++++++++++++++- crates/providers/src/lib.rs | 2 +- 2 files changed, 204 insertions(+), 3 deletions(-) diff --git a/crates/providers/src/data_tree.rs b/crates/providers/src/data_tree.rs index 5b1b409d48db..ac893123fa97 100644 --- a/crates/providers/src/data_tree.rs +++ b/crates/providers/src/data_tree.rs @@ -14,6 +14,16 @@ use std::marker::PhantomData; use hashbrown::HashMap; +/// A path entry used for tracking a path through a [`DataTree`] +/// +/// Each entry can either be an index or a key. A slice of `PathEntry` are used to form +/// a traversal path through the [`DataTree`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum PathEntry<'a> { + Index(usize), + Key(&'a str), +} + /// An item stored in a `DataTree` /// /// This can either be a Leaf which is a concrete item of type `T` or a subtree. @@ -175,6 +185,38 @@ impl<'a, T> DataTree<'a, T> { } } + /// Take a path slice and return the entry at the given path + /// + /// This will return `None` if a path can not be found. This includes an + /// invalid path, such as a path a leaf node in the middle. An empty path + /// will also return `None`. + pub fn get_by_path(&self, path: &[PathEntry]) -> Option<&TreeEntry<'_, T>> { + if path.is_empty() { + return None; + } + let start = match path[0] { + PathEntry::Index(idx) => Some(&self.data[idx]), + PathEntry::Key(key) => self.keys.get(key).map(|idx| &self.data[*idx]), + }?; + match start { + TreeEntry::Leaf(_) => { + if path.len() > 1 { + // If there are more components in the path this is an invalid path + None + } else { + Some(start) + } + } + TreeEntry::Tree(inner_tree) => { + if path.len() > 1 { + inner_tree.get_by_path(&path[1..]) + } else { + Some(start) + } + } + } + } + /// Get an item from the `DataTree` by index. /// /// This will return `None` if the index is not valid. @@ -261,7 +303,7 @@ impl<'a, T> DataTree<'a, T> { /// Return an iterator over the leaves in the `DataTree` /// - /// This method will return an iterator over all leave nodes in the tree by traversing the tree + /// This method will return an iterator over all leaf nodes in the tree by traversing the tree /// in a DFS order. /// /// # Example @@ -298,9 +340,117 @@ impl<'a, T> DataTree<'a, T> { inner_next: None, } } + + /// Return an iterator over the leaves in the `DataTree` that returns the path and leaf value. + /// + /// This method will return an iterator over all the leaf nodes in the tree in a DFS order. + /// Unlike [`iter_leaves`] which just returns the value this will return an owned `Vec` of the + /// path through the data tree to get to that value. This has allocation overhead for each leaf + /// node in the tree and should only be used if you need the path along with the value. + /// + /// ```rust + /// use qiskit_providers::{DataTree, PathEntry}; + /// let mut subsubsubtree = DataTree::new(); + /// subsubsubtree.push_leaf(3); + /// subsubsubtree.push_leaf(4); + /// let mut subsubtree = DataTree::new(); + /// subsubtree.push_tree(subsubsubtree); + /// subsubtree.insert_leaf("b", 5); + /// let mut subsubtree_prime = DataTree::new(); + /// subsubtree_prime.push_leaf(7); + /// let mut subtree = DataTree::new(); + /// subtree.insert_tree("c", subsubtree); + /// subtree.insert_leaf("d", 6); + /// subtree.push_tree(subsubtree_prime); + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("a", 0); + /// tree.insert_tree("root", subtree); + /// tree.insert_leaf("z", 26); + /// let result: Vec<_> = tree.iter_path().map(|(a, b)| (a, *b)).collect(); + /// let expected_paths: Vec> = vec![ + /// vec![0], + /// vec![1, 0, 0, 0], + /// vec![1, 0, 0, 1], + /// vec![1, 0, 1], + /// vec![1, 1], + /// vec![1, 2, 0], + /// vec![2], + /// ]; + /// let expected_vals = vec![0, 3, 4, 5, 6, 7, 26]; + /// let expected: Vec<_> = expected_paths + /// .into_iter() + /// .map(|x| x.into_iter().map(PathEntry::Index).collect::>()) + /// .zip(expected_vals) + /// .collect(); + /// assert_eq!(result, expected); + /// ``` + pub fn iter_path(&self) -> IterDataTree<'_, T> { + IterDataTree { + tree: self, + index: 0, + inner: None, + inner_next: None, + path: Vec::new(), + } + } +} + +pub struct IterDataTree<'a, T> { + tree: &'a DataTree<'a, T>, + index: usize, + inner: Option>>, + inner_next: Option<(Vec>, &'a T)>, + path: Vec>, +} + +impl<'a, T> Iterator for IterDataTree<'a, T> { + type Item = (Vec>, &'a T); + + fn next(&mut self) -> Option { + if self.index >= self.tree.len() { + return None; + } + let entry = &self.tree.data[self.index]; + match entry { + TreeEntry::Leaf(val) => { + self.index += 1; + let mut out_path = self.path.clone(); + out_path.push(PathEntry::Index(self.index - 1)); + Some((out_path, val)) + } + TreeEntry::Tree(subtree) => match self.inner { + Some(ref mut inner) => { + if let Some(val) = inner.next() { + let (return_path, return_val) = self.inner_next.replace(val).unwrap(); + Some((return_path, return_val)) + } else { + self.inner = None; + self.index += 1; + let (return_path, return_val) = self.inner_next.take().unwrap(); + self.inner_next = None; + Some((return_path, return_val)) + } + } + None => { + let mut inner = subtree.iter_path(); + let mut inner_path = self.path.clone(); + inner_path.push(PathEntry::Index(self.index)); + inner.path = inner_path; + self.inner = Some(Box::new(inner)); + let (inner_path, val) = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some((inner_path, val)) + } + }, + } + } } -#[derive(Debug)] pub struct IterLeaves<'a, T> { tree: &'a DataTree<'a, T>, index: usize, @@ -335,6 +485,7 @@ impl<'a, T: std::fmt::Debug> Iterator for IterLeaves<'a, T> { return_val } } + None => { self.inner = Some(Box::new(subtree.iter_leaves())); let val = self.inner.as_mut().map(|x| x.next().unwrap()); @@ -412,6 +563,56 @@ mod test { ); } + #[test] + fn test_nested_dict_iter_path() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.push_leaf(4); + inner_inner_tree.push_leaf(5); + inner_tree.push_tree(inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_tree("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + let expected_paths = vec![ + vec![PathEntry::Index(0), PathEntry::Index(0)], + vec![PathEntry::Index(0), PathEntry::Index(1)], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(0), + ], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(1), + ], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(2), + ], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(3), + ], + vec![PathEntry::Index(1)], + ]; + let expected_vals = vec![10, 1, 2, 3, 4, 5, 100]; + let expected = expected_paths + .into_iter() + .zip(expected_vals.into_iter()) + .collect::>(); + assert_eq!( + expected, + tree.iter_path().map(|(a, b)| (a, *b)).collect::>() + ); + } + #[test] fn test_get_by_str() { let mut inner_tree = DataTree::new(); diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index ccc1b9b38e95..47072ccd4ed2 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -12,4 +12,4 @@ mod data_tree; -pub use data_tree::{DataTree, TreeEntry}; +pub use data_tree::{DataTree, PathEntry, TreeEntry}; From a5b30271e0a06c378a7bcc2c12a153b66ef911b8 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 28 Apr 2026 11:08:51 -0400 Subject: [PATCH 03/10] Invert type hierarchy This commit changes the type hierarchy for DataTree to move the struct from being a vec of enums of either leaves or branches. To each DataTree being an enum of either a leaf or branch and each branch contains a vec of DataTrees. This is a more natural form as the outer type can either be a leaf or branch and simplifies working with the tree. --- crates/providers/src/data_tree.rs | 615 ++++++++++++++++++++---------- crates/providers/src/lib.rs | 2 +- 2 files changed, 415 insertions(+), 202 deletions(-) diff --git a/crates/providers/src/data_tree.rs b/crates/providers/src/data_tree.rs index ac893123fa97..13b881bfc30f 100644 --- a/crates/providers/src/data_tree.rs +++ b/crates/providers/src/data_tree.rs @@ -10,8 +10,6 @@ // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. -use std::marker::PhantomData; - use hashbrown::HashMap; /// A path entry used for tracking a path through a [`DataTree`] @@ -24,84 +22,151 @@ pub enum PathEntry<'a> { Key(&'a str), } -/// An item stored in a `DataTree` +/// A struct representing a branch in a [`DataTree`]. /// -/// This can either be a Leaf which is a concrete item of type `T` or a subtree. -#[derive(Debug, Clone, PartialEq)] -pub enum TreeEntry<'a, T> { - Leaf(T), - // TODO: Box this to reduce memory consumption - Tree(DataTree<'a, T>), +/// Each branch contains a vec of [`DataTree`] that can also be assigned a +/// string key for accessing it. Typically you will not create these directly +/// but instead create them via the [`DataTree`] API. +#[derive(Debug, Clone)] +pub struct DataTreeBranch { + data: Vec>, + keys: HashMap, } -impl<'a, T> TreeEntry<'a, T> { - /// Return true if the entry is a leaf - pub fn is_leaf(&self) -> bool { - match self { - Self::Leaf(_) => true, - Self::Tree(_) => false, +impl DataTreeBranch { + /// Construct a new empty [`DataTreeBranch`] + pub fn new() -> Self { + DataTreeBranch { + data: Vec::new(), + keys: HashMap::new(), } } - /// Consume the entry and return the leaf value otherwise panic - pub fn unwrap_leaf(self) -> T { - match self { - Self::Leaf(data) => data, - Self::Tree(_) => panic!("called TreeEntry::unwrap_leaf() on a `Tree` value"), + /// Construct a new empty [`DataTreeBranch`] with a set capacity + pub fn with_capacity(capacity: usize) -> Self { + DataTreeBranch { + data: Vec::with_capacity(capacity), + keys: HashMap::with_capacity(capacity), } } - /// Consume the entry and return the data tree otherwise panic - pub fn unwrap_tree(self) -> DataTree<'a, T> { - match self { - Self::Leaf(_) => panic!("called TreeEntry::unwrap_tree() on a `Leaf` value"), - Self::Tree(data) => data, + /// Take a path slice and return the entry at the given path + /// + /// This will return `None` if a path can not be found. This includes an + /// invalid path, such as a path a leaf node in the middle. An empty path + /// will also return `self`. + pub fn get_by_path(&self, path: &[PathEntry]) -> Option<&DataTree> { + let start = match path[0] { + PathEntry::Index(idx) => Some(&self.data[idx]), + PathEntry::Key(key) => self.keys.get(key).map(|idx| &self.data[*idx]), + }?; + match start { + DataTree::Leaf(_) => { + if path.len() > 1 { + // If there are more components in the path this is an invalid path + None + } else { + Some(start) + } + } + DataTree::Branch(inner_tree) => { + if path.len() > 1 { + inner_tree.get_by_path(&path[1..]) + } else { + Some(start) + } + } } } - /// Return a reference to the underlying tree + /// Return an iterator over the leaves in the `DataTree` /// - /// This will be None if the `TreeEntry` is a `Leaf` - pub fn as_tree_ref(&self) -> Option<&DataTree<'a, T>> { - match *self { - Self::Leaf(_) => None, - Self::Tree(ref tree) => Some(tree), + /// This method will return an iterator over all leaf nodes in the tree by traversing the tree + /// in a DFS order. + pub fn iter_path(&self) -> IterDataTree<'_, T> { + IterDataTree { + tree: None, + branch: Some(self), + index: 0, + inner: None, + inner_next: None, + path: vec![], } } - /// Return a reference to the underlying tree - /// - /// This will be None if the `TreeEntry` is a `Tree` - pub fn as_leaf_ref(&self) -> Option<&T> { - match *self { - Self::Leaf(ref val) => Some(val), - Self::Tree(_) => None, + /// The number of items in this `DataTree`. This length is just the number of items in this + /// local tree object and will not recurse through the tree to compute the total number of + /// leaves. If you want to do that you should use [`DataTree::iter_leaves`]. + pub fn iter_leaves(&self) -> IterLeaves<'_, T> { + IterLeaves { + tree: None, + branch: Some(self), + index: 0, + inner: None, + inner_next: None, + } + } + + /// The number of [`DataTree`] in this branch. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Check if there are any [`DataTree`] in this branch. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// The number of string keys set on this branch. + pub fn num_keys(&self) -> usize { + self.keys.len() + } + + /// Check if the branch has any string keys set. + pub fn has_keys(&self) -> bool { + !self.keys.is_empty() + } +} + +impl From> for DataTreeBranch { + fn from(input: DataTree) -> Self { + DataTreeBranch { + data: vec![input], + keys: HashMap::new(), } } } /// A generic tree that is addressable either by either indices or string keys #[derive(Debug, Clone)] -pub struct DataTree<'a, T> { - data: Vec>, - keys: HashMap, - _marker: PhantomData<&'a T>, +pub enum DataTree { + Leaf(T), + Branch(DataTreeBranch), } -impl<'a, T> Default for DataTree<'a, T> { +impl Default for DataTree { fn default() -> Self { Self::new() } } -impl<'a, T> DataTree<'a, T> { +impl DataTree { + /// Consume the entry and return the leaf value otherwise panic + pub fn unwrap_leaf(self) -> T { + match self { + Self::Leaf(data) => data, + Self::Branch(_) => panic!("called TreeEntry::unwrap_leaf() on a `Tree` value"), + } + } + /// Create a new empty data tree pub fn new() -> Self { - DataTree { - data: Vec::new(), - keys: HashMap::new(), - _marker: PhantomData, - } + DataTree::Branch(DataTreeBranch::new()) + } + + /// Create a new leaf data tree + pub fn new_leaf(value: T) -> Self { + DataTree::Leaf(value) } /// Create a new empty data tree with an underlying allocation of a given size. @@ -111,11 +176,7 @@ impl<'a, T> DataTree<'a, T> { /// account for nesting in the allocation as each layer in the tree is a separate /// `DataTree` object. pub fn with_capacity(capacity: usize) -> Self { - DataTree { - data: Vec::with_capacity(capacity), - keys: HashMap::with_capacity(capacity), - _marker: PhantomData, - } + DataTree::Branch(DataTreeBranch::with_capacity(capacity)) } /// The number of items in this `DataTree`. This length is just the number of items in this @@ -133,16 +194,19 @@ impl<'a, T> DataTree<'a, T> { /// inner_tree.push_leaf(15); /// /// let mut tree = DataTree::new(); - /// tree.insert_tree("x", inner_tree); + /// tree.insert_branch("x", inner_tree); /// assert_eq!(tree.len(), 1); /// ``` pub fn len(&self) -> usize { - self.data.len() + match self { + Self::Leaf(_) => 1, + Self::Branch(branch) => branch.data.len(), + } } /// Return whether this `DataTree` has an items in it. pub fn is_empty(&self) -> bool { - self.data.is_empty() + self.len() == 0 } /// Take a string key and return the entry at the given key. @@ -152,7 +216,7 @@ impl<'a, T> DataTree<'a, T> { /// /// This will return `None` if the string key can not be found. This includes /// an invalid path, such as a path containing component or a leaf node in the - /// middle. + /// middle. An empty string for the path will return `self`. /// /// # Example /// ```rust @@ -160,28 +224,22 @@ impl<'a, T> DataTree<'a, T> { /// let mut inner_tree = DataTree::new(); /// inner_tree.insert_leaf("y", 10); /// let mut tree = DataTree::new(); - /// tree.insert_tree("x", inner_tree); - /// let result = tree.get_by_str_key("x.y").unwrap().as_leaf_ref(); - /// assert_eq!(*result.unwrap(), 10); + /// tree.insert_branch("x", inner_tree); + /// let result = tree.get_by_str_key("x.y").unwrap().clone().unwrap_leaf(); + /// assert_eq!(result, 10); /// ``` - pub fn get_by_str_key(&self, key: &str) -> Option<&TreeEntry<'_, T>> { + pub fn get_by_str_key(&self, key: &str) -> Option<&Self> { + if key.is_empty() { + return Some(self); + } if key.contains(".") { - let mut components = key.split("."); - let first = self.get_by_str_key(components.next().unwrap()); - components.fold(first, |tree, key| match tree { - Some(entry) => { - match entry { - // If we encounter a leaf in the accumulated tree than - // that means we have an incorrect path and there is no - // match - TreeEntry::Leaf(_) => None, - TreeEntry::Tree(tree) => tree.get_by_str_key(key), - } - } - None => None, - }) + let path: Vec = key.split(".").map(PathEntry::Key).collect(); + self.get_by_path(&path) } else { - self.keys.get(key).map(|value| &self.data[*value]) + match self { + Self::Leaf(_) => None, + Self::Branch(branch) => branch.keys.get(key).map(|value| &branch.data[*value]), + } } } @@ -189,17 +247,20 @@ impl<'a, T> DataTree<'a, T> { /// /// This will return `None` if a path can not be found. This includes an /// invalid path, such as a path a leaf node in the middle. An empty path - /// will also return `None`. - pub fn get_by_path(&self, path: &[PathEntry]) -> Option<&TreeEntry<'_, T>> { + /// will also return `self`. + pub fn get_by_path(&self, path: &[PathEntry]) -> Option<&Self> { if path.is_empty() { - return None; + return Some(self); } + let Self::Branch(branch) = self else { + return None; + }; let start = match path[0] { - PathEntry::Index(idx) => Some(&self.data[idx]), - PathEntry::Key(key) => self.keys.get(key).map(|idx| &self.data[*idx]), + PathEntry::Index(idx) => Some(&branch.data[idx]), + PathEntry::Key(key) => branch.keys.get(key).map(|idx| &branch.data[*idx]), }?; match start { - TreeEntry::Leaf(_) => { + DataTree::Leaf(_) => { if path.len() > 1 { // If there are more components in the path this is an invalid path None @@ -207,7 +268,7 @@ impl<'a, T> DataTree<'a, T> { Some(start) } } - TreeEntry::Tree(inner_tree) => { + DataTree::Branch(inner_tree) => { if path.len() > 1 { inner_tree.get_by_path(&path[1..]) } else { @@ -227,18 +288,19 @@ impl<'a, T> DataTree<'a, T> { /// let mut inner_tree = DataTree::new(); /// inner_tree.insert_leaf("y", 10); /// let mut tree = DataTree::new(); - /// tree.insert_tree("x", inner_tree); + /// tree.insert_branch("x", inner_tree); /// tree.push_leaf(124); - /// let Some(result) = tree.get(1).unwrap().as_leaf_ref() else { - /// panic!("Encountered an unexpected Tree"); - /// }; - /// assert_eq!(*result, 124); - /// let subtree = tree.get(0).unwrap().as_tree_ref().unwrap(); - /// let subtree_result = subtree.get(0).unwrap().as_leaf_ref(); - /// assert_eq!(*subtree_result.unwrap(), 10); + /// let result = tree.get(1).unwrap().clone().unwrap_leaf(); + /// assert_eq!(result, 124); + /// let subtree = tree.get(0).unwrap(); + /// let subtree_result = subtree.get(0).unwrap().clone().unwrap_leaf(); + /// assert_eq!(subtree_result, 10); /// ``` - pub fn get(&self, index: usize) -> Option<&TreeEntry<'_, T>> { - self.data.get(index) + pub fn get(&self, index: usize) -> Option<&DataTree> { + match self { + Self::Leaf(_) => panic!("Called get() on a leaf node"), + Self::Branch(branch) => branch.data.get(index), + } } /// Insert a new leaf node with an associated string key @@ -252,12 +314,17 @@ impl<'a, T> DataTree<'a, T> { /// let mut tree = DataTree::new(); /// tree.insert_leaf("y", 10); /// tree.insert_leaf("y", 1000); - /// let result = tree.get_by_str_key("y").unwrap().as_leaf_ref(); - /// assert_eq!(*result.unwrap(), 1000); + /// let result = tree.get_by_str_key("y").unwrap().clone().unwrap_leaf(); + /// assert_eq!(result, 1000); /// ``` pub fn insert_leaf(&mut self, key: &str, value: T) { - self.data.push(TreeEntry::Leaf(value)); - self.keys.insert(key.to_string(), self.data.len() - 1); + match self { + Self::Leaf(_) => panic!("Called insert_leaf() on a leaf node"), + Self::Branch(branch) => { + branch.data.push(Self::Leaf(value)); + branch.keys.insert(key.to_string(), branch.data.len() - 1); + } + } } /// Add a new leaf to the tree @@ -271,7 +338,10 @@ impl<'a, T> DataTree<'a, T> { /// assert_eq!(vec![10, 1000], tree.iter_leaves().copied().collect::>()); /// ``` pub fn push_leaf(&mut self, value: T) { - self.data.push(TreeEntry::Leaf(value)); + match self { + Self::Leaf(_) => panic!("Called push_leaf() on a leaf_node"), + Self::Branch(branch) => branch.data.push(DataTree::Leaf(value)), + } } /// Add a subtree to the tree with an associated string key @@ -287,18 +357,39 @@ impl<'a, T> DataTree<'a, T> { /// let mut subtree = DataTree::with_capacity(2); /// subtree.push_leaf(123); /// subtree.push_leaf(456); - /// tree.insert_tree("y", subtree); - /// let result = tree.get_by_str_key("y").unwrap().as_tree_ref().unwrap(); + /// tree.insert_branch("y", subtree); + /// let result = tree.get_by_str_key("y").unwrap(); /// let leaves: Vec<_> = result.iter_leaves().copied().collect(); /// assert_eq!(leaves, vec![123, 456]); /// ``` - pub fn insert_tree(&mut self, key: &str, value: DataTree<'a, T>) { - self.data.push(TreeEntry::Tree(value)); - self.keys.insert(key.to_string(), self.data.len() - 1); + pub fn insert_branch(&mut self, key: &str, value: DataTree) { + match self { + Self::Leaf(_) => panic!("Called insert_branch() on a leaf_node"), + Self::Branch(branch) => { + branch.data.push(value); + branch.keys.insert(key.to_string(), branch.data.len() - 1); + } + } } - pub fn push_tree(&mut self, value: DataTree<'a, T>) { - self.data.push(TreeEntry::Tree(value)); + /// Add a new branch to the tree + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.push_leaf(10); + /// let mut subtree = DataTree::with_capacity(2); + /// subtree.push_leaf(123); + /// subtree.push_leaf(456); + /// tree.push_branch(subtree); + /// assert_eq!(vec![10, 123, 456], tree.iter_leaves().copied().collect::>()); + /// ``` + pub fn push_branch(&mut self, value: DataTree) { + match self { + Self::Leaf(_) => panic!("Called insert_branch() on a leaf_node"), + Self::Branch(branch) => branch.data.push(value), + } } /// Return an iterator over the leaves in the `DataTree` @@ -316,17 +407,17 @@ impl<'a, T> DataTree<'a, T> { /// subsubsubtree.push_leaf(3); /// subsubsubtree.push_leaf(4); /// let mut subsubtree = DataTree::new(); - /// subsubtree.push_tree(subsubsubtree); + /// subsubtree.push_branch(subsubsubtree); /// subsubtree.insert_leaf("b", 5); /// let mut subsubtree_prime = DataTree::new(); /// subsubtree_prime.push_leaf(7); /// let mut subtree = DataTree::new(); - /// subtree.insert_tree("c", subsubtree); + /// subtree.insert_branch("c", subsubtree); /// subtree.insert_leaf("d", 6); - /// subtree.push_tree(subsubtree_prime); + /// subtree.push_branch(subsubtree_prime); /// let mut tree = DataTree::new(); /// tree.insert_leaf("a", 0); - /// tree.insert_tree("root", subtree); + /// tree.insert_branch("root", subtree); /// tree.insert_leaf("z", 26); /// let leaves: Vec<_> = tree.iter_leaves().copied().collect(); /// let expected = vec![0, 3, 4, 5, 6, 7, 26]; @@ -334,7 +425,8 @@ impl<'a, T> DataTree<'a, T> { /// ``` pub fn iter_leaves(&self) -> IterLeaves<'_, T> { IterLeaves { - tree: self, + tree: Some(self), + branch: None, index: 0, inner: None, inner_next: None, @@ -354,17 +446,17 @@ impl<'a, T> DataTree<'a, T> { /// subsubsubtree.push_leaf(3); /// subsubsubtree.push_leaf(4); /// let mut subsubtree = DataTree::new(); - /// subsubtree.push_tree(subsubsubtree); + /// subsubtree.push_branch(subsubsubtree); /// subsubtree.insert_leaf("b", 5); /// let mut subsubtree_prime = DataTree::new(); /// subsubtree_prime.push_leaf(7); /// let mut subtree = DataTree::new(); - /// subtree.insert_tree("c", subsubtree); + /// subtree.insert_branch("c", subsubtree); /// subtree.insert_leaf("d", 6); - /// subtree.push_tree(subsubtree_prime); + /// subtree.push_branch(subsubtree_prime); /// let mut tree = DataTree::new(); /// tree.insert_leaf("a", 0); - /// tree.insert_tree("root", subtree); + /// tree.insert_branch("root", subtree); /// tree.insert_leaf("z", 26); /// let result: Vec<_> = tree.iter_path().map(|(a, b)| (a, *b)).collect(); /// let expected_paths: Vec> = vec![ @@ -386,7 +478,8 @@ impl<'a, T> DataTree<'a, T> { /// ``` pub fn iter_path(&self) -> IterDataTree<'_, T> { IterDataTree { - tree: self, + tree: Some(self), + branch: None, index: 0, inner: None, inner_next: None, @@ -396,7 +489,8 @@ impl<'a, T> DataTree<'a, T> { } pub struct IterDataTree<'a, T> { - tree: &'a DataTree<'a, T>, + tree: Option<&'a DataTree>, + branch: Option<&'a DataTreeBranch>, index: usize, inner: Option>>, inner_next: Option<(Vec>, &'a T)>, @@ -407,52 +501,109 @@ impl<'a, T> Iterator for IterDataTree<'a, T> { type Item = (Vec>, &'a T); fn next(&mut self) -> Option { - if self.index >= self.tree.len() { - return None; - } - let entry = &self.tree.data[self.index]; - match entry { - TreeEntry::Leaf(val) => { - self.index += 1; - let mut out_path = self.path.clone(); - out_path.push(PathEntry::Index(self.index - 1)); - Some((out_path, val)) + if let Some(tree) = self.tree { + if let DataTree::Leaf(val) = tree { + if self.index == 0 { + self.index += 1; + return Some((vec![], val)); + } else { + return None; + } + } + let DataTree::Branch(branch) = tree else { + unreachable!("Must be a branch variant"); + }; + if self.index >= branch.data.len() { + return None; } - TreeEntry::Tree(subtree) => match self.inner { - Some(ref mut inner) => { - if let Some(val) = inner.next() { - let (return_path, return_val) = self.inner_next.replace(val).unwrap(); - Some((return_path, return_val)) + let entry = &branch.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + let mut out_path = self.path.clone(); + out_path.push(PathEntry::Index(self.index - 1)); + Some((out_path, val)) + } + DataTree::Branch(sub_branch) => { + if let Some(ref mut inner) = self.inner { + if let Some(val) = inner.next() { + let (return_path, return_val) = self.inner_next.replace(val).unwrap(); + Some((return_path, return_val)) + } else { + self.inner = None; + self.index += 1; + let (return_path, return_val) = self.inner_next.take().unwrap(); + self.inner_next = None; + Some((return_path, return_val)) + } } else { - self.inner = None; - self.index += 1; - let (return_path, return_val) = self.inner_next.take().unwrap(); - self.inner_next = None; - Some((return_path, return_val)) + let mut inner = sub_branch.iter_path(); + let mut inner_path = self.path.clone(); + inner_path.push(PathEntry::Index(self.index)); + inner.path = inner_path; + self.inner = Some(Box::new(inner)); + let (inner_path, val) = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some((inner_path, val)) } } - None => { - let mut inner = subtree.iter_path(); - let mut inner_path = self.path.clone(); - inner_path.push(PathEntry::Index(self.index)); - inner.path = inner_path; - self.inner = Some(Box::new(inner)); - let (inner_path, val) = self.inner.as_mut().map(|x| x.next().unwrap())?; - self.inner_next = self.inner.as_mut().and_then(|x| x.next()); - if self.inner_next.is_none() { - self.index += 1; - self.inner = None; - self.inner_next = None; - } - Some((inner_path, val)) + } + } else if let Some(subtree) = self.branch { + if self.index >= subtree.data.len() { + return None; + } + let entry = &subtree.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + let mut out_path = self.path.clone(); + out_path.push(PathEntry::Index(self.index - 1)); + Some((out_path, val)) } - }, + DataTree::Branch(subtree) => match self.inner { + Some(ref mut inner) => { + if let Some(val) = inner.next() { + let (return_path, return_val) = self.inner_next.replace(val).unwrap(); + Some((return_path, return_val)) + } else { + self.inner = None; + self.index += 1; + let (return_path, return_val) = self.inner_next.take().unwrap(); + self.inner_next = None; + Some((return_path, return_val)) + } + } + None => { + let mut inner = subtree.iter_path(); + let mut inner_path = self.path.clone(); + inner_path.push(PathEntry::Index(self.index)); + inner.path = inner_path; + self.inner = Some(Box::new(inner)); + let (inner_path, val) = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some((inner_path, val)) + } + }, + } + } else { + None } } } pub struct IterLeaves<'a, T> { - tree: &'a DataTree<'a, T>, + tree: Option<&'a DataTree>, + branch: Option<&'a DataTreeBranch>, index: usize, inner: Option>>, inner_next: Option<&'a T>, @@ -462,49 +613,112 @@ impl<'a, T: std::fmt::Debug> Iterator for IterLeaves<'a, T> { type Item = &'a T; fn next(&mut self) -> Option { - if self.index >= self.tree.len() { - return None; - } - let entry = &self.tree.data[self.index]; - match entry { - TreeEntry::Leaf(val) => { - self.index += 1; - Some(val) + if let Some(tree) = self.tree { + if let DataTree::Leaf(val) = tree { + if self.index == 0 { + self.index += 1; + return Some(val); + } else { + return None; + } } - TreeEntry::Tree(subtree) => match self.inner { - Some(ref mut inner) => { - if let Some(val) = inner.next() { - let return_val = self.inner_next; - self.inner_next = Some(val); - return_val + let DataTree::Branch(branch) = tree else { + unreachable!("Must be a branch variant"); + }; + if self.index >= branch.data.len() { + return None; + } + let entry = &branch.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + Some(val) + } + DataTree::Branch(sub_branch) => { + if let Some(ref mut inner) = self.inner { + if let Some(val) = inner.next() { + let return_val = self.inner_next.replace(val).unwrap(); + Some(return_val) + } else { + self.inner = None; + self.index += 1; + let return_val = self.inner_next.take().unwrap(); + self.inner_next = None; + Some(return_val) + } } else { - self.inner = None; - self.index += 1; - let return_val = self.inner_next; - self.inner_next = None; - return_val + let inner = sub_branch.iter_leaves(); + self.inner = Some(Box::new(inner)); + let val = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some(val) } } - - None => { - self.inner = Some(Box::new(subtree.iter_leaves())); - let val = self.inner.as_mut().map(|x| x.next().unwrap()); - self.inner_next = self.inner.as_mut().and_then(|x| x.next()); - if self.inner_next.is_none() { - self.index += 1; - self.inner = None; - self.inner_next = None; - } - val + } + } else if let Some(subtree) = self.branch { + if self.index >= subtree.data.len() { + return None; + } + let entry = &subtree.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + Some(val) } - }, + DataTree::Branch(subtree) => match self.inner { + Some(ref mut inner) => { + if let Some(val) = inner.next() { + let return_val = self.inner_next.replace(val).unwrap(); + Some(return_val) + } else { + self.inner = None; + self.index += 1; + let return_val = self.inner_next.take().unwrap(); + self.inner_next = None; + Some(return_val) + } + } + None => { + let inner = subtree.iter_leaves(); + self.inner = Some(Box::new(inner)); + let val = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some(val) + } + }, + } + } else { + None } } } -impl<'a, T: PartialEq> PartialEq for DataTree<'a, T> { +impl PartialEq for DataTree { fn eq(&self, other: &DataTree) -> bool { - self.data == other.data && self.keys == other.keys + match self { + Self::Leaf(val) => { + let Self::Leaf(other_val) = other else { + return false; + }; + val == other_val + } + Self::Branch(branch) => { + let Self::Branch(other) = other else { + return false; + }; + branch.data == other.data && branch.keys == other.keys + } + } } } @@ -536,11 +750,10 @@ mod test { let mut inner_tree = DataTree::new(); inner_tree.insert_leaf("y", 10); let mut tree = DataTree::new(); - tree.insert_tree("x", inner_tree.clone()); + tree.insert_branch("x", inner_tree.clone()); tree.insert_leaf("z", 100); assert_eq!(None, tree.get_by_str_key("z.y")); - let expected = TreeEntry::Tree(inner_tree); - assert_eq!(Some(&expected), tree.get_by_str_key("x")); + assert_eq!(Some(&inner_tree), tree.get_by_str_key("x")); } #[test] @@ -553,9 +766,9 @@ mod test { inner_inner_tree.push_leaf(3); inner_inner_tree.push_leaf(4); inner_inner_tree.push_leaf(5); - inner_tree.push_tree(inner_inner_tree); + inner_tree.push_branch(inner_inner_tree); let mut tree = DataTree::new(); - tree.insert_tree("x", inner_tree.clone()); + tree.insert_branch("x", inner_tree.clone()); tree.insert_leaf("z", 100); assert_eq!( vec![10, 1, 2, 3, 4, 5, 100], @@ -573,9 +786,9 @@ mod test { inner_inner_tree.push_leaf(3); inner_inner_tree.push_leaf(4); inner_inner_tree.push_leaf(5); - inner_tree.push_tree(inner_inner_tree); + inner_tree.push_branch(inner_inner_tree); let mut tree = DataTree::new(); - tree.insert_tree("x", inner_tree.clone()); + tree.insert_branch("x", inner_tree.clone()); tree.insert_leaf("z", 100); let expected_paths = vec![ vec![PathEntry::Index(0), PathEntry::Index(0)], @@ -624,18 +837,18 @@ mod test { inner_inner_tree.insert_leaf("a", 4); inner_inner_tree.push_leaf(5); let inner_inner_tree_expected = inner_inner_tree.clone(); - inner_tree.insert_tree("yyy", inner_inner_tree); + inner_tree.insert_branch("yyy", inner_inner_tree); let mut tree = DataTree::new(); - tree.insert_tree("x", inner_tree.clone()); + tree.insert_branch("x", inner_tree.clone()); tree.insert_leaf("z", 100); let result = tree.get_by_str_key("x.yyy.a"); - assert_eq!(result, Some(&TreeEntry::Leaf(4))); - assert_eq!(tree.get_by_str_key("z"), Some(&TreeEntry::Leaf(100))); + assert_eq!(result, Some(&DataTree::Leaf(4))); + assert_eq!(tree.get_by_str_key("z"), Some(&DataTree::Leaf(100))); assert_eq!( tree.get_by_str_key("x.yyy"), - Some(&TreeEntry::Tree(inner_inner_tree_expected)) + Some(&inner_inner_tree_expected), ); - assert_eq!(tree.get_by_str_key("x.yy"), Some(&TreeEntry::Leaf(1))); + assert_eq!(tree.get_by_str_key("x.yy"), Some(&DataTree::Leaf(1))); } #[test] @@ -648,9 +861,9 @@ mod test { inner_inner_tree.push_leaf(3); inner_inner_tree.insert_leaf("a", 4); inner_inner_tree.push_leaf(5); - inner_tree.insert_tree("yyy", inner_inner_tree); + inner_tree.insert_branch("yyy", inner_inner_tree); let mut tree = DataTree::new(); - tree.insert_tree("x", inner_tree.clone()); + tree.insert_branch("x", inner_tree.clone()); tree.insert_leaf("z", 100); assert_eq!(None, tree.get_by_str_key("a")); assert_eq!(None, tree.get_by_str_key("x.yyyy")); diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 47072ccd4ed2..e6036bf8669e 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -12,4 +12,4 @@ mod data_tree; -pub use data_tree::{DataTree, PathEntry, TreeEntry}; +pub use data_tree::{DataTree, PathEntry}; From 9755531bd3c3822e7e7dd40337c69ae53020329a Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 28 Apr 2026 13:47:46 -0400 Subject: [PATCH 04/10] Fix MSRV clippy error --- crates/providers/src/data_tree.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/providers/src/data_tree.rs b/crates/providers/src/data_tree.rs index 13b881bfc30f..541dc14aaab2 100644 --- a/crates/providers/src/data_tree.rs +++ b/crates/providers/src/data_tree.rs @@ -33,6 +33,12 @@ pub struct DataTreeBranch { keys: HashMap, } +impl Default for DataTreeBranch { + fn default() -> Self { + Self::new() + } +} + impl DataTreeBranch { /// Construct a new empty [`DataTreeBranch`] pub fn new() -> Self { From 0d132381899a5f74e9c32a0d8ac74ca83833f159 Mon Sep 17 00:00:00 2001 From: Ian Hincks Date: Fri, 10 Apr 2026 10:07:14 -0400 Subject: [PATCH 05/10] Add DType, Tensor & friends to providers crate --- Cargo.lock | 4 + crates/providers/Cargo.toml | 7 + crates/providers/src/lib.rs | 2 +- crates/providers/src/tensor.rs | 381 +++++++++++++++++++++++++++++++++ 4 files changed, 393 insertions(+), 1 deletion(-) create mode 100644 crates/providers/src/tensor.rs diff --git a/Cargo.lock b/Cargo.lock index b2dfb578593a..613d376505ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2342,7 +2342,11 @@ dependencies = [ name = "qiskit-providers" version = "2.5.0-dev" dependencies = [ + "anyhow", "hashbrown 0.15.5", + "ndarray", + "num-complex", + "rustworkx-core", ] [[package]] diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index 84b0184037af..7e2805f33e15 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -9,6 +9,9 @@ license.workspace = true name = "qiskit_providers" [dependencies] +rustworkx-core.workspace = true +anyhow.workspace = true +num-complex.workspace = true [dependencies.hashbrown] workspace = true @@ -16,3 +19,7 @@ features = ["rayon", "serde"] [lints] workspace = true + +[dependencies.ndarray] +workspace = true +features = ["rayon", "approx"] \ No newline at end of file diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index e6036bf8669e..c5036f18423b 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -11,5 +11,5 @@ // that they have been altered from the originals. mod data_tree; - +pub mod tensor; pub use data_tree::{DataTree, PathEntry}; diff --git a/crates/providers/src/tensor.rs b/crates/providers/src/tensor.rs new file mode 100644 index 000000000000..017765d5e1f1 --- /dev/null +++ b/crates/providers/src/tensor.rs @@ -0,0 +1,381 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use ndarray::ArrayD; +use num_complex::Complex; +use std::fmt; + +/// The possible data types for a Tensor. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DType { + C128, // complex + C64, + F64, // float + F32, + I64, // signed ints + I32, + I16, + I8, + U64, // unsigned ints + U32, + U16, + U8, + Bit, // bool +} + +impl fmt::Display for DType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let string_repr = match self { + DType::C128 => "C128", + DType::C64 => "C64", + DType::F64 => "F64", + DType::F32 => "F32", + DType::I64 => "I64", + DType::I32 => "I32", + DType::I16 => "I16", + DType::I8 => "I8", + DType::U64 => "U64", + DType::U32 => "U32", + DType::U16 => "U16", + DType::U8 => "U8", + DType::Bit => "Bit", + }; + write!(f, "{string_repr}") + } +} + +// A tensor data type whose value is yet unknown, but named. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DTypeVar { + pub name: String, +} + +impl> From for DTypeVar { + fn from(value: T) -> Self { + Self { name: value.into() } + } +} + +/// A tensor data type whose value is yet unknown, but will be the promotion of others. +#[derive(Debug, Clone)] +pub struct DTypePromotion { + pub args: Vec, +} + +impl>> From for DTypePromotion { + fn from(args: T) -> Self { + Self { args: args.into() } + } +} + +/// A tensor data type, known or unknown. +#[derive(Debug, Clone)] +pub enum DTypeLike { + Concrete(DType), + Var(DTypeVar), + Promotion(DTypePromotion), +} + +/// Promote a pair of DTypes to the smallest type compatible with both. +/// +/// QuantumProgram operations often, but not necessarily, use this promotion rule +/// to determine their output type. +/// +/// This function implements the same promotion rules as NumPy, modulo that we don't +/// need to contend with the arbitrary precision types for each type kind, and that +/// we omit F16 entirely because it's ustable in rust: +/// https://numpy.org/doc/stable/reference/arrays.promotion.html#numerical-promotion +/// In short, if you view the linked diagram as a DAG, this function hard-codes the +/// least-common-descendent algorithm. +pub fn promotion(lhs: DType, rhs: DType) -> DType { + use DType::*; + + // painfully write a lookup table as a nested match statement. to check if it's right, + // compare agaist the linked image, or more easily, study the test + // test_promotion_against_promotion_dag that tests every input combination. + match lhs { + C128 => C128, + + C64 => match rhs { + U32 | U64 | I32 | I64 | F64 | C128 => C128, + _ => C64, + }, + + F64 => match rhs { + C64 | C128 => C128, + _ => F64, + }, + + F32 => match rhs { + C128 => C128, + C64 => C64, + U32 | U64 | I32 | I64 | F64 => F64, + _ => F32, + }, + + I64 => match rhs { + C64 | C128 => C128, + U64 | F32 | F64 => F64, + _ => I64, + }, + + I32 => match rhs { + C64 | C128 => C128, + U64 | F32 | F64 => F64, + U32 | I64 => I64, + _ => I32, + }, + + I16 => match rhs { + U64 => F64, + U32 => I64, + U16 => I32, + Bit | U8 | I8 => I16, + _ => rhs, + }, + + I8 => match rhs { + U64 => F64, + U32 => I64, + U16 => I32, + U8 => I16, + Bit => I8, + _ => rhs, + }, + + U64 => match rhs { + C128 | C64 => C128, + F32 | F64 | I8 | I16 | I32 | I64 => F64, + _ => U64, + }, + + U32 => match rhs { + C64 | C128 => C128, + F32 | F64 => F64, + I8 | I16 | I32 | I64 => I64, + U64 => U64, + _ => U32, + }, + + U16 => match rhs { + I8 | I16 => I32, + Bit | U8 => U16, + _ => rhs, + }, + + U8 => match rhs { + I8 => I16, + Bit => U8, + _ => rhs, + }, + + Bit => rhs, + } +} + +/// A tensor axis dimension. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Dim { + /// A known size. + Fixed(usize), + /// An unresolved, named size. + Named(String), +} + +/// A specification of a tensor without any data. +#[derive(Debug, Clone)] +pub struct TensorType { + /// The type of the tensor. + pub dtype: DTypeLike, + /// The shape of the tensor, possibly with axes of unknown size. + pub shape: Vec, + /// Whether the tensor supports leading-axis (i.e. NumPy-style) broadcasting semantics. + pub broadcastable: bool, +} + +impl TensorType { + // Return a dimension vector if all are sizes are fixed, None otherwise. + pub fn concrete_shape(&self) -> Option> { + let mut out = Vec::with_capacity(self.shape.len()); + for d in &self.shape { + match d { + Dim::Fixed(n) => out.push(*n), + Dim::Named(_) => return None, + } + } + Some(out) + } +} + +/// A tensor of one of the supported dtypes. +#[derive(Debug, Clone)] +pub enum Tensor { + C64(ArrayD>), + C128(ArrayD>), + F32(ArrayD), + F64(ArrayD), + I8(ArrayD), + I16(ArrayD), + I32(ArrayD), + I64(ArrayD), + U8(ArrayD), + U16(ArrayD), + U32(ArrayD), + U64(ArrayD), + Bit(ArrayD), +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_promotion_against_promotion_dag() { + use DType::*; + use rustworkx_core::dag_algo::lexicographical_topological_sort; + use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex}; + use rustworkx_core::traversal::descendants; + use std::collections::{HashMap, HashSet}; + + // define a DAG that implements all promotion rules; two DTypes + // should be promoted to their least common descendent in the DAG + let mut g: DiGraph = DiGraph::new(); + let mut idx: HashMap = HashMap::new(); + + let nodes = [ + Bit, U8, U16, U32, U64, I8, I16, I32, I64, F32, F64, C64, C128, + ]; + + for &dtype in &nodes { + idx.insert(dtype, g.add_node(dtype)); + } + + // within-kind promotions + g.add_edge(idx[&U8], idx[&U16], ()); + g.add_edge(idx[&U16], idx[&U32], ()); + g.add_edge(idx[&U32], idx[&U64], ()); + + g.add_edge(idx[&I8], idx[&I16], ()); + g.add_edge(idx[&I16], idx[&I32], ()); + g.add_edge(idx[&I32], idx[&I64], ()); + + g.add_edge(idx[&F32], idx[&F64], ()); + + g.add_edge(idx[&C64], idx[&C128], ()); + + // bit promotions + g.add_edge(idx[&Bit], idx[&U8], ()); + g.add_edge(idx[&Bit], idx[&I8], ()); + + // uint promotions + g.add_edge(idx[&U8], idx[&I16], ()); + g.add_edge(idx[&U16], idx[&I32], ()); + g.add_edge(idx[&U16], idx[&F32], ()); + g.add_edge(idx[&U32], idx[&I64], ()); + g.add_edge(idx[&U64], idx[&F64], ()); + + // int promotions + g.add_edge(idx[&I16], idx[&F32], ()); + g.add_edge(idx[&I32], idx[&F64], ()); + g.add_edge(idx[&I64], idx[&F64], ()); + + // float promotions + g.add_edge(idx[&F32], idx[&C64], ()); + g.add_edge(idx[&F64], idx[&C128], ()); + + let order = lexicographical_topological_sort( + &g, + |n: NodeIndex| Ok::(n.index()), + false, + None, + ) + .ok() + .unwrap(); + + let least_common_decendent = move |a: &DType, b: &DType| -> DType { + let da: HashSet<_> = descendants(&g, idx[&a]).collect(); + let db: HashSet<_> = descendants(&g, idx[&b]).collect(); + let common: HashSet = da.intersection(&db).copied().collect(); + let least_idx = order.iter().find(|n| common.contains(n)).unwrap(); + nodes[least_idx.index()] + }; + + for &a in &nodes { + for &b in &nodes { + assert_eq!( + promotion(a, b), + least_common_decendent(&a, &b), + "For promotion ({a}, {b})" + ) + } + } + } + + #[test] + fn test_promotion_idempotence() { + use DType::*; + let nodes = [ + Bit, U8, U16, U32, U64, I8, I16, I32, I64, F32, F64, C64, C128, + ]; + + for &a in &nodes { + assert_eq!(promotion(a, a), a, "For promotion ({a}, {a})") + } + } + + #[test] + fn test_promotion_commutativity() { + use DType::*; + let nodes = [ + Bit, U8, U16, U32, U64, I8, I16, I32, I64, F32, F64, C64, C128, + ]; + + for &a in &nodes { + for &b in &nodes { + assert_eq!(promotion(a, b), promotion(b, a), "For promotion ({a}, {b})") + } + } + } + + #[test] + fn test_tensor_type_concrete_shape() { + assert_eq!( + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![Dim::Fixed(3)], + broadcastable: false, + } + .concrete_shape(), + Some(vec![3]) + ); + + assert_eq!( + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![Dim::Fixed(3), Dim::Fixed(8)], + broadcastable: true, + } + .concrete_shape(), + Some(vec![3, 8]) + ); + + assert_eq!( + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![Dim::Fixed(3), Dim::Named("foo".into())], + broadcastable: false, + } + .concrete_shape(), + None + ); + } +} From 23e705529290610dc919887867dee776b65aab6c Mon Sep 17 00:00:00 2001 From: Ian Hincks Date: Wed, 29 Apr 2026 15:33:16 -0400 Subject: [PATCH 06/10] Add operation implementations for Tensor --- crates/providers/src/tensor.rs | 316 +++++++++++++++++++++++++++++++-- 1 file changed, 304 insertions(+), 12 deletions(-) diff --git a/crates/providers/src/tensor.rs b/crates/providers/src/tensor.rs index 017765d5e1f1..25e50645fa68 100644 --- a/crates/providers/src/tensor.rs +++ b/crates/providers/src/tensor.rs @@ -10,7 +10,7 @@ // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. -use ndarray::ArrayD; +use ndarray::{ArrayD, IxDyn, Zip}; use num_complex::Complex; use std::fmt; @@ -19,13 +19,13 @@ use std::fmt; pub enum DType { C128, // complex C64, - F64, // float + F64, // real F32, - I64, // signed ints + I64, // signed integer I32, I16, I8, - U64, // unsigned ints + U64, // unsigned integer U32, U16, U8, @@ -53,9 +53,10 @@ impl fmt::Display for DType { } } -// A tensor data type whose value is yet unknown, but named. +/// A tensor dtype that is unknown but identified by name. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DTypeVar { + /// The variable name. pub name: String, } @@ -68,6 +69,7 @@ impl> From for DTypeVar { /// A tensor data type whose value is yet unknown, but will be the promotion of others. #[derive(Debug, Clone)] pub struct DTypePromotion { + /// The dtype arguments to promote over. pub args: Vec, } @@ -80,14 +82,17 @@ impl>> From for DTypePromotion { /// A tensor data type, known or unknown. #[derive(Debug, Clone)] pub enum DTypeLike { + /// A fully resolved dtype. Concrete(DType), + /// A dtype identified by a variable name, to be resolved later. Var(DTypeVar), + /// A dtype that is the promotion of one or more other dtypes. Promotion(DTypePromotion), } /// Promote a pair of DTypes to the smallest type compatible with both. /// -/// QuantumProgram operations often, but not necessarily, use this promotion rule +/// QuantumProgram nodes often, but not necessarily, use this promotion rule /// to determine their output type. /// /// This function implements the same promotion rules as NumPy, modulo that we don't @@ -203,7 +208,7 @@ pub struct TensorType { } impl TensorType { - // Return a dimension vector if all are sizes are fixed, None otherwise. + /// Return a dimension vector if all sizes are fixed, or `None` if any are named. pub fn concrete_shape(&self) -> Option> { let mut out = Vec::with_capacity(self.shape.len()); for d in &self.shape { @@ -219,21 +224,308 @@ impl TensorType { /// A tensor of one of the supported dtypes. #[derive(Debug, Clone)] pub enum Tensor { - C64(ArrayD>), + C64(ArrayD>), // complex C128(ArrayD>), - F32(ArrayD), + F32(ArrayD), // real F64(ArrayD), - I8(ArrayD), + I8(ArrayD), // signed integer I16(ArrayD), I32(ArrayD), I64(ArrayD), - U8(ArrayD), + U8(ArrayD), // unsigned integer U16(ArrayD), U32(ArrayD), U64(ArrayD), - Bit(ArrayD), + Bit(ArrayD), // bool } +/// Cast an `ArrayD` of a real numeric type to any supported dtype. +macro_rules! cast_real { + ($arr:expr, $src:ty, $target:expr) => { + match $target { + DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8)), + DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8)), + DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16)), + DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32)), + DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64)), + DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8)), + DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16)), + DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32)), + DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64)), + DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32)), + DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64)), + DType::C64 => Tensor::C64($arr.mapv(|x: $src| Complex::new(x as f32, 0.0))), + DType::C128 => Tensor::C128($arr.mapv(|x: $src| Complex::new(x as f64, 0.0))), + } + }; +} + +/// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets). +macro_rules! cast_complex { + ($arr:expr, $target:expr) => { + match $target { + DType::C64 => Tensor::C64($arr.mapv(|x| Complex::new(x.re as f32, x.im as f32))), + DType::C128 => Tensor::C128($arr.mapv(|x| Complex::new(x.re as f64, x.im as f64))), + _ => panic!("cannot cast complex tensor to a real dtype"), + } + }; +} + +/// Element-wise binary operation on two arrays with NumPy-style broadcasting. +/// +/// Unlike ndarray's built-in arithmetic operators which handle broadcasting automatically, +/// this helper is needed for operations without a Rust operator (e.g. `pow`). +fn broadcast_elementwise(a: &ArrayD, b: &ArrayD, op: F) -> ArrayD +where + T: Clone, + F: Fn(&T, &T) -> T, +{ + let ndim = a.ndim().max(b.ndim()); + let out_shape: Vec = (0..ndim) + .map(|i| { + let d_a = if i >= ndim - a.ndim() { + a.shape()[i - (ndim - a.ndim())] + } else { + 1 + }; + let d_b = if i >= ndim - b.ndim() { + b.shape()[i - (ndim - b.ndim())] + } else { + 1 + }; + match (d_a, d_b) { + (x, y) if x == y => x, + (1, y) => y, + (x, 1) => x, + _ => panic!( + "shapes {:?} and {:?} are not broadcast-compatible", + a.shape(), + b.shape() + ), + } + }) + .collect(); + let out_ix = IxDyn(&out_shape); + let a_bc = a.broadcast(out_ix.clone()).expect("broadcast failed"); + let b_bc = b.broadcast(out_ix).expect("broadcast failed"); + Zip::from(a_bc).and(b_bc).map_collect(op) +} + +impl Tensor { + /// Return the dtype of this tensor. + pub fn dtype(&self) -> DType { + match self { + Tensor::C128(_) => DType::C128, + Tensor::C64(_) => DType::C64, + Tensor::F64(_) => DType::F64, + Tensor::F32(_) => DType::F32, + Tensor::I64(_) => DType::I64, + Tensor::I32(_) => DType::I32, + Tensor::I16(_) => DType::I16, + Tensor::I8(_) => DType::I8, + Tensor::U64(_) => DType::U64, + Tensor::U32(_) => DType::U32, + Tensor::U16(_) => DType::U16, + Tensor::U8(_) => DType::U8, + Tensor::Bit(_) => DType::Bit, + } + } + + /// Return the shape of this tensor as a slice of dimension sizes. + pub fn shape(&self) -> &[usize] { + match self { + Tensor::C128(a) => a.shape(), + Tensor::C64(a) => a.shape(), + Tensor::F64(a) => a.shape(), + Tensor::F32(a) => a.shape(), + Tensor::I64(a) => a.shape(), + Tensor::I32(a) => a.shape(), + Tensor::I16(a) => a.shape(), + Tensor::I8(a) => a.shape(), + Tensor::U64(a) => a.shape(), + Tensor::U32(a) => a.shape(), + Tensor::U16(a) => a.shape(), + Tensor::U8(a) => a.shape(), + Tensor::Bit(a) => a.shape(), + } + } + + /// Return the [`TensorType`] that describes this tensor's dtype and concrete shape. + pub fn tensor_type(&self) -> TensorType { + TensorType { + dtype: DTypeLike::Concrete(self.dtype()), + shape: self.shape().iter().map(|&n| Dim::Fixed(n)).collect(), + broadcastable: false, + } + } + + /// Element-wise power with NumPy-style broadcasting. + /// + /// For integer types the exponent is cast to `u32`; negative integer exponents + /// are not supported. + pub fn pow(&self, rhs: &Tensor) -> Tensor { + match (self, rhs) { + (Tensor::F32(a), Tensor::F32(b)) => { + Tensor::F32(broadcast_elementwise(a, b, |&x, &y| x.powf(y))) + } + (Tensor::F64(a), Tensor::F64(b)) => { + Tensor::F64(broadcast_elementwise(a, b, |&x, &y| x.powf(y))) + } + (Tensor::C64(a), Tensor::C64(b)) => { + Tensor::C64(broadcast_elementwise(a, b, |&x, &y| x.powc(y))) + } + (Tensor::C128(a), Tensor::C128(b)) => { + Tensor::C128(broadcast_elementwise(a, b, |&x, &y| x.powc(y))) + } + (Tensor::I8(a), Tensor::I8(b)) => { + Tensor::I8(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::I16(a), Tensor::I16(b)) => { + Tensor::I16(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::I32(a), Tensor::I32(b)) => { + Tensor::I32(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::I64(a), Tensor::I64(b)) => { + Tensor::I64(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::U8(a), Tensor::U8(b)) => { + Tensor::U8(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::U16(a), Tensor::U16(b)) => { + Tensor::U16(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::U32(a), Tensor::U32(b)) => { + Tensor::U32(broadcast_elementwise(a, b, |&x, &y| x.pow(y))) + } + (Tensor::U64(a), Tensor::U64(b)) => { + Tensor::U64(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + _ => panic!("type mismatch in Tensor::pow"), + } + } + + /// Cast this tensor to `target`, consuming it. Returns `self` unchanged if already that dtype. + pub fn cast(self, target: DType) -> Tensor { + if self.dtype() == target { + return self; + } + match &self { + Tensor::Bit(a) | Tensor::U8(a) => cast_real!(a, u8, target), + Tensor::U16(a) => cast_real!(a, u16, target), + Tensor::U32(a) => cast_real!(a, u32, target), + Tensor::U64(a) => cast_real!(a, u64, target), + Tensor::I8(a) => cast_real!(a, i8, target), + Tensor::I16(a) => cast_real!(a, i16, target), + Tensor::I32(a) => cast_real!(a, i32, target), + Tensor::I64(a) => cast_real!(a, i64, target), + Tensor::F32(a) => cast_real!(a, f32, target), + Tensor::F64(a) => cast_real!(a, f64, target), + Tensor::C64(a) => cast_complex!(a, target), + Tensor::C128(a) => cast_complex!(a, target), + } + } +} + +/// Implement `From<&[T]>`, `From<&[T; N]>`, and `From>` for a given `Tensor` variant. +macro_rules! impl_tensor_from { + ($variant:ident, $t:ty) => { + impl From<&[$t]> for Tensor { + fn from(data: &[$t]) -> Self { + Tensor::$variant(ndarray::arr1(data).into_dyn()) + } + } + impl From<[$t; N]> for Tensor { + fn from(data: [$t; N]) -> Self { + Tensor::$variant(ndarray::arr1(&data).into_dyn()) + } + } + impl From> for Tensor { + fn from(data: ArrayD<$t>) -> Self { + Tensor::$variant(data) + } + } + }; +} + +impl_tensor_from!(C128, Complex); +impl_tensor_from!(C64, Complex); +impl_tensor_from!(F64, f64); +impl_tensor_from!(F32, f32); +impl_tensor_from!(I64, i64); +impl_tensor_from!(I32, i32); +impl_tensor_from!(I16, i16); +impl_tensor_from!(I8, i8); +impl_tensor_from!(U64, u64); +impl_tensor_from!(U32, u32); +impl_tensor_from!(U16, u16); +impl_tensor_from!(U8, u8); // u8 → U8; Bit requires explicit construction + +/// Implement a standard Rust binary operator trait for `Tensor` and `&Tensor`. +macro_rules! impl_tensor_binop { + ($trait:ident, $method:ident, $op:tt) => { + impl std::ops::$trait for &Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { + match (self, rhs) { + (Tensor::C128(a), Tensor::C128(b)) => Tensor::C128(a $op b), + (Tensor::C64(a), Tensor::C64(b)) => Tensor::C64(a $op b), + (Tensor::F64(a), Tensor::F64(b)) => Tensor::F64(a $op b), + (Tensor::F32(a), Tensor::F32(b)) => Tensor::F32(a $op b), + (Tensor::I64(a), Tensor::I64(b)) => Tensor::I64(a $op b), + (Tensor::I32(a), Tensor::I32(b)) => Tensor::I32(a $op b), + (Tensor::I16(a), Tensor::I16(b)) => Tensor::I16(a $op b), + (Tensor::I8(a), Tensor::I8(b)) => Tensor::I8(a $op b), + (Tensor::U64(a), Tensor::U64(b)) => Tensor::U64(a $op b), + (Tensor::U32(a), Tensor::U32(b)) => Tensor::U32(a $op b), + (Tensor::U16(a), Tensor::U16(b)) => Tensor::U16(a $op b), + (Tensor::U8(a), Tensor::U8(b)) => Tensor::U8(a $op b), + _ => panic!("type mismatch in Tensor::{}", stringify!($method)), + } + } + } + impl std::ops::$trait for Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { &self $op &rhs } + } + }; +} + +/// Like [`impl_tensor_binop!`], but omits complex variants for ops that don't support them +/// (e.g. `Rem`, which `num_complex` does not implement). +macro_rules! impl_tensor_binop_real { + ($trait:ident, $method:ident, $op:tt) => { + impl std::ops::$trait for &Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { + match (self, rhs) { + (Tensor::F64(a), Tensor::F64(b)) => Tensor::F64(a $op b), + (Tensor::F32(a), Tensor::F32(b)) => Tensor::F32(a $op b), + (Tensor::I64(a), Tensor::I64(b)) => Tensor::I64(a $op b), + (Tensor::I32(a), Tensor::I32(b)) => Tensor::I32(a $op b), + (Tensor::I16(a), Tensor::I16(b)) => Tensor::I16(a $op b), + (Tensor::I8(a), Tensor::I8(b)) => Tensor::I8(a $op b), + (Tensor::U64(a), Tensor::U64(b)) => Tensor::U64(a $op b), + (Tensor::U32(a), Tensor::U32(b)) => Tensor::U32(a $op b), + (Tensor::U16(a), Tensor::U16(b)) => Tensor::U16(a $op b), + (Tensor::U8(a), Tensor::U8(b)) => Tensor::U8(a $op b), + _ => panic!("type mismatch or unsupported dtype in Tensor::{}", stringify!($method)), + } + } + } + impl std::ops::$trait for Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { &self $op &rhs } + } + }; +} + +impl_tensor_binop!(Add, add, +); +impl_tensor_binop!(Sub, sub, -); +impl_tensor_binop!(Mul, mul, *); +impl_tensor_binop!(Div, div, /); +impl_tensor_binop_real!(Rem, rem, %); + #[cfg(test)] mod test { use super::*; From 9ffc188d73ce3a4585b878c05e8dedd3fc656b13 Mon Sep 17 00:00:00 2001 From: Ian Hincks Date: Thu, 30 Apr 2026 13:20:12 -0400 Subject: [PATCH 07/10] Add ProgramNode --- Cargo.lock | 1 + crates/providers/Cargo.toml | 1 + crates/providers/src/lib.rs | 5 ++++ crates/providers/src/program_node.rs | 40 ++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+) create mode 100644 crates/providers/src/program_node.rs diff --git a/Cargo.lock b/Cargo.lock index 613d376505ee..7b614eacfad9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2347,6 +2347,7 @@ dependencies = [ "ndarray", "num-complex", "rustworkx-core", + "thiserror 2.0.18", ] [[package]] diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index 7e2805f33e15..2a09144addc7 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -12,6 +12,7 @@ name = "qiskit_providers" rustworkx-core.workspace = true anyhow.workspace = true num-complex.workspace = true +thiserror.workspace = true [dependencies.hashbrown] workspace = true diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index c5036f18423b..cb59bffbca6b 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -11,5 +11,10 @@ // that they have been altered from the originals. mod data_tree; +mod program_node; +mod store; pub mod tensor; + pub use data_tree::{DataTree, PathEntry}; +pub use program_node::ProgramNode; +pub use store::Store; diff --git a/crates/providers/src/program_node.rs b/crates/providers/src/program_node.rs new file mode 100644 index 000000000000..be8980d250ca --- /dev/null +++ b/crates/providers/src/program_node.rs @@ -0,0 +1,40 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::tensor::{Tensor, TensorType}; + +/// A node in a quantum program graph that transforms tensors. +pub trait ProgramNode { + /// The name of this program node. + fn name(&self) -> &'static str; + + /// The namespace this program node belongs to. + fn namespace(&self) -> &'static str; + + /// The namespace and name as one string. + fn full_name(&self) -> String { + format_args!("{}.{}", self.namespace(), self.name()).to_string() + } + + /// The inputs expected at `call` time. + fn input_types(&self) -> &DataTree; + + /// The outputs promised on `call` return. + fn output_types(&self) -> &DataTree; + + /// Whether this program node implements the call method. + fn implements_call(&self) -> bool; + + /// The action of this program node. + fn call(&self, args: &DataTree) -> anyhow::Result>; +} From 03e8e5b00ef36095c8459f1e531df6543681fd6c Mon Sep 17 00:00:00 2001 From: Ian Hincks Date: Wed, 29 Apr 2026 15:25:32 -0400 Subject: [PATCH 08/10] Add DataTree.iter_children() --- crates/providers/src/data_tree.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/crates/providers/src/data_tree.rs b/crates/providers/src/data_tree.rs index 541dc14aaab2..0558bf00f34e 100644 --- a/crates/providers/src/data_tree.rs +++ b/crates/providers/src/data_tree.rs @@ -309,6 +309,33 @@ impl DataTree { } } + /// Iterate over direct children, yielding `(optional_key, child)` pairs in index order. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.push_leaf(10); // unnamed + /// tree.insert_leaf("b", 20); // named + /// tree.push_leaf(30); // unnamed + /// let children: Vec<_> = tree.iter_children().collect(); + /// assert_eq!(children[0], (None, &DataTree::Leaf(10))); + /// assert_eq!(children[1], (Some("b"), &DataTree::Leaf(20))); + /// assert_eq!(children[2], (None, &DataTree::Leaf(30))); + /// ``` + pub fn iter_children(&self) -> impl Iterator, &DataTree)> + '_ { + let branch = match self { + Self::Branch(branch) => branch, + Self::Leaf(_) => panic!("called iter_children() on a leaf node"), + }; + let rev: HashMap = branch.keys.iter().map(|(k, &v)| (v, k.as_str())).collect(); + branch + .data + .iter() + .enumerate() + .map(move |(i, child)| (rev.get(&i).copied(), child)) + } + /// Insert a new leaf node with an associated string key /// /// If a key is provided that is already in the tree the new value will be associated with From bda047a45ea2a02314dcd93bd30aad0ffe6898a0 Mon Sep 17 00:00:00 2001 From: Ian Hincks Date: Wed, 29 Apr 2026 15:26:50 -0400 Subject: [PATCH 09/10] Add Store impl of ProgramNode --- crates/providers/src/store.rs | 159 ++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 crates/providers/src/store.rs diff --git a/crates/providers/src/store.rs b/crates/providers/src/store.rs new file mode 100644 index 000000000000..6e0e84172bb2 --- /dev/null +++ b/crates/providers/src/store.rs @@ -0,0 +1,159 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{Tensor, TensorType}; +use std::sync::OnceLock; + +/// A program node that owns constant data and outputs it unconditionally. +/// +/// `Store` takes no inputs; its `call()` always returns the data it was constructed with. +/// In a data-flow graph, `Store` nodes play the role of constants — they are wired to +/// the input ports of computation nodes to supply fixed values. +pub struct Store { + data: DataTree, + output_types: DataTree, +} + +impl Store { + /// Construct a new `Store` holding the given data. + pub fn new(data: DataTree) -> Self { + let output_types = derive_output_types(&data); + Self { data, output_types } + } + + /// Return a reference to the stored data. + pub fn data(&self) -> &DataTree { + &self.data + } +} + +/// Recursively derive output types from concrete tensor data. +fn derive_output_types(data: &DataTree) -> DataTree { + match data { + DataTree::Leaf(tensor) => DataTree::new_leaf(tensor.tensor_type()), + DataTree::Branch(_) => { + let mut result = DataTree::with_capacity(data.len()); + for (key, child) in data.iter_children() { + let child_type = derive_output_types(child); + if let Some(k) = key { + result.insert_branch(k, child_type); + } else { + result.push_branch(child_type); + } + } + result + } + } +} + +impl ProgramNode for Store { + fn name(&self) -> &'static str { + "store" + } + + fn namespace(&self) -> &'static str { + "core" + } + + fn input_types(&self) -> &DataTree { + static EMPTY: OnceLock> = OnceLock::new(); + EMPTY.get_or_init(DataTree::new) + } + + fn output_types(&self) -> &DataTree { + &self.output_types + } + + fn implements_call(&self) -> bool { + true + } + + fn call(&self, _args: &DataTree) -> anyhow::Result> { + Ok(self.data.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::{DType, DTypeLike, Dim, Tensor}; + + #[test] + fn test_store_leaf_call() { + let data = DataTree::new_leaf(Tensor::from([1.0_f64, 2.0, 3.0])); + let store = Store::new(data); + let result = store.call(&DataTree::new()).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0]); + } + + #[test] + fn test_store_output_types_leaf() { + let data = DataTree::new_leaf(Tensor::from([1.0_f64, 2.0, 3.0])); + let store = Store::new(data); + let DataTree::Leaf(tt) = store.output_types() else { + panic!("expected leaf output type"); + }; + assert!(matches!(tt.dtype, DTypeLike::Concrete(DType::F64))); + assert_eq!(tt.shape, vec![Dim::Fixed(3)]); + assert!(!tt.broadcastable); + } + + #[test] + fn test_store_output_types_2d() { + use ndarray::arr2; + let data = + DataTree::new_leaf(Tensor::F64(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn())); + let store = Store::new(data); + let DataTree::Leaf(tt) = store.output_types() else { + panic!("expected leaf output type"); + }; + assert_eq!(tt.shape, vec![Dim::Fixed(2), Dim::Fixed(2)]); + } + + #[test] + fn test_store_branched() { + let mut data = DataTree::new(); + data.insert_leaf("a", Tensor::from([1.0_f64, 2.0])); + data.insert_leaf("b", Tensor::from([10_i32, 20, 30])); + let store = Store::new(data); + + assert!(store.input_types().is_empty()); + assert_eq!(store.name(), "store"); + assert_eq!(store.namespace(), "core"); + assert_eq!(store.full_name(), "core.store"); + + let out_types = store.output_types(); + let DataTree::Leaf(tt_a) = out_types.get_by_str_key("a").unwrap() else { + panic!("expected leaf at a"); + }; + assert!(matches!(tt_a.dtype, DTypeLike::Concrete(DType::F64))); + assert_eq!(tt_a.shape, vec![Dim::Fixed(2)]); + + let DataTree::Leaf(tt_b) = out_types.get_by_str_key("b").unwrap() else { + panic!("expected leaf at b"); + }; + assert!(matches!(tt_b.dtype, DTypeLike::Concrete(DType::I32))); + assert_eq!(tt_b.shape, vec![Dim::Fixed(3)]); + } + + #[test] + fn test_store_no_inputs() { + let store = Store::new(DataTree::new_leaf(Tensor::from([42.0_f64]))); + assert!(store.input_types().is_empty()); + assert!(store.implements_call()); + } +} From a437668f6c6d3b84ef13181c0464ecb4e7508272 Mon Sep 17 00:00:00 2001 From: Ian Hincks Date: Wed, 29 Apr 2026 15:43:13 -0400 Subject: [PATCH 10/10] Add impls of ProgramNode for various math operations --- crates/providers/src/lib.rs | 1 + crates/providers/src/math_nodes/binary.rs | 278 +++++++++++ crates/providers/src/math_nodes/bitwise.rs | 291 ++++++++++++ crates/providers/src/math_nodes/mod.rs | 15 + crates/providers/src/math_nodes/reduction.rs | 469 +++++++++++++++++++ 5 files changed, 1054 insertions(+) create mode 100644 crates/providers/src/math_nodes/binary.rs create mode 100644 crates/providers/src/math_nodes/bitwise.rs create mode 100644 crates/providers/src/math_nodes/mod.rs create mode 100644 crates/providers/src/math_nodes/reduction.rs diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index cb59bffbca6b..b81596dff981 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -11,6 +11,7 @@ // that they have been altered from the originals. mod data_tree; +pub mod math_nodes; mod program_node; mod store; pub mod tensor; diff --git a/crates/providers/src/math_nodes/binary.rs b/crates/providers/src/math_nodes/binary.rs new file mode 100644 index 000000000000..c77e111452f5 --- /dev/null +++ b/crates/providers/src/math_nodes/binary.rs @@ -0,0 +1,278 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DTypeLike, Tensor, TensorType, promotion}; +use std::borrow::Cow; +use std::sync::OnceLock; + +/// Shared input type spec for all elementwise binary nodes: two broadcastable tensors `x` and `y`. +fn elementwise_binary_input_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + let mut types = DataTree::with_capacity(2); + types.insert_leaf( + "x", + TensorType { + dtype: DTypeLike::Var("x".into()), + shape: vec![], + broadcastable: true, + }, + ); + types.insert_leaf( + "y", + TensorType { + dtype: DTypeLike::Var("y".into()), + shape: vec![], + broadcastable: true, + }, + ); + types + }) +} + +/// Shared output type spec for all elementwise binary nodes: a single tensor of the promoted dtype. +fn elementwise_binary_output_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Promotion( + vec![DTypeLike::Var("x".into()), DTypeLike::Var("y".into())].into(), + ), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Extract `x` and `y` from `args`, promote dtypes, and apply `op` element-wise. +fn binary_elementwise_call( + args: &DataTree, + op: impl Fn(&Tensor, &Tensor) -> Tensor, +) -> anyhow::Result> { + let DataTree::Leaf(x) = args.get_by_str_key("x").expect("missing input x") else { + panic!("expected leaf at x"); + }; + let DataTree::Leaf(y) = args.get_by_str_key("y").expect("missing input y") else { + panic!("expected leaf at y"); + }; + let out_dtype = promotion(x.dtype(), y.dtype()); + + // Use copy-on-write smart pointer to avoid cloning when promotion is unnecessary + let x = if x.dtype() == out_dtype { + Cow::Borrowed(x) + } else { + Cow::Owned(x.clone().cast(out_dtype)) + }; + let y = if y.dtype() == out_dtype { + Cow::Borrowed(y) + } else { + Cow::Owned(y.clone().cast(out_dtype)) + }; + Ok(DataTree::new_leaf(op(x.as_ref(), y.as_ref()))) +} + +/// Generate a [`ProgramNode`] struct for an elementwise binary operation. +macro_rules! elementwise_binary_node { + ($name:ident, $node_name:literal, $call_fn:expr) => { + #[doc = concat!("Elementwise `", $node_name, "` of two broadcastable tensors.")] + pub struct $name; + + impl ProgramNode for $name { + fn name(&self) -> &'static str { + $node_name + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + elementwise_binary_input_types() + } + fn output_types(&self) -> &DataTree { + elementwise_binary_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + binary_elementwise_call(args, $call_fn) + } + } + }; +} + +elementwise_binary_node!(Add, "add", |x, y| x + y); +elementwise_binary_node!(Subtract, "subtract", |x, y| x - y); +elementwise_binary_node!(Multiply, "multiply", |x, y| x * y); +elementwise_binary_node!(Divide, "divide", |x, y| x / y); +elementwise_binary_node!(Remainder, "remainder", |x, y| x % y); +elementwise_binary_node!(Power, "power", |x, y| x.pow(y)); + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::{DType, Tensor}; + + fn args(x: Tensor, y: Tensor) -> DataTree { + let mut tree = DataTree::new(); + tree.insert_leaf("x", x); + tree.insert_leaf("y", y); + tree + } + + #[test] + fn test_add_same_dtype() { + let result = Add + .call(&args( + Tensor::from([1.0_f64, 2.0, 3.0]), + Tensor::from([4.0_f64, 5.0, 6.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf") + }; + assert_eq!(arr.as_slice().unwrap(), &[5.0, 7.0, 9.0]); + } + + #[test] + fn test_add_promotes_dtype() { + let result = Add + .call(&args( + Tensor::from([1.0_f32, 2.0]), + Tensor::from([3.0_f64, 4.0]), + )) + .unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf") + }; + assert_eq!(tensor.dtype(), DType::F64); + let Tensor::F64(arr) = tensor else { + panic!("expected f64") + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 6.0]); + } + + #[test] + fn test_add_broadcasts_1d_scalar() { + // shape [3] + shape [1] -> shape [3] + let result = Add + .call(&args( + Tensor::from([1.0_f64, 2.0, 3.0]), + Tensor::from([10.0_f64]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf") + }; + assert_eq!(arr.as_slice().unwrap(), &[11.0, 12.0, 13.0]); + } + + #[test] + fn test_add_broadcasts_2d_with_1d() { + // shape [2, 3] + shape [3] -> shape [2, 3] + use ndarray::arr2; + let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn()); + let y = Tensor::from([10.0_f64, 20.0, 30.0]); + let result = Add.call(&args(x, y)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf") + }; + let expected = arr2(&[[11.0_f64, 22.0, 33.0], [14.0, 25.0, 36.0]]).into_dyn(); + assert_eq!(arr, expected); + } + + #[test] + fn test_subtract() { + let result = Subtract + .call(&args( + Tensor::from([5.0_f64, 6.0, 7.0]), + Tensor::from([1.0_f64, 2.0, 3.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 4.0, 4.0]); + } + + #[test] + fn test_multiply() { + let result = Multiply + .call(&args( + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([10.0_f64, 10.0, 10.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[20.0, 30.0, 40.0]); + } + + #[test] + fn test_divide() { + let result = Divide + .call(&args( + Tensor::from([10.0_f64, 9.0, 8.0]), + Tensor::from([2.0_f64, 3.0, 4.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[5.0, 3.0, 2.0]); + } + + #[test] + fn test_remainder() { + let result = Remainder + .call(&args( + Tensor::from([7.0_f64, 8.0, 9.0]), + Tensor::from([3.0_f64, 3.0, 3.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 0.0]); + } + + #[test] + fn test_power() { + let result = Power + .call(&args( + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([3.0_f64, 2.0, 1.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[8.0, 9.0, 4.0]); + } + + #[test] + fn test_power_broadcasts() { + // shape [3] ** shape [1] -> shape [3] + let result = Power + .call(&args( + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([2.0_f64]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 9.0, 16.0]); + } +} diff --git a/crates/providers/src/math_nodes/bitwise.rs b/crates/providers/src/math_nodes/bitwise.rs new file mode 100644 index 000000000000..d11347b2034c --- /dev/null +++ b/crates/providers/src/math_nodes/bitwise.rs @@ -0,0 +1,291 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DType, DTypeLike, Tensor, TensorType}; +use ndarray::{ArrayD, Axis}; +use std::sync::OnceLock; + +/// Shared input type spec for binary bitwise nodes: two broadcastable `Bit` tensors `x` and `y`. +fn bit_binary_input_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + let mut types = DataTree::with_capacity(2); + types.insert_leaf( + "x", + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }, + ); + types.insert_leaf( + "y", + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }, + ); + types + }) +} + +/// A single broadcastable `Bit` leaf — used for unary inputs and all bitwise outputs. +fn bit_leaf_type() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Extract `x` and `y` from `args` as `Bit` arrays and apply `op` element-wise with broadcasting. +fn bitwise_binary_call( + args: &DataTree, + op: impl Fn(&ArrayD, &ArrayD) -> ArrayD, +) -> anyhow::Result> { + let DataTree::Leaf(x) = args.get_by_str_key("x").expect("missing input x") else { + panic!("expected leaf at x"); + }; + let DataTree::Leaf(y) = args.get_by_str_key("y").expect("missing input y") else { + panic!("expected leaf at y"); + }; + let (Tensor::Bit(x_arr), Tensor::Bit(y_arr)) = (x, y) else { + panic!("bitwise operations require Bit tensors"); + }; + Ok(DataTree::new_leaf(Tensor::Bit(op(x_arr, y_arr)))) +} + +/// Generate a [`ProgramNode`] struct for an elementwise binary bitwise operation on `Bit` tensors. +macro_rules! bitwise_binary_node { + ($name:ident, $node_name:literal, $call_fn:expr) => { + #[doc = concat!("Elementwise `", $node_name, "` of two broadcastable `Bit` tensors.")] + pub struct $name; + + impl ProgramNode for $name { + fn name(&self) -> &'static str { + $node_name + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + bit_binary_input_types() + } + fn output_types(&self) -> &DataTree { + bit_leaf_type() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + bitwise_binary_call(args, $call_fn) + } + } + }; +} + +bitwise_binary_node!(BitwiseAnd, "bitwise_and", |x, y| x & y); +bitwise_binary_node!(BitwiseOr, "bitwise_or", |x, y| x | y); +bitwise_binary_node!(BitwiseXor, "bitwise_xor", |x, y| x ^ y); + +/// Elementwise bitwise NOT of a broadcastable `Bit` tensor. +pub struct BitwiseNot; + +impl ProgramNode for BitwiseNot { + fn name(&self) -> &'static str { + "bitwise_not" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + bit_leaf_type() + } + fn output_types(&self) -> &DataTree { + bit_leaf_type() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let Tensor::Bit(arr) = x else { + panic!("bitwise_not requires a Bit tensor"); + }; + Ok(DataTree::new_leaf(Tensor::Bit(arr.mapv(|b| b ^ 1)))) + } +} + +/// XOR-reduction of a `Bit` tensor along a specified axis, removing that axis. +/// +/// The parity of a sequence of bits is 1 if an odd number of bits are 1, and 0 otherwise, +/// which is equivalent to XOR-folding the sequence. The output has one fewer dimension than +/// the input, with the reduction axis removed. +pub struct Parity { + axis: usize, +} + +impl Parity { + /// Construct a `Parity` node that reduces along `axis`. + pub fn new(axis: usize) -> Self { + Self { axis } + } +} + +impl ProgramNode for Parity { + fn name(&self) -> &'static str { + "parity" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + bit_leaf_type() + } + fn output_types(&self) -> &DataTree { + bit_leaf_type() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let Tensor::Bit(arr) = x else { + panic!("parity requires a Bit tensor"); + }; + Ok(DataTree::new_leaf(Tensor::Bit(arr.fold_axis( + Axis(self.axis), + 0u8, + |&acc, &b| acc ^ b, + )))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::{arr1, arr2}; + + fn bit(data: &[u8]) -> Tensor { + Tensor::Bit(arr1(data).into_dyn()) + } + + fn args2(x: Tensor, y: Tensor) -> DataTree { + let mut tree = DataTree::new(); + tree.insert_leaf("x", x); + tree.insert_leaf("y", y); + tree + } + + #[test] + fn test_bitwise_and() { + let result = BitwiseAnd + .call(&args2(bit(&[1, 0, 1, 1]), bit(&[1, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 0, 0, 1]); + } + + #[test] + fn test_bitwise_or() { + let result = BitwiseOr + .call(&args2(bit(&[1, 0, 1, 0]), bit(&[0, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 1, 1, 1]); + } + + #[test] + fn test_bitwise_xor() { + let result = BitwiseXor + .call(&args2(bit(&[1, 0, 1, 1]), bit(&[1, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 1, 1, 0]); + } + + #[test] + fn test_bitwise_and_broadcasts() { + // shape [3] & shape [1] -> shape [3] + let result = BitwiseAnd.call(&args2(bit(&[1, 0, 1]), bit(&[1]))).unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 0, 1]); + } + + #[test] + fn test_bitwise_not() { + let result = BitwiseNot + .call(&DataTree::new_leaf(bit(&[1, 0, 1, 0]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 1, 0, 1]); + } + + #[test] + fn test_parity_axis0() { + // Reduce rows: XOR each column across rows + // [[1, 0, 1], + // [0, 1, 1], axis 0 → [1^0, 0^1, 1^1] = [1, 1, 0] + // [0, 0, 0]] + let x = Tensor::Bit(arr2(&[[1u8, 0, 1], [0, 1, 1], [0, 0, 0]]).into_dyn()); + let result = Parity::new(0).call(&DataTree::new_leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 1, 0]); + } + + #[test] + fn test_parity_axis1() { + // Reduce cols: XOR each row across columns + // [[1, 0, 1], axis 1 → [1^0^1, 0^1^1] = [0, 0] + // [0, 1, 1]] + let x = Tensor::Bit(arr2(&[[1u8, 0, 1], [0, 1, 1]]).into_dyn()); + let result = Parity::new(1).call(&DataTree::new_leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 0]); + } + + #[test] + fn test_parity_1d() { + // [1, 1, 0, 1] → 1^1^0^1 = 1 + let result = Parity::new(0) + .call(&DataTree::new_leaf(bit(&[1, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1]); + } +} diff --git a/crates/providers/src/math_nodes/mod.rs b/crates/providers/src/math_nodes/mod.rs new file mode 100644 index 000000000000..f86b768b201f --- /dev/null +++ b/crates/providers/src/math_nodes/mod.rs @@ -0,0 +1,15 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +pub mod binary; +pub mod bitwise; +pub mod reduction; diff --git a/crates/providers/src/math_nodes/reduction.rs b/crates/providers/src/math_nodes/reduction.rs new file mode 100644 index 000000000000..145ec6f8a657 --- /dev/null +++ b/crates/providers/src/math_nodes/reduction.rs @@ -0,0 +1,469 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DType, DTypeLike, Tensor, TensorType}; +use ndarray::Axis; +use num_complex::Complex; +use std::sync::OnceLock; + +/// Shared input type spec for reduction nodes: a single broadcastable tensor of any dtype. +fn reduction_input_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Var("x".into()), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Shared output type spec for reduction nodes: a single broadcastable tensor of any dtype. +fn reduction_output_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Var("out".into()), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Mean of a tensor along a specified axis, removing that axis. +/// +/// Integer inputs are cast to `F64` before computing the mean. `F32` inputs +/// produce `F32` output; all other float and integer types produce `F64`. +/// Complex inputs (`C64`, `C128`) preserve their complex dtype. +pub struct Mean { + axis: usize, +} + +impl Mean { + /// Construct a `Mean` node that reduces along `axis`. + pub fn new(axis: usize) -> Self { + Self { axis } + } +} + +impl ProgramNode for Mean { + fn name(&self) -> &'static str { + "mean" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + reduction_input_types() + } + fn output_types(&self) -> &DataTree { + reduction_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let result = match x { + Tensor::F32(a) => Tensor::F32(a.mean_axis(Axis(self.axis)).unwrap()), + Tensor::F64(a) => Tensor::F64(a.mean_axis(Axis(self.axis)).unwrap()), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + Tensor::C64(a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + Tensor::C128(a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + } + other => { + let Tensor::F64(a) = other.clone().cast(DType::F64) else { + unreachable!() + }; + Tensor::F64(a.mean_axis(Axis(self.axis)).unwrap()) + } + }; + Ok(DataTree::new_leaf(result)) + } +} + +/// Variance of a tensor along a specified axis, removing that axis. +/// +/// The `ddof` (delta degrees of freedom) parameter adjusts the divisor: the result +/// is divided by `n - ddof` where `n` is the number of elements along the axis. +/// Use `ddof=0` for population variance and `ddof=1` for sample variance. +/// +/// Integer inputs are cast to `F64`. `F32` produces `F32`; all other real types +/// produce `F64`. Complex inputs (`C64`, `C128`) produce real output (`F32`, `F64` +/// respectively), computed as the mean squared modulus of the deviations. +pub struct Variance { + axis: usize, + ddof: f64, +} + +impl Variance { + /// Construct a `Variance` node that reduces along `axis` with degrees-of-freedom + /// correction `ddof`. + pub fn new(axis: usize, ddof: f64) -> Self { + Self { axis, ddof } + } +} + +impl ProgramNode for Variance { + fn name(&self) -> &'static str { + "variance" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + reduction_input_types() + } + fn output_types(&self) -> &DataTree { + reduction_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let result = match x { + Tensor::F32(a) => Tensor::F32(a.var_axis(Axis(self.axis), self.ddof as f32)), + Tensor::F64(a) => Tensor::F64(a.var_axis(Axis(self.axis), self.ddof)), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F32(sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof as f32)) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F64(sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof)) + } + other => { + let Tensor::F64(a) = other.clone().cast(DType::F64) else { + unreachable!() + }; + Tensor::F64(a.var_axis(Axis(self.axis), self.ddof)) + } + }; + Ok(DataTree::new_leaf(result)) + } +} + +/// Standard deviation of a tensor along a specified axis, removing that axis. +/// +/// This is the square root of [`Variance`]. See that type for details on `ddof`, +/// output dtypes, and complex handling. +pub struct Std { + axis: usize, + ddof: f64, +} + +impl Std { + /// Construct a `Std` node that reduces along `axis` with degrees-of-freedom + /// correction `ddof`. + pub fn new(axis: usize, ddof: f64) -> Self { + Self { axis, ddof } + } +} + +impl ProgramNode for Std { + fn name(&self) -> &'static str { + "std" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + reduction_input_types() + } + fn output_types(&self) -> &DataTree { + reduction_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let result = match x { + Tensor::F32(a) => Tensor::F32(a.std_axis(Axis(self.axis), self.ddof as f32)), + Tensor::F64(a) => Tensor::F64(a.std_axis(Axis(self.axis), self.ddof)), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F32( + (sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof as f32)).mapv(f32::sqrt), + ) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F64((sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof)).mapv(f64::sqrt)) + } + other => { + let Tensor::F64(a) = other.clone().cast(DType::F64) else { + unreachable!() + }; + Tensor::F64(a.std_axis(Axis(self.axis), self.ddof)) + } + }; + Ok(DataTree::new_leaf(result)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::{DType, Tensor}; + use ndarray::arr2; + + fn leaf(t: Tensor) -> DataTree { + DataTree::new_leaf(t) + } + + fn approx_eq_slice(a: &[f64], b: &[f64]) { + assert_eq!(a.len(), b.len(), "slice lengths differ"); + for (x, y) in a.iter().zip(b.iter()) { + assert!((x - y).abs() < 1e-10, "{x} != {y}"); + } + } + + fn approx_eq_slice_f32(a: &[f32], b: &[f32]) { + assert_eq!(a.len(), b.len(), "slice lengths differ"); + for (x, y) in a.iter().zip(b.iter()) { + assert!((x - y).abs() < 1e-6, "{x} != {y}"); + } + } + + // --- Mean tests --- + + #[test] + fn test_mean_f64_axis0() { + // [[1,2,3],[4,5,6]] along axis 0 → [2.5, 3.5, 4.5] + let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn()); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.5, 3.5, 4.5]); + } + + #[test] + fn test_mean_f64_axis1() { + // [[1,2,3],[4,5,6]] along axis 1 → [2.0, 5.0] + let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn()); + let result = Mean::new(1).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.0, 5.0]); + } + + #[test] + fn test_mean_f32() { + let x = Tensor::from([1.0_f32, 2.0, 3.0, 4.0]); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F32, + "F32 input should produce F32 mean" + ); + let Tensor::F32(arr) = tensor else { panic!() }; + approx_eq_slice_f32(arr.as_slice().unwrap(), &[2.5]); + } + + #[test] + fn test_mean_i32_casts_to_f64() { + let x = Tensor::from([1_i32, 2, 3, 4]); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F64, + "integer input should produce F64 mean" + ); + let Tensor::F64(arr) = tensor else { panic!() }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.5]); + } + + #[test] + fn test_mean_c128() { + use num_complex::Complex; + let data: Vec> = vec![ + Complex::new(1.0, 2.0), + Complex::new(3.0, 4.0), + Complex::new(5.0, 6.0), + ]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn()); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::C128(arr)) = result else { + panic!("expected C128 leaf"); + }; + let v = arr.as_slice().unwrap()[0]; + assert!((v.re - 3.0).abs() < 1e-10); + assert!((v.im - 4.0).abs() < 1e-10); + } + + // --- Variance tests --- + + #[test] + fn test_variance_f64_ddof0() { + // [2, 4, 4, 4, 5, 5, 7, 9] — classic example, population variance = 4.0 + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[4.0]); + } + + #[test] + fn test_variance_f64_ddof1() { + // Sample variance (ddof=1) of the same sequence + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 1.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + // sample variance = population variance * n / (n-1) = 4.0 * 8/7 + approx_eq_slice(arr.as_slice().unwrap(), &[4.0 * 8.0 / 7.0]); + } + + #[test] + fn test_variance_i32_casts_to_f64() { + let x = Tensor::from([1_i32, 2, 3, 4, 5]); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!(tensor.dtype(), DType::F64); + let Tensor::F64(arr) = tensor else { panic!() }; + // mean=3, deviations=[-2,-1,0,1,2], sq=[4,1,0,1,4], population variance=2.0 + approx_eq_slice(arr.as_slice().unwrap(), &[2.0]); + } + + #[test] + fn test_variance_f32() { + let x = Tensor::from([2.0_f32, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F32, + "F32 input should produce F32 variance" + ); + let Tensor::F32(arr) = tensor else { panic!() }; + approx_eq_slice_f32(arr.as_slice().unwrap(), &[4.0]); + } + + #[test] + fn test_variance_c128_returns_real() { + use num_complex::Complex; + // [1+1i, 3+3i] — mean = 2+2i, deviations = [−1−i, 1+i], |.|^2 = [2, 2], var = 2.0 + let data: Vec> = vec![Complex::new(1.0, 1.0), Complex::new(3.0, 3.0)]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn()); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F64, + "C128 variance should return F64" + ); + let Tensor::F64(arr) = tensor else { panic!() }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.0]); + } + + // --- Std tests --- + + #[test] + fn test_std_f64() { + // std = sqrt(variance) = 2.0 for the classic example + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.0]); + } + + #[test] + fn test_std_i32_casts_to_f64() { + let x = Tensor::from([1_i32, 2, 3, 4, 5]); + let result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!(tensor.dtype(), DType::F64); + let Tensor::F64(arr) = tensor else { panic!() }; + // population std = sqrt(2.0) + approx_eq_slice(arr.as_slice().unwrap(), &[2.0_f64.sqrt()]); + } + + #[test] + fn test_std_matches_sqrt_of_variance() { + // Verify std = sqrt(variance) numerically + let x = Tensor::from([1.0_f64, 3.0, 5.0, 7.0, 9.0]); + let var_result = Variance::new(0, 0.0).call(&leaf(x.clone())).unwrap(); + let std_result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + + let DataTree::Leaf(Tensor::F64(var_arr)) = var_result else { + panic!() + }; + let DataTree::Leaf(Tensor::F64(std_arr)) = std_result else { + panic!() + }; + + let var_val = var_arr.as_slice().unwrap()[0]; + let std_val = std_arr.as_slice().unwrap()[0]; + assert!((std_val - var_val.sqrt()).abs() < 1e-10); + } + + #[test] + fn test_std_c128_returns_real() { + use num_complex::Complex; + let data: Vec> = vec![Complex::new(1.0, 1.0), Complex::new(3.0, 3.0)]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn()); + let result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!(tensor.dtype(), DType::F64, "C128 std should return F64"); + let Tensor::F64(arr) = tensor else { panic!() }; + // std = sqrt(2.0) + approx_eq_slice(arr.as_slice().unwrap(), &[2.0_f64.sqrt()]); + } +}