diff --git a/Cargo.lock b/Cargo.lock index 4d5ec222efb0..7b614eacfad9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2341,6 +2341,14 @@ dependencies = [ [[package]] name = "qiskit-providers" version = "2.5.0-dev" +dependencies = [ + "anyhow", + "hashbrown 0.15.5", + "ndarray", + "num-complex", + "rustworkx-core", + "thiserror 2.0.18", +] [[package]] name = "qiskit-pyext" diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index e9f2e7ff623b..2a09144addc7 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -9,6 +9,18 @@ license.workspace = true name = "qiskit_providers" [dependencies] +rustworkx-core.workspace = true +anyhow.workspace = true +num-complex.workspace = true +thiserror.workspace = true + +[dependencies.hashbrown] +workspace = true +features = ["rayon", "serde"] [lints] workspace = true + +[dependencies.ndarray] +workspace = true +features = ["rayon", "approx"] \ No newline at end of file diff --git a/crates/providers/src/data_tree.rs b/crates/providers/src/data_tree.rs new file mode 100644 index 000000000000..0558bf00f34e --- /dev/null +++ b/crates/providers/src/data_tree.rs @@ -0,0 +1,907 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use hashbrown::HashMap; + +/// A path entry used for tracking a path through a [`DataTree`] +/// +/// Each entry can either be an index or a key. A slice of `PathEntry` are used to form +/// a traversal path through the [`DataTree`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum PathEntry<'a> { + Index(usize), + Key(&'a str), +} + +/// A struct representing a branch in a [`DataTree`]. +/// +/// Each branch contains a vec of [`DataTree`] that can also be assigned a +/// string key for accessing it. Typically you will not create these directly +/// but instead create them via the [`DataTree`] API. +#[derive(Debug, Clone)] +pub struct DataTreeBranch { + data: Vec>, + keys: HashMap, +} + +impl Default for DataTreeBranch { + fn default() -> Self { + Self::new() + } +} + +impl DataTreeBranch { + /// Construct a new empty [`DataTreeBranch`] + pub fn new() -> Self { + DataTreeBranch { + data: Vec::new(), + keys: HashMap::new(), + } + } + + /// Construct a new empty [`DataTreeBranch`] with a set capacity + pub fn with_capacity(capacity: usize) -> Self { + DataTreeBranch { + data: Vec::with_capacity(capacity), + keys: HashMap::with_capacity(capacity), + } + } + + /// Take a path slice and return the entry at the given path + /// + /// This will return `None` if a path can not be found. This includes an + /// invalid path, such as a path a leaf node in the middle. An empty path + /// will also return `self`. + pub fn get_by_path(&self, path: &[PathEntry]) -> Option<&DataTree> { + let start = match path[0] { + PathEntry::Index(idx) => Some(&self.data[idx]), + PathEntry::Key(key) => self.keys.get(key).map(|idx| &self.data[*idx]), + }?; + match start { + DataTree::Leaf(_) => { + if path.len() > 1 { + // If there are more components in the path this is an invalid path + None + } else { + Some(start) + } + } + DataTree::Branch(inner_tree) => { + if path.len() > 1 { + inner_tree.get_by_path(&path[1..]) + } else { + Some(start) + } + } + } + } + + /// Return an iterator over the leaves in the `DataTree` + /// + /// This method will return an iterator over all leaf nodes in the tree by traversing the tree + /// in a DFS order. + pub fn iter_path(&self) -> IterDataTree<'_, T> { + IterDataTree { + tree: None, + branch: Some(self), + index: 0, + inner: None, + inner_next: None, + path: vec![], + } + } + + /// The number of items in this `DataTree`. This length is just the number of items in this + /// local tree object and will not recurse through the tree to compute the total number of + /// leaves. If you want to do that you should use [`DataTree::iter_leaves`]. + pub fn iter_leaves(&self) -> IterLeaves<'_, T> { + IterLeaves { + tree: None, + branch: Some(self), + index: 0, + inner: None, + inner_next: None, + } + } + + /// The number of [`DataTree`] in this branch. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Check if there are any [`DataTree`] in this branch. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// The number of string keys set on this branch. + pub fn num_keys(&self) -> usize { + self.keys.len() + } + + /// Check if the branch has any string keys set. + pub fn has_keys(&self) -> bool { + !self.keys.is_empty() + } +} + +impl From> for DataTreeBranch { + fn from(input: DataTree) -> Self { + DataTreeBranch { + data: vec![input], + keys: HashMap::new(), + } + } +} + +/// A generic tree that is addressable either by either indices or string keys +#[derive(Debug, Clone)] +pub enum DataTree { + Leaf(T), + Branch(DataTreeBranch), +} + +impl Default for DataTree { + fn default() -> Self { + Self::new() + } +} + +impl DataTree { + /// Consume the entry and return the leaf value otherwise panic + pub fn unwrap_leaf(self) -> T { + match self { + Self::Leaf(data) => data, + Self::Branch(_) => panic!("called TreeEntry::unwrap_leaf() on a `Tree` value"), + } + } + + /// Create a new empty data tree + pub fn new() -> Self { + DataTree::Branch(DataTreeBranch::new()) + } + + /// Create a new leaf data tree + pub fn new_leaf(value: T) -> Self { + DataTree::Leaf(value) + } + + /// Create a new empty data tree with an underlying allocation of a given size. + /// + /// The specified capacity is the number of items of type T stored in the `DataTree` + /// along with an associated `String` key for each element in the tree. This does not + /// account for nesting in the allocation as each layer in the tree is a separate + /// `DataTree` object. + pub fn with_capacity(capacity: usize) -> Self { + DataTree::Branch(DataTreeBranch::with_capacity(capacity)) + } + + /// The number of items in this `DataTree`. This length is just the number of items in this + /// local tree object and will not recurse through the tree to compute the total number of + /// leaves. If you want to do that you should use [`DataTree::iter_leaves`]. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut inner_tree = DataTree::with_capacity(5); + /// inner_tree.insert_leaf("y", 10); + /// inner_tree.insert_leaf("z", 11); + /// inner_tree.insert_leaf("a", 12); + /// inner_tree.insert_leaf("b", 13); + /// inner_tree.push_leaf(15); + /// + /// let mut tree = DataTree::new(); + /// tree.insert_branch("x", inner_tree); + /// assert_eq!(tree.len(), 1); + /// ``` + pub fn len(&self) -> usize { + match self { + Self::Leaf(_) => 1, + Self::Branch(branch) => branch.data.len(), + } + } + + /// Return whether this `DataTree` has an items in it. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Take a string key and return the entry at the given key. + /// + /// The "." character is reserved in keys and used to indicate a path + /// through the graph. + /// + /// This will return `None` if the string key can not be found. This includes + /// an invalid path, such as a path containing component or a leaf node in the + /// middle. An empty string for the path will return `self`. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut inner_tree = DataTree::new(); + /// inner_tree.insert_leaf("y", 10); + /// let mut tree = DataTree::new(); + /// tree.insert_branch("x", inner_tree); + /// let result = tree.get_by_str_key("x.y").unwrap().clone().unwrap_leaf(); + /// assert_eq!(result, 10); + /// ``` + pub fn get_by_str_key(&self, key: &str) -> Option<&Self> { + if key.is_empty() { + return Some(self); + } + if key.contains(".") { + let path: Vec = key.split(".").map(PathEntry::Key).collect(); + self.get_by_path(&path) + } else { + match self { + Self::Leaf(_) => None, + Self::Branch(branch) => branch.keys.get(key).map(|value| &branch.data[*value]), + } + } + } + + /// Take a path slice and return the entry at the given path + /// + /// This will return `None` if a path can not be found. This includes an + /// invalid path, such as a path a leaf node in the middle. An empty path + /// will also return `self`. + pub fn get_by_path(&self, path: &[PathEntry]) -> Option<&Self> { + if path.is_empty() { + return Some(self); + } + let Self::Branch(branch) = self else { + return None; + }; + let start = match path[0] { + PathEntry::Index(idx) => Some(&branch.data[idx]), + PathEntry::Key(key) => branch.keys.get(key).map(|idx| &branch.data[*idx]), + }?; + match start { + DataTree::Leaf(_) => { + if path.len() > 1 { + // If there are more components in the path this is an invalid path + None + } else { + Some(start) + } + } + DataTree::Branch(inner_tree) => { + if path.len() > 1 { + inner_tree.get_by_path(&path[1..]) + } else { + Some(start) + } + } + } + } + + /// Get an item from the `DataTree` by index. + /// + /// This will return `None` if the index is not valid. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut inner_tree = DataTree::new(); + /// inner_tree.insert_leaf("y", 10); + /// let mut tree = DataTree::new(); + /// tree.insert_branch("x", inner_tree); + /// tree.push_leaf(124); + /// let result = tree.get(1).unwrap().clone().unwrap_leaf(); + /// assert_eq!(result, 124); + /// let subtree = tree.get(0).unwrap(); + /// let subtree_result = subtree.get(0).unwrap().clone().unwrap_leaf(); + /// assert_eq!(subtree_result, 10); + /// ``` + pub fn get(&self, index: usize) -> Option<&DataTree> { + match self { + Self::Leaf(_) => panic!("Called get() on a leaf node"), + Self::Branch(branch) => branch.data.get(index), + } + } + + /// Iterate over direct children, yielding `(optional_key, child)` pairs in index order. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.push_leaf(10); // unnamed + /// tree.insert_leaf("b", 20); // named + /// tree.push_leaf(30); // unnamed + /// let children: Vec<_> = tree.iter_children().collect(); + /// assert_eq!(children[0], (None, &DataTree::Leaf(10))); + /// assert_eq!(children[1], (Some("b"), &DataTree::Leaf(20))); + /// assert_eq!(children[2], (None, &DataTree::Leaf(30))); + /// ``` + pub fn iter_children(&self) -> impl Iterator, &DataTree)> + '_ { + let branch = match self { + Self::Branch(branch) => branch, + Self::Leaf(_) => panic!("called iter_children() on a leaf node"), + }; + let rev: HashMap = branch.keys.iter().map(|(k, &v)| (v, k.as_str())).collect(); + branch + .data + .iter() + .enumerate() + .map(move |(i, child)| (rev.get(&i).copied(), child)) + } + + /// Insert a new leaf node with an associated string key + /// + /// If a key is provided that is already in the tree the new value will be associated with + /// with the key and the old value will remain in the tree but without a string key. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("y", 10); + /// tree.insert_leaf("y", 1000); + /// let result = tree.get_by_str_key("y").unwrap().clone().unwrap_leaf(); + /// assert_eq!(result, 1000); + /// ``` + pub fn insert_leaf(&mut self, key: &str, value: T) { + match self { + Self::Leaf(_) => panic!("Called insert_leaf() on a leaf node"), + Self::Branch(branch) => { + branch.data.push(Self::Leaf(value)); + branch.keys.insert(key.to_string(), branch.data.len() - 1); + } + } + } + + /// Add a new leaf to the tree + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.push_leaf(10); + /// tree.push_leaf(1000); + /// assert_eq!(vec![10, 1000], tree.iter_leaves().copied().collect::>()); + /// ``` + pub fn push_leaf(&mut self, value: T) { + match self { + Self::Leaf(_) => panic!("Called push_leaf() on a leaf_node"), + Self::Branch(branch) => branch.data.push(DataTree::Leaf(value)), + } + } + + /// Add a subtree to the tree with an associated string key + /// + /// If a key is provided that is already in the tree the new value will be associated with + /// with the key and the old value will remain in the tree but without a string key. + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("y", 10); + /// let mut subtree = DataTree::with_capacity(2); + /// subtree.push_leaf(123); + /// subtree.push_leaf(456); + /// tree.insert_branch("y", subtree); + /// let result = tree.get_by_str_key("y").unwrap(); + /// let leaves: Vec<_> = result.iter_leaves().copied().collect(); + /// assert_eq!(leaves, vec![123, 456]); + /// ``` + pub fn insert_branch(&mut self, key: &str, value: DataTree) { + match self { + Self::Leaf(_) => panic!("Called insert_branch() on a leaf_node"), + Self::Branch(branch) => { + branch.data.push(value); + branch.keys.insert(key.to_string(), branch.data.len() - 1); + } + } + } + + /// Add a new branch to the tree + /// + /// # Example + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut tree = DataTree::new(); + /// tree.push_leaf(10); + /// let mut subtree = DataTree::with_capacity(2); + /// subtree.push_leaf(123); + /// subtree.push_leaf(456); + /// tree.push_branch(subtree); + /// assert_eq!(vec![10, 123, 456], tree.iter_leaves().copied().collect::>()); + /// ``` + pub fn push_branch(&mut self, value: DataTree) { + match self { + Self::Leaf(_) => panic!("Called insert_branch() on a leaf_node"), + Self::Branch(branch) => branch.data.push(value), + } + } + + /// Return an iterator over the leaves in the `DataTree` + /// + /// This method will return an iterator over all leaf nodes in the tree by traversing the tree + /// in a DFS order. + /// + /// # Example + /// + /// Traversing this tree: + /// + /// ```rust + /// use qiskit_providers::DataTree; + /// let mut subsubsubtree = DataTree::new(); + /// subsubsubtree.push_leaf(3); + /// subsubsubtree.push_leaf(4); + /// let mut subsubtree = DataTree::new(); + /// subsubtree.push_branch(subsubsubtree); + /// subsubtree.insert_leaf("b", 5); + /// let mut subsubtree_prime = DataTree::new(); + /// subsubtree_prime.push_leaf(7); + /// let mut subtree = DataTree::new(); + /// subtree.insert_branch("c", subsubtree); + /// subtree.insert_leaf("d", 6); + /// subtree.push_branch(subsubtree_prime); + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("a", 0); + /// tree.insert_branch("root", subtree); + /// tree.insert_leaf("z", 26); + /// let leaves: Vec<_> = tree.iter_leaves().copied().collect(); + /// let expected = vec![0, 3, 4, 5, 6, 7, 26]; + /// assert_eq!(leaves, expected); + /// ``` + pub fn iter_leaves(&self) -> IterLeaves<'_, T> { + IterLeaves { + tree: Some(self), + branch: None, + index: 0, + inner: None, + inner_next: None, + } + } + + /// Return an iterator over the leaves in the `DataTree` that returns the path and leaf value. + /// + /// This method will return an iterator over all the leaf nodes in the tree in a DFS order. + /// Unlike [`iter_leaves`] which just returns the value this will return an owned `Vec` of the + /// path through the data tree to get to that value. This has allocation overhead for each leaf + /// node in the tree and should only be used if you need the path along with the value. + /// + /// ```rust + /// use qiskit_providers::{DataTree, PathEntry}; + /// let mut subsubsubtree = DataTree::new(); + /// subsubsubtree.push_leaf(3); + /// subsubsubtree.push_leaf(4); + /// let mut subsubtree = DataTree::new(); + /// subsubtree.push_branch(subsubsubtree); + /// subsubtree.insert_leaf("b", 5); + /// let mut subsubtree_prime = DataTree::new(); + /// subsubtree_prime.push_leaf(7); + /// let mut subtree = DataTree::new(); + /// subtree.insert_branch("c", subsubtree); + /// subtree.insert_leaf("d", 6); + /// subtree.push_branch(subsubtree_prime); + /// let mut tree = DataTree::new(); + /// tree.insert_leaf("a", 0); + /// tree.insert_branch("root", subtree); + /// tree.insert_leaf("z", 26); + /// let result: Vec<_> = tree.iter_path().map(|(a, b)| (a, *b)).collect(); + /// let expected_paths: Vec> = vec![ + /// vec![0], + /// vec![1, 0, 0, 0], + /// vec![1, 0, 0, 1], + /// vec![1, 0, 1], + /// vec![1, 1], + /// vec![1, 2, 0], + /// vec![2], + /// ]; + /// let expected_vals = vec![0, 3, 4, 5, 6, 7, 26]; + /// let expected: Vec<_> = expected_paths + /// .into_iter() + /// .map(|x| x.into_iter().map(PathEntry::Index).collect::>()) + /// .zip(expected_vals) + /// .collect(); + /// assert_eq!(result, expected); + /// ``` + pub fn iter_path(&self) -> IterDataTree<'_, T> { + IterDataTree { + tree: Some(self), + branch: None, + index: 0, + inner: None, + inner_next: None, + path: Vec::new(), + } + } +} + +pub struct IterDataTree<'a, T> { + tree: Option<&'a DataTree>, + branch: Option<&'a DataTreeBranch>, + index: usize, + inner: Option>>, + inner_next: Option<(Vec>, &'a T)>, + path: Vec>, +} + +impl<'a, T> Iterator for IterDataTree<'a, T> { + type Item = (Vec>, &'a T); + + fn next(&mut self) -> Option { + if let Some(tree) = self.tree { + if let DataTree::Leaf(val) = tree { + if self.index == 0 { + self.index += 1; + return Some((vec![], val)); + } else { + return None; + } + } + let DataTree::Branch(branch) = tree else { + unreachable!("Must be a branch variant"); + }; + if self.index >= branch.data.len() { + return None; + } + let entry = &branch.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + let mut out_path = self.path.clone(); + out_path.push(PathEntry::Index(self.index - 1)); + Some((out_path, val)) + } + DataTree::Branch(sub_branch) => { + if let Some(ref mut inner) = self.inner { + if let Some(val) = inner.next() { + let (return_path, return_val) = self.inner_next.replace(val).unwrap(); + Some((return_path, return_val)) + } else { + self.inner = None; + self.index += 1; + let (return_path, return_val) = self.inner_next.take().unwrap(); + self.inner_next = None; + Some((return_path, return_val)) + } + } else { + let mut inner = sub_branch.iter_path(); + let mut inner_path = self.path.clone(); + inner_path.push(PathEntry::Index(self.index)); + inner.path = inner_path; + self.inner = Some(Box::new(inner)); + let (inner_path, val) = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some((inner_path, val)) + } + } + } + } else if let Some(subtree) = self.branch { + if self.index >= subtree.data.len() { + return None; + } + let entry = &subtree.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + let mut out_path = self.path.clone(); + out_path.push(PathEntry::Index(self.index - 1)); + Some((out_path, val)) + } + DataTree::Branch(subtree) => match self.inner { + Some(ref mut inner) => { + if let Some(val) = inner.next() { + let (return_path, return_val) = self.inner_next.replace(val).unwrap(); + Some((return_path, return_val)) + } else { + self.inner = None; + self.index += 1; + let (return_path, return_val) = self.inner_next.take().unwrap(); + self.inner_next = None; + Some((return_path, return_val)) + } + } + None => { + let mut inner = subtree.iter_path(); + let mut inner_path = self.path.clone(); + inner_path.push(PathEntry::Index(self.index)); + inner.path = inner_path; + self.inner = Some(Box::new(inner)); + let (inner_path, val) = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some((inner_path, val)) + } + }, + } + } else { + None + } + } +} + +pub struct IterLeaves<'a, T> { + tree: Option<&'a DataTree>, + branch: Option<&'a DataTreeBranch>, + index: usize, + inner: Option>>, + inner_next: Option<&'a T>, +} + +impl<'a, T: std::fmt::Debug> Iterator for IterLeaves<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if let Some(tree) = self.tree { + if let DataTree::Leaf(val) = tree { + if self.index == 0 { + self.index += 1; + return Some(val); + } else { + return None; + } + } + let DataTree::Branch(branch) = tree else { + unreachable!("Must be a branch variant"); + }; + if self.index >= branch.data.len() { + return None; + } + let entry = &branch.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + Some(val) + } + DataTree::Branch(sub_branch) => { + if let Some(ref mut inner) = self.inner { + if let Some(val) = inner.next() { + let return_val = self.inner_next.replace(val).unwrap(); + Some(return_val) + } else { + self.inner = None; + self.index += 1; + let return_val = self.inner_next.take().unwrap(); + self.inner_next = None; + Some(return_val) + } + } else { + let inner = sub_branch.iter_leaves(); + self.inner = Some(Box::new(inner)); + let val = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some(val) + } + } + } + } else if let Some(subtree) = self.branch { + if self.index >= subtree.data.len() { + return None; + } + let entry = &subtree.data[self.index]; + match entry { + DataTree::Leaf(val) => { + self.index += 1; + Some(val) + } + DataTree::Branch(subtree) => match self.inner { + Some(ref mut inner) => { + if let Some(val) = inner.next() { + let return_val = self.inner_next.replace(val).unwrap(); + Some(return_val) + } else { + self.inner = None; + self.index += 1; + let return_val = self.inner_next.take().unwrap(); + self.inner_next = None; + Some(return_val) + } + } + None => { + let inner = subtree.iter_leaves(); + self.inner = Some(Box::new(inner)); + let val = self.inner.as_mut().map(|x| x.next().unwrap())?; + self.inner_next = self.inner.as_mut().and_then(|x| x.next()); + if self.inner_next.is_none() { + self.index += 1; + self.inner = None; + self.inner_next = None; + } + Some(val) + } + }, + } + } else { + None + } + } +} + +impl PartialEq for DataTree { + fn eq(&self, other: &DataTree) -> bool { + match self { + Self::Leaf(val) => { + let Self::Leaf(other_val) = other else { + return false; + }; + val == other_val + } + Self::Branch(branch) => { + let Self::Branch(other) = other else { + return false; + }; + branch.data == other.data && branch.keys == other.keys + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_data_leaf() { + let mut tree = DataTree::new(); + tree.push_leaf(42); + let result = tree.get(0).unwrap().clone(); + assert_eq!(result.unwrap_leaf(), 42); + } + + #[test] + fn test_flat_dict() { + let mut tree = DataTree::with_capacity(3); + tree.insert_leaf("a", 1); + tree.insert_leaf("b", 2); + let result = tree.get_by_str_key("b").unwrap().clone(); + assert_eq!(result.unwrap_leaf(), 2); + let result = tree.get_by_str_key("a").unwrap().clone(); + assert_eq!(result.unwrap_leaf(), 1); + } + + #[test] + fn test_nested_dict() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + let mut tree = DataTree::new(); + tree.insert_branch("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + assert_eq!(None, tree.get_by_str_key("z.y")); + assert_eq!(Some(&inner_tree), tree.get_by_str_key("x")); + } + + #[test] + fn test_nested_dict_iter() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.push_leaf(4); + inner_inner_tree.push_leaf(5); + inner_tree.push_branch(inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_branch("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + assert_eq!( + vec![10, 1, 2, 3, 4, 5, 100], + tree.iter_leaves().copied().collect::>() + ); + } + + #[test] + fn test_nested_dict_iter_path() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.push_leaf(4); + inner_inner_tree.push_leaf(5); + inner_tree.push_branch(inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_branch("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + let expected_paths = vec![ + vec![PathEntry::Index(0), PathEntry::Index(0)], + vec![PathEntry::Index(0), PathEntry::Index(1)], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(0), + ], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(1), + ], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(2), + ], + vec![ + PathEntry::Index(0), + PathEntry::Index(2), + PathEntry::Index(3), + ], + vec![PathEntry::Index(1)], + ]; + let expected_vals = vec![10, 1, 2, 3, 4, 5, 100]; + let expected = expected_paths + .into_iter() + .zip(expected_vals.into_iter()) + .collect::>(); + assert_eq!( + expected, + tree.iter_path().map(|(a, b)| (a, *b)).collect::>() + ); + } + + #[test] + fn test_get_by_str() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.insert_leaf("a", 4); + inner_inner_tree.push_leaf(5); + let inner_inner_tree_expected = inner_inner_tree.clone(); + inner_tree.insert_branch("yyy", inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_branch("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + let result = tree.get_by_str_key("x.yyy.a"); + assert_eq!(result, Some(&DataTree::Leaf(4))); + assert_eq!(tree.get_by_str_key("z"), Some(&DataTree::Leaf(100))); + assert_eq!( + tree.get_by_str_key("x.yyy"), + Some(&inner_inner_tree_expected), + ); + assert_eq!(tree.get_by_str_key("x.yy"), Some(&DataTree::Leaf(1))); + } + + #[test] + fn test_get_by_str_no_match() { + let mut inner_tree = DataTree::new(); + inner_tree.insert_leaf("y", 10); + inner_tree.insert_leaf("yy", 1); + let mut inner_inner_tree = DataTree::new(); + inner_inner_tree.push_leaf(2); + inner_inner_tree.push_leaf(3); + inner_inner_tree.insert_leaf("a", 4); + inner_inner_tree.push_leaf(5); + inner_tree.insert_branch("yyy", inner_inner_tree); + let mut tree = DataTree::new(); + tree.insert_branch("x", inner_tree.clone()); + tree.insert_leaf("z", 100); + assert_eq!(None, tree.get_by_str_key("a")); + assert_eq!(None, tree.get_by_str_key("x.yyyy")); + assert_eq!(None, tree.get_by_str_key("x.yy.a")); + assert_eq!(None, tree.get_by_str_key("🎩")); + assert_eq!(None, tree.get_by_str_key("z.yyy.a")); + } +} diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 95dc338739c7..b81596dff981 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -9,3 +9,13 @@ // Any modifications or derivative works of this code must retain this // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. + +mod data_tree; +pub mod math_nodes; +mod program_node; +mod store; +pub mod tensor; + +pub use data_tree::{DataTree, PathEntry}; +pub use program_node::ProgramNode; +pub use store::Store; diff --git a/crates/providers/src/math_nodes/binary.rs b/crates/providers/src/math_nodes/binary.rs new file mode 100644 index 000000000000..c77e111452f5 --- /dev/null +++ b/crates/providers/src/math_nodes/binary.rs @@ -0,0 +1,278 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DTypeLike, Tensor, TensorType, promotion}; +use std::borrow::Cow; +use std::sync::OnceLock; + +/// Shared input type spec for all elementwise binary nodes: two broadcastable tensors `x` and `y`. +fn elementwise_binary_input_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + let mut types = DataTree::with_capacity(2); + types.insert_leaf( + "x", + TensorType { + dtype: DTypeLike::Var("x".into()), + shape: vec![], + broadcastable: true, + }, + ); + types.insert_leaf( + "y", + TensorType { + dtype: DTypeLike::Var("y".into()), + shape: vec![], + broadcastable: true, + }, + ); + types + }) +} + +/// Shared output type spec for all elementwise binary nodes: a single tensor of the promoted dtype. +fn elementwise_binary_output_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Promotion( + vec![DTypeLike::Var("x".into()), DTypeLike::Var("y".into())].into(), + ), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Extract `x` and `y` from `args`, promote dtypes, and apply `op` element-wise. +fn binary_elementwise_call( + args: &DataTree, + op: impl Fn(&Tensor, &Tensor) -> Tensor, +) -> anyhow::Result> { + let DataTree::Leaf(x) = args.get_by_str_key("x").expect("missing input x") else { + panic!("expected leaf at x"); + }; + let DataTree::Leaf(y) = args.get_by_str_key("y").expect("missing input y") else { + panic!("expected leaf at y"); + }; + let out_dtype = promotion(x.dtype(), y.dtype()); + + // Use copy-on-write smart pointer to avoid cloning when promotion is unnecessary + let x = if x.dtype() == out_dtype { + Cow::Borrowed(x) + } else { + Cow::Owned(x.clone().cast(out_dtype)) + }; + let y = if y.dtype() == out_dtype { + Cow::Borrowed(y) + } else { + Cow::Owned(y.clone().cast(out_dtype)) + }; + Ok(DataTree::new_leaf(op(x.as_ref(), y.as_ref()))) +} + +/// Generate a [`ProgramNode`] struct for an elementwise binary operation. +macro_rules! elementwise_binary_node { + ($name:ident, $node_name:literal, $call_fn:expr) => { + #[doc = concat!("Elementwise `", $node_name, "` of two broadcastable tensors.")] + pub struct $name; + + impl ProgramNode for $name { + fn name(&self) -> &'static str { + $node_name + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + elementwise_binary_input_types() + } + fn output_types(&self) -> &DataTree { + elementwise_binary_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + binary_elementwise_call(args, $call_fn) + } + } + }; +} + +elementwise_binary_node!(Add, "add", |x, y| x + y); +elementwise_binary_node!(Subtract, "subtract", |x, y| x - y); +elementwise_binary_node!(Multiply, "multiply", |x, y| x * y); +elementwise_binary_node!(Divide, "divide", |x, y| x / y); +elementwise_binary_node!(Remainder, "remainder", |x, y| x % y); +elementwise_binary_node!(Power, "power", |x, y| x.pow(y)); + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::{DType, Tensor}; + + fn args(x: Tensor, y: Tensor) -> DataTree { + let mut tree = DataTree::new(); + tree.insert_leaf("x", x); + tree.insert_leaf("y", y); + tree + } + + #[test] + fn test_add_same_dtype() { + let result = Add + .call(&args( + Tensor::from([1.0_f64, 2.0, 3.0]), + Tensor::from([4.0_f64, 5.0, 6.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf") + }; + assert_eq!(arr.as_slice().unwrap(), &[5.0, 7.0, 9.0]); + } + + #[test] + fn test_add_promotes_dtype() { + let result = Add + .call(&args( + Tensor::from([1.0_f32, 2.0]), + Tensor::from([3.0_f64, 4.0]), + )) + .unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf") + }; + assert_eq!(tensor.dtype(), DType::F64); + let Tensor::F64(arr) = tensor else { + panic!("expected f64") + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 6.0]); + } + + #[test] + fn test_add_broadcasts_1d_scalar() { + // shape [3] + shape [1] -> shape [3] + let result = Add + .call(&args( + Tensor::from([1.0_f64, 2.0, 3.0]), + Tensor::from([10.0_f64]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf") + }; + assert_eq!(arr.as_slice().unwrap(), &[11.0, 12.0, 13.0]); + } + + #[test] + fn test_add_broadcasts_2d_with_1d() { + // shape [2, 3] + shape [3] -> shape [2, 3] + use ndarray::arr2; + let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn()); + let y = Tensor::from([10.0_f64, 20.0, 30.0]); + let result = Add.call(&args(x, y)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf") + }; + let expected = arr2(&[[11.0_f64, 22.0, 33.0], [14.0, 25.0, 36.0]]).into_dyn(); + assert_eq!(arr, expected); + } + + #[test] + fn test_subtract() { + let result = Subtract + .call(&args( + Tensor::from([5.0_f64, 6.0, 7.0]), + Tensor::from([1.0_f64, 2.0, 3.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 4.0, 4.0]); + } + + #[test] + fn test_multiply() { + let result = Multiply + .call(&args( + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([10.0_f64, 10.0, 10.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[20.0, 30.0, 40.0]); + } + + #[test] + fn test_divide() { + let result = Divide + .call(&args( + Tensor::from([10.0_f64, 9.0, 8.0]), + Tensor::from([2.0_f64, 3.0, 4.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[5.0, 3.0, 2.0]); + } + + #[test] + fn test_remainder() { + let result = Remainder + .call(&args( + Tensor::from([7.0_f64, 8.0, 9.0]), + Tensor::from([3.0_f64, 3.0, 3.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 0.0]); + } + + #[test] + fn test_power() { + let result = Power + .call(&args( + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([3.0_f64, 2.0, 1.0]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[8.0, 9.0, 4.0]); + } + + #[test] + fn test_power_broadcasts() { + // shape [3] ** shape [1] -> shape [3] + let result = Power + .call(&args( + Tensor::from([2.0_f64, 3.0, 4.0]), + Tensor::from([2.0_f64]), + )) + .unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!() + }; + assert_eq!(arr.as_slice().unwrap(), &[4.0, 9.0, 16.0]); + } +} diff --git a/crates/providers/src/math_nodes/bitwise.rs b/crates/providers/src/math_nodes/bitwise.rs new file mode 100644 index 000000000000..d11347b2034c --- /dev/null +++ b/crates/providers/src/math_nodes/bitwise.rs @@ -0,0 +1,291 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DType, DTypeLike, Tensor, TensorType}; +use ndarray::{ArrayD, Axis}; +use std::sync::OnceLock; + +/// Shared input type spec for binary bitwise nodes: two broadcastable `Bit` tensors `x` and `y`. +fn bit_binary_input_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + let mut types = DataTree::with_capacity(2); + types.insert_leaf( + "x", + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }, + ); + types.insert_leaf( + "y", + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }, + ); + types + }) +} + +/// A single broadcastable `Bit` leaf — used for unary inputs and all bitwise outputs. +fn bit_leaf_type() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Extract `x` and `y` from `args` as `Bit` arrays and apply `op` element-wise with broadcasting. +fn bitwise_binary_call( + args: &DataTree, + op: impl Fn(&ArrayD, &ArrayD) -> ArrayD, +) -> anyhow::Result> { + let DataTree::Leaf(x) = args.get_by_str_key("x").expect("missing input x") else { + panic!("expected leaf at x"); + }; + let DataTree::Leaf(y) = args.get_by_str_key("y").expect("missing input y") else { + panic!("expected leaf at y"); + }; + let (Tensor::Bit(x_arr), Tensor::Bit(y_arr)) = (x, y) else { + panic!("bitwise operations require Bit tensors"); + }; + Ok(DataTree::new_leaf(Tensor::Bit(op(x_arr, y_arr)))) +} + +/// Generate a [`ProgramNode`] struct for an elementwise binary bitwise operation on `Bit` tensors. +macro_rules! bitwise_binary_node { + ($name:ident, $node_name:literal, $call_fn:expr) => { + #[doc = concat!("Elementwise `", $node_name, "` of two broadcastable `Bit` tensors.")] + pub struct $name; + + impl ProgramNode for $name { + fn name(&self) -> &'static str { + $node_name + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + bit_binary_input_types() + } + fn output_types(&self) -> &DataTree { + bit_leaf_type() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + bitwise_binary_call(args, $call_fn) + } + } + }; +} + +bitwise_binary_node!(BitwiseAnd, "bitwise_and", |x, y| x & y); +bitwise_binary_node!(BitwiseOr, "bitwise_or", |x, y| x | y); +bitwise_binary_node!(BitwiseXor, "bitwise_xor", |x, y| x ^ y); + +/// Elementwise bitwise NOT of a broadcastable `Bit` tensor. +pub struct BitwiseNot; + +impl ProgramNode for BitwiseNot { + fn name(&self) -> &'static str { + "bitwise_not" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + bit_leaf_type() + } + fn output_types(&self) -> &DataTree { + bit_leaf_type() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let Tensor::Bit(arr) = x else { + panic!("bitwise_not requires a Bit tensor"); + }; + Ok(DataTree::new_leaf(Tensor::Bit(arr.mapv(|b| b ^ 1)))) + } +} + +/// XOR-reduction of a `Bit` tensor along a specified axis, removing that axis. +/// +/// The parity of a sequence of bits is 1 if an odd number of bits are 1, and 0 otherwise, +/// which is equivalent to XOR-folding the sequence. The output has one fewer dimension than +/// the input, with the reduction axis removed. +pub struct Parity { + axis: usize, +} + +impl Parity { + /// Construct a `Parity` node that reduces along `axis`. + pub fn new(axis: usize) -> Self { + Self { axis } + } +} + +impl ProgramNode for Parity { + fn name(&self) -> &'static str { + "parity" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + bit_leaf_type() + } + fn output_types(&self) -> &DataTree { + bit_leaf_type() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let Tensor::Bit(arr) = x else { + panic!("parity requires a Bit tensor"); + }; + Ok(DataTree::new_leaf(Tensor::Bit(arr.fold_axis( + Axis(self.axis), + 0u8, + |&acc, &b| acc ^ b, + )))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::{arr1, arr2}; + + fn bit(data: &[u8]) -> Tensor { + Tensor::Bit(arr1(data).into_dyn()) + } + + fn args2(x: Tensor, y: Tensor) -> DataTree { + let mut tree = DataTree::new(); + tree.insert_leaf("x", x); + tree.insert_leaf("y", y); + tree + } + + #[test] + fn test_bitwise_and() { + let result = BitwiseAnd + .call(&args2(bit(&[1, 0, 1, 1]), bit(&[1, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 0, 0, 1]); + } + + #[test] + fn test_bitwise_or() { + let result = BitwiseOr + .call(&args2(bit(&[1, 0, 1, 0]), bit(&[0, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 1, 1, 1]); + } + + #[test] + fn test_bitwise_xor() { + let result = BitwiseXor + .call(&args2(bit(&[1, 0, 1, 1]), bit(&[1, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 1, 1, 0]); + } + + #[test] + fn test_bitwise_and_broadcasts() { + // shape [3] & shape [1] -> shape [3] + let result = BitwiseAnd.call(&args2(bit(&[1, 0, 1]), bit(&[1]))).unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 0, 1]); + } + + #[test] + fn test_bitwise_not() { + let result = BitwiseNot + .call(&DataTree::new_leaf(bit(&[1, 0, 1, 0]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 1, 0, 1]); + } + + #[test] + fn test_parity_axis0() { + // Reduce rows: XOR each column across rows + // [[1, 0, 1], + // [0, 1, 1], axis 0 → [1^0, 0^1, 1^1] = [1, 1, 0] + // [0, 0, 0]] + let x = Tensor::Bit(arr2(&[[1u8, 0, 1], [0, 1, 1], [0, 0, 0]]).into_dyn()); + let result = Parity::new(0).call(&DataTree::new_leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1, 1, 0]); + } + + #[test] + fn test_parity_axis1() { + // Reduce cols: XOR each row across columns + // [[1, 0, 1], axis 1 → [1^0^1, 0^1^1] = [0, 0] + // [0, 1, 1]] + let x = Tensor::Bit(arr2(&[[1u8, 0, 1], [0, 1, 1]]).into_dyn()); + let result = Parity::new(1).call(&DataTree::new_leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[0, 0]); + } + + #[test] + fn test_parity_1d() { + // [1, 1, 0, 1] → 1^1^0^1 = 1 + let result = Parity::new(0) + .call(&DataTree::new_leaf(bit(&[1, 1, 0, 1]))) + .unwrap(); + let DataTree::Leaf(Tensor::Bit(arr)) = result else { + panic!("expected Bit leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1]); + } +} diff --git a/crates/providers/src/math_nodes/mod.rs b/crates/providers/src/math_nodes/mod.rs new file mode 100644 index 000000000000..f86b768b201f --- /dev/null +++ b/crates/providers/src/math_nodes/mod.rs @@ -0,0 +1,15 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +pub mod binary; +pub mod bitwise; +pub mod reduction; diff --git a/crates/providers/src/math_nodes/reduction.rs b/crates/providers/src/math_nodes/reduction.rs new file mode 100644 index 000000000000..145ec6f8a657 --- /dev/null +++ b/crates/providers/src/math_nodes/reduction.rs @@ -0,0 +1,469 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{DType, DTypeLike, Tensor, TensorType}; +use ndarray::Axis; +use num_complex::Complex; +use std::sync::OnceLock; + +/// Shared input type spec for reduction nodes: a single broadcastable tensor of any dtype. +fn reduction_input_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Var("x".into()), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Shared output type spec for reduction nodes: a single broadcastable tensor of any dtype. +fn reduction_output_types() -> &'static DataTree { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| { + DataTree::new_leaf(TensorType { + dtype: DTypeLike::Var("out".into()), + shape: vec![], + broadcastable: true, + }) + }) +} + +/// Mean of a tensor along a specified axis, removing that axis. +/// +/// Integer inputs are cast to `F64` before computing the mean. `F32` inputs +/// produce `F32` output; all other float and integer types produce `F64`. +/// Complex inputs (`C64`, `C128`) preserve their complex dtype. +pub struct Mean { + axis: usize, +} + +impl Mean { + /// Construct a `Mean` node that reduces along `axis`. + pub fn new(axis: usize) -> Self { + Self { axis } + } +} + +impl ProgramNode for Mean { + fn name(&self) -> &'static str { + "mean" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + reduction_input_types() + } + fn output_types(&self) -> &DataTree { + reduction_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let result = match x { + Tensor::F32(a) => Tensor::F32(a.mean_axis(Axis(self.axis)).unwrap()), + Tensor::F64(a) => Tensor::F64(a.mean_axis(Axis(self.axis)).unwrap()), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + Tensor::C64(a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + Tensor::C128(a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + } + other => { + let Tensor::F64(a) = other.clone().cast(DType::F64) else { + unreachable!() + }; + Tensor::F64(a.mean_axis(Axis(self.axis)).unwrap()) + } + }; + Ok(DataTree::new_leaf(result)) + } +} + +/// Variance of a tensor along a specified axis, removing that axis. +/// +/// The `ddof` (delta degrees of freedom) parameter adjusts the divisor: the result +/// is divided by `n - ddof` where `n` is the number of elements along the axis. +/// Use `ddof=0` for population variance and `ddof=1` for sample variance. +/// +/// Integer inputs are cast to `F64`. `F32` produces `F32`; all other real types +/// produce `F64`. Complex inputs (`C64`, `C128`) produce real output (`F32`, `F64` +/// respectively), computed as the mean squared modulus of the deviations. +pub struct Variance { + axis: usize, + ddof: f64, +} + +impl Variance { + /// Construct a `Variance` node that reduces along `axis` with degrees-of-freedom + /// correction `ddof`. + pub fn new(axis: usize, ddof: f64) -> Self { + Self { axis, ddof } + } +} + +impl ProgramNode for Variance { + fn name(&self) -> &'static str { + "variance" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + reduction_input_types() + } + fn output_types(&self) -> &DataTree { + reduction_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let result = match x { + Tensor::F32(a) => Tensor::F32(a.var_axis(Axis(self.axis), self.ddof as f32)), + Tensor::F64(a) => Tensor::F64(a.var_axis(Axis(self.axis), self.ddof)), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F32(sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof as f32)) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F64(sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof)) + } + other => { + let Tensor::F64(a) = other.clone().cast(DType::F64) else { + unreachable!() + }; + Tensor::F64(a.var_axis(Axis(self.axis), self.ddof)) + } + }; + Ok(DataTree::new_leaf(result)) + } +} + +/// Standard deviation of a tensor along a specified axis, removing that axis. +/// +/// This is the square root of [`Variance`]. See that type for details on `ddof`, +/// output dtypes, and complex handling. +pub struct Std { + axis: usize, + ddof: f64, +} + +impl Std { + /// Construct a `Std` node that reduces along `axis` with degrees-of-freedom + /// correction `ddof`. + pub fn new(axis: usize, ddof: f64) -> Self { + Self { axis, ddof } + } +} + +impl ProgramNode for Std { + fn name(&self) -> &'static str { + "std" + } + fn namespace(&self) -> &'static str { + "math" + } + fn input_types(&self) -> &DataTree { + reduction_input_types() + } + fn output_types(&self) -> &DataTree { + reduction_output_types() + } + fn implements_call(&self) -> bool { + true + } + fn call(&self, args: &DataTree) -> anyhow::Result> { + let DataTree::Leaf(x) = args else { + panic!("expected leaf input"); + }; + let result = match x { + Tensor::F32(a) => Tensor::F32(a.std_axis(Axis(self.axis), self.ddof as f32)), + Tensor::F64(a) => Tensor::F64(a.std_axis(Axis(self.axis), self.ddof)), + Tensor::C64(a) => { + let n = a.shape()[self.axis] as f32; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F32( + (sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof as f32)).mapv(f32::sqrt), + ) + } + Tensor::C128(a) => { + let n = a.shape()[self.axis] as f64; + let mean = (a.sum_axis(Axis(self.axis)) / Complex::new(n, 0.0)) + .insert_axis(Axis(self.axis)); + let sq_mod = (a - &mean).mapv(|c| c.re * c.re + c.im * c.im); + Tensor::F64((sq_mod.sum_axis(Axis(self.axis)) / (n - self.ddof)).mapv(f64::sqrt)) + } + other => { + let Tensor::F64(a) = other.clone().cast(DType::F64) else { + unreachable!() + }; + Tensor::F64(a.std_axis(Axis(self.axis), self.ddof)) + } + }; + Ok(DataTree::new_leaf(result)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::{DType, Tensor}; + use ndarray::arr2; + + fn leaf(t: Tensor) -> DataTree { + DataTree::new_leaf(t) + } + + fn approx_eq_slice(a: &[f64], b: &[f64]) { + assert_eq!(a.len(), b.len(), "slice lengths differ"); + for (x, y) in a.iter().zip(b.iter()) { + assert!((x - y).abs() < 1e-10, "{x} != {y}"); + } + } + + fn approx_eq_slice_f32(a: &[f32], b: &[f32]) { + assert_eq!(a.len(), b.len(), "slice lengths differ"); + for (x, y) in a.iter().zip(b.iter()) { + assert!((x - y).abs() < 1e-6, "{x} != {y}"); + } + } + + // --- Mean tests --- + + #[test] + fn test_mean_f64_axis0() { + // [[1,2,3],[4,5,6]] along axis 0 → [2.5, 3.5, 4.5] + let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn()); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.5, 3.5, 4.5]); + } + + #[test] + fn test_mean_f64_axis1() { + // [[1,2,3],[4,5,6]] along axis 1 → [2.0, 5.0] + let x = Tensor::F64(arr2(&[[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]).into_dyn()); + let result = Mean::new(1).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.0, 5.0]); + } + + #[test] + fn test_mean_f32() { + let x = Tensor::from([1.0_f32, 2.0, 3.0, 4.0]); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F32, + "F32 input should produce F32 mean" + ); + let Tensor::F32(arr) = tensor else { panic!() }; + approx_eq_slice_f32(arr.as_slice().unwrap(), &[2.5]); + } + + #[test] + fn test_mean_i32_casts_to_f64() { + let x = Tensor::from([1_i32, 2, 3, 4]); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F64, + "integer input should produce F64 mean" + ); + let Tensor::F64(arr) = tensor else { panic!() }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.5]); + } + + #[test] + fn test_mean_c128() { + use num_complex::Complex; + let data: Vec> = vec![ + Complex::new(1.0, 2.0), + Complex::new(3.0, 4.0), + Complex::new(5.0, 6.0), + ]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn()); + let result = Mean::new(0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::C128(arr)) = result else { + panic!("expected C128 leaf"); + }; + let v = arr.as_slice().unwrap()[0]; + assert!((v.re - 3.0).abs() < 1e-10); + assert!((v.im - 4.0).abs() < 1e-10); + } + + // --- Variance tests --- + + #[test] + fn test_variance_f64_ddof0() { + // [2, 4, 4, 4, 5, 5, 7, 9] — classic example, population variance = 4.0 + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[4.0]); + } + + #[test] + fn test_variance_f64_ddof1() { + // Sample variance (ddof=1) of the same sequence + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 1.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + // sample variance = population variance * n / (n-1) = 4.0 * 8/7 + approx_eq_slice(arr.as_slice().unwrap(), &[4.0 * 8.0 / 7.0]); + } + + #[test] + fn test_variance_i32_casts_to_f64() { + let x = Tensor::from([1_i32, 2, 3, 4, 5]); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!(tensor.dtype(), DType::F64); + let Tensor::F64(arr) = tensor else { panic!() }; + // mean=3, deviations=[-2,-1,0,1,2], sq=[4,1,0,1,4], population variance=2.0 + approx_eq_slice(arr.as_slice().unwrap(), &[2.0]); + } + + #[test] + fn test_variance_f32() { + let x = Tensor::from([2.0_f32, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F32, + "F32 input should produce F32 variance" + ); + let Tensor::F32(arr) = tensor else { panic!() }; + approx_eq_slice_f32(arr.as_slice().unwrap(), &[4.0]); + } + + #[test] + fn test_variance_c128_returns_real() { + use num_complex::Complex; + // [1+1i, 3+3i] — mean = 2+2i, deviations = [−1−i, 1+i], |.|^2 = [2, 2], var = 2.0 + let data: Vec> = vec![Complex::new(1.0, 1.0), Complex::new(3.0, 3.0)]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn()); + let result = Variance::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!( + tensor.dtype(), + DType::F64, + "C128 variance should return F64" + ); + let Tensor::F64(arr) = tensor else { panic!() }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.0]); + } + + // --- Std tests --- + + #[test] + fn test_std_f64() { + // std = sqrt(variance) = 2.0 for the classic example + let x = Tensor::from([2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]); + let result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected F64 leaf"); + }; + approx_eq_slice(arr.as_slice().unwrap(), &[2.0]); + } + + #[test] + fn test_std_i32_casts_to_f64() { + let x = Tensor::from([1_i32, 2, 3, 4, 5]); + let result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!(tensor.dtype(), DType::F64); + let Tensor::F64(arr) = tensor else { panic!() }; + // population std = sqrt(2.0) + approx_eq_slice(arr.as_slice().unwrap(), &[2.0_f64.sqrt()]); + } + + #[test] + fn test_std_matches_sqrt_of_variance() { + // Verify std = sqrt(variance) numerically + let x = Tensor::from([1.0_f64, 3.0, 5.0, 7.0, 9.0]); + let var_result = Variance::new(0, 0.0).call(&leaf(x.clone())).unwrap(); + let std_result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + + let DataTree::Leaf(Tensor::F64(var_arr)) = var_result else { + panic!() + }; + let DataTree::Leaf(Tensor::F64(std_arr)) = std_result else { + panic!() + }; + + let var_val = var_arr.as_slice().unwrap()[0]; + let std_val = std_arr.as_slice().unwrap()[0]; + assert!((std_val - var_val.sqrt()).abs() < 1e-10); + } + + #[test] + fn test_std_c128_returns_real() { + use num_complex::Complex; + let data: Vec> = vec![Complex::new(1.0, 1.0), Complex::new(3.0, 3.0)]; + let x = Tensor::C128(ndarray::Array1::from(data).into_dyn()); + let result = Std::new(0, 0.0).call(&leaf(x)).unwrap(); + let DataTree::Leaf(tensor) = result else { + panic!("expected leaf"); + }; + assert_eq!(tensor.dtype(), DType::F64, "C128 std should return F64"); + let Tensor::F64(arr) = tensor else { panic!() }; + // std = sqrt(2.0) + approx_eq_slice(arr.as_slice().unwrap(), &[2.0_f64.sqrt()]); + } +} diff --git a/crates/providers/src/program_node.rs b/crates/providers/src/program_node.rs new file mode 100644 index 000000000000..be8980d250ca --- /dev/null +++ b/crates/providers/src/program_node.rs @@ -0,0 +1,40 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::tensor::{Tensor, TensorType}; + +/// A node in a quantum program graph that transforms tensors. +pub trait ProgramNode { + /// The name of this program node. + fn name(&self) -> &'static str; + + /// The namespace this program node belongs to. + fn namespace(&self) -> &'static str; + + /// The namespace and name as one string. + fn full_name(&self) -> String { + format_args!("{}.{}", self.namespace(), self.name()).to_string() + } + + /// The inputs expected at `call` time. + fn input_types(&self) -> &DataTree; + + /// The outputs promised on `call` return. + fn output_types(&self) -> &DataTree; + + /// Whether this program node implements the call method. + fn implements_call(&self) -> bool; + + /// The action of this program node. + fn call(&self, args: &DataTree) -> anyhow::Result>; +} diff --git a/crates/providers/src/store.rs b/crates/providers/src/store.rs new file mode 100644 index 000000000000..6e0e84172bb2 --- /dev/null +++ b/crates/providers/src/store.rs @@ -0,0 +1,159 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use crate::data_tree::DataTree; +use crate::program_node::ProgramNode; +use crate::tensor::{Tensor, TensorType}; +use std::sync::OnceLock; + +/// A program node that owns constant data and outputs it unconditionally. +/// +/// `Store` takes no inputs; its `call()` always returns the data it was constructed with. +/// In a data-flow graph, `Store` nodes play the role of constants — they are wired to +/// the input ports of computation nodes to supply fixed values. +pub struct Store { + data: DataTree, + output_types: DataTree, +} + +impl Store { + /// Construct a new `Store` holding the given data. + pub fn new(data: DataTree) -> Self { + let output_types = derive_output_types(&data); + Self { data, output_types } + } + + /// Return a reference to the stored data. + pub fn data(&self) -> &DataTree { + &self.data + } +} + +/// Recursively derive output types from concrete tensor data. +fn derive_output_types(data: &DataTree) -> DataTree { + match data { + DataTree::Leaf(tensor) => DataTree::new_leaf(tensor.tensor_type()), + DataTree::Branch(_) => { + let mut result = DataTree::with_capacity(data.len()); + for (key, child) in data.iter_children() { + let child_type = derive_output_types(child); + if let Some(k) = key { + result.insert_branch(k, child_type); + } else { + result.push_branch(child_type); + } + } + result + } + } +} + +impl ProgramNode for Store { + fn name(&self) -> &'static str { + "store" + } + + fn namespace(&self) -> &'static str { + "core" + } + + fn input_types(&self) -> &DataTree { + static EMPTY: OnceLock> = OnceLock::new(); + EMPTY.get_or_init(DataTree::new) + } + + fn output_types(&self) -> &DataTree { + &self.output_types + } + + fn implements_call(&self) -> bool { + true + } + + fn call(&self, _args: &DataTree) -> anyhow::Result> { + Ok(self.data.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::{DType, DTypeLike, Dim, Tensor}; + + #[test] + fn test_store_leaf_call() { + let data = DataTree::new_leaf(Tensor::from([1.0_f64, 2.0, 3.0])); + let store = Store::new(data); + let result = store.call(&DataTree::new()).unwrap(); + let DataTree::Leaf(Tensor::F64(arr)) = result else { + panic!("expected f64 leaf"); + }; + assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0]); + } + + #[test] + fn test_store_output_types_leaf() { + let data = DataTree::new_leaf(Tensor::from([1.0_f64, 2.0, 3.0])); + let store = Store::new(data); + let DataTree::Leaf(tt) = store.output_types() else { + panic!("expected leaf output type"); + }; + assert!(matches!(tt.dtype, DTypeLike::Concrete(DType::F64))); + assert_eq!(tt.shape, vec![Dim::Fixed(3)]); + assert!(!tt.broadcastable); + } + + #[test] + fn test_store_output_types_2d() { + use ndarray::arr2; + let data = + DataTree::new_leaf(Tensor::F64(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn())); + let store = Store::new(data); + let DataTree::Leaf(tt) = store.output_types() else { + panic!("expected leaf output type"); + }; + assert_eq!(tt.shape, vec![Dim::Fixed(2), Dim::Fixed(2)]); + } + + #[test] + fn test_store_branched() { + let mut data = DataTree::new(); + data.insert_leaf("a", Tensor::from([1.0_f64, 2.0])); + data.insert_leaf("b", Tensor::from([10_i32, 20, 30])); + let store = Store::new(data); + + assert!(store.input_types().is_empty()); + assert_eq!(store.name(), "store"); + assert_eq!(store.namespace(), "core"); + assert_eq!(store.full_name(), "core.store"); + + let out_types = store.output_types(); + let DataTree::Leaf(tt_a) = out_types.get_by_str_key("a").unwrap() else { + panic!("expected leaf at a"); + }; + assert!(matches!(tt_a.dtype, DTypeLike::Concrete(DType::F64))); + assert_eq!(tt_a.shape, vec![Dim::Fixed(2)]); + + let DataTree::Leaf(tt_b) = out_types.get_by_str_key("b").unwrap() else { + panic!("expected leaf at b"); + }; + assert!(matches!(tt_b.dtype, DTypeLike::Concrete(DType::I32))); + assert_eq!(tt_b.shape, vec![Dim::Fixed(3)]); + } + + #[test] + fn test_store_no_inputs() { + let store = Store::new(DataTree::new_leaf(Tensor::from([42.0_f64]))); + assert!(store.input_types().is_empty()); + assert!(store.implements_call()); + } +} diff --git a/crates/providers/src/tensor.rs b/crates/providers/src/tensor.rs new file mode 100644 index 000000000000..25e50645fa68 --- /dev/null +++ b/crates/providers/src/tensor.rs @@ -0,0 +1,673 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2026 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at https://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use ndarray::{ArrayD, IxDyn, Zip}; +use num_complex::Complex; +use std::fmt; + +/// The possible data types for a Tensor. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DType { + C128, // complex + C64, + F64, // real + F32, + I64, // signed integer + I32, + I16, + I8, + U64, // unsigned integer + U32, + U16, + U8, + Bit, // bool +} + +impl fmt::Display for DType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let string_repr = match self { + DType::C128 => "C128", + DType::C64 => "C64", + DType::F64 => "F64", + DType::F32 => "F32", + DType::I64 => "I64", + DType::I32 => "I32", + DType::I16 => "I16", + DType::I8 => "I8", + DType::U64 => "U64", + DType::U32 => "U32", + DType::U16 => "U16", + DType::U8 => "U8", + DType::Bit => "Bit", + }; + write!(f, "{string_repr}") + } +} + +/// A tensor dtype that is unknown but identified by name. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DTypeVar { + /// The variable name. + pub name: String, +} + +impl> From for DTypeVar { + fn from(value: T) -> Self { + Self { name: value.into() } + } +} + +/// A tensor data type whose value is yet unknown, but will be the promotion of others. +#[derive(Debug, Clone)] +pub struct DTypePromotion { + /// The dtype arguments to promote over. + pub args: Vec, +} + +impl>> From for DTypePromotion { + fn from(args: T) -> Self { + Self { args: args.into() } + } +} + +/// A tensor data type, known or unknown. +#[derive(Debug, Clone)] +pub enum DTypeLike { + /// A fully resolved dtype. + Concrete(DType), + /// A dtype identified by a variable name, to be resolved later. + Var(DTypeVar), + /// A dtype that is the promotion of one or more other dtypes. + Promotion(DTypePromotion), +} + +/// Promote a pair of DTypes to the smallest type compatible with both. +/// +/// QuantumProgram nodes often, but not necessarily, use this promotion rule +/// to determine their output type. +/// +/// This function implements the same promotion rules as NumPy, modulo that we don't +/// need to contend with the arbitrary precision types for each type kind, and that +/// we omit F16 entirely because it's ustable in rust: +/// https://numpy.org/doc/stable/reference/arrays.promotion.html#numerical-promotion +/// In short, if you view the linked diagram as a DAG, this function hard-codes the +/// least-common-descendent algorithm. +pub fn promotion(lhs: DType, rhs: DType) -> DType { + use DType::*; + + // painfully write a lookup table as a nested match statement. to check if it's right, + // compare agaist the linked image, or more easily, study the test + // test_promotion_against_promotion_dag that tests every input combination. + match lhs { + C128 => C128, + + C64 => match rhs { + U32 | U64 | I32 | I64 | F64 | C128 => C128, + _ => C64, + }, + + F64 => match rhs { + C64 | C128 => C128, + _ => F64, + }, + + F32 => match rhs { + C128 => C128, + C64 => C64, + U32 | U64 | I32 | I64 | F64 => F64, + _ => F32, + }, + + I64 => match rhs { + C64 | C128 => C128, + U64 | F32 | F64 => F64, + _ => I64, + }, + + I32 => match rhs { + C64 | C128 => C128, + U64 | F32 | F64 => F64, + U32 | I64 => I64, + _ => I32, + }, + + I16 => match rhs { + U64 => F64, + U32 => I64, + U16 => I32, + Bit | U8 | I8 => I16, + _ => rhs, + }, + + I8 => match rhs { + U64 => F64, + U32 => I64, + U16 => I32, + U8 => I16, + Bit => I8, + _ => rhs, + }, + + U64 => match rhs { + C128 | C64 => C128, + F32 | F64 | I8 | I16 | I32 | I64 => F64, + _ => U64, + }, + + U32 => match rhs { + C64 | C128 => C128, + F32 | F64 => F64, + I8 | I16 | I32 | I64 => I64, + U64 => U64, + _ => U32, + }, + + U16 => match rhs { + I8 | I16 => I32, + Bit | U8 => U16, + _ => rhs, + }, + + U8 => match rhs { + I8 => I16, + Bit => U8, + _ => rhs, + }, + + Bit => rhs, + } +} + +/// A tensor axis dimension. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Dim { + /// A known size. + Fixed(usize), + /// An unresolved, named size. + Named(String), +} + +/// A specification of a tensor without any data. +#[derive(Debug, Clone)] +pub struct TensorType { + /// The type of the tensor. + pub dtype: DTypeLike, + /// The shape of the tensor, possibly with axes of unknown size. + pub shape: Vec, + /// Whether the tensor supports leading-axis (i.e. NumPy-style) broadcasting semantics. + pub broadcastable: bool, +} + +impl TensorType { + /// Return a dimension vector if all sizes are fixed, or `None` if any are named. + pub fn concrete_shape(&self) -> Option> { + let mut out = Vec::with_capacity(self.shape.len()); + for d in &self.shape { + match d { + Dim::Fixed(n) => out.push(*n), + Dim::Named(_) => return None, + } + } + Some(out) + } +} + +/// A tensor of one of the supported dtypes. +#[derive(Debug, Clone)] +pub enum Tensor { + C64(ArrayD>), // complex + C128(ArrayD>), + F32(ArrayD), // real + F64(ArrayD), + I8(ArrayD), // signed integer + I16(ArrayD), + I32(ArrayD), + I64(ArrayD), + U8(ArrayD), // unsigned integer + U16(ArrayD), + U32(ArrayD), + U64(ArrayD), + Bit(ArrayD), // bool +} + +/// Cast an `ArrayD` of a real numeric type to any supported dtype. +macro_rules! cast_real { + ($arr:expr, $src:ty, $target:expr) => { + match $target { + DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8)), + DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8)), + DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16)), + DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32)), + DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64)), + DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8)), + DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16)), + DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32)), + DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64)), + DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32)), + DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64)), + DType::C64 => Tensor::C64($arr.mapv(|x: $src| Complex::new(x as f32, 0.0))), + DType::C128 => Tensor::C128($arr.mapv(|x: $src| Complex::new(x as f64, 0.0))), + } + }; +} + +/// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets). +macro_rules! cast_complex { + ($arr:expr, $target:expr) => { + match $target { + DType::C64 => Tensor::C64($arr.mapv(|x| Complex::new(x.re as f32, x.im as f32))), + DType::C128 => Tensor::C128($arr.mapv(|x| Complex::new(x.re as f64, x.im as f64))), + _ => panic!("cannot cast complex tensor to a real dtype"), + } + }; +} + +/// Element-wise binary operation on two arrays with NumPy-style broadcasting. +/// +/// Unlike ndarray's built-in arithmetic operators which handle broadcasting automatically, +/// this helper is needed for operations without a Rust operator (e.g. `pow`). +fn broadcast_elementwise(a: &ArrayD, b: &ArrayD, op: F) -> ArrayD +where + T: Clone, + F: Fn(&T, &T) -> T, +{ + let ndim = a.ndim().max(b.ndim()); + let out_shape: Vec = (0..ndim) + .map(|i| { + let d_a = if i >= ndim - a.ndim() { + a.shape()[i - (ndim - a.ndim())] + } else { + 1 + }; + let d_b = if i >= ndim - b.ndim() { + b.shape()[i - (ndim - b.ndim())] + } else { + 1 + }; + match (d_a, d_b) { + (x, y) if x == y => x, + (1, y) => y, + (x, 1) => x, + _ => panic!( + "shapes {:?} and {:?} are not broadcast-compatible", + a.shape(), + b.shape() + ), + } + }) + .collect(); + let out_ix = IxDyn(&out_shape); + let a_bc = a.broadcast(out_ix.clone()).expect("broadcast failed"); + let b_bc = b.broadcast(out_ix).expect("broadcast failed"); + Zip::from(a_bc).and(b_bc).map_collect(op) +} + +impl Tensor { + /// Return the dtype of this tensor. + pub fn dtype(&self) -> DType { + match self { + Tensor::C128(_) => DType::C128, + Tensor::C64(_) => DType::C64, + Tensor::F64(_) => DType::F64, + Tensor::F32(_) => DType::F32, + Tensor::I64(_) => DType::I64, + Tensor::I32(_) => DType::I32, + Tensor::I16(_) => DType::I16, + Tensor::I8(_) => DType::I8, + Tensor::U64(_) => DType::U64, + Tensor::U32(_) => DType::U32, + Tensor::U16(_) => DType::U16, + Tensor::U8(_) => DType::U8, + Tensor::Bit(_) => DType::Bit, + } + } + + /// Return the shape of this tensor as a slice of dimension sizes. + pub fn shape(&self) -> &[usize] { + match self { + Tensor::C128(a) => a.shape(), + Tensor::C64(a) => a.shape(), + Tensor::F64(a) => a.shape(), + Tensor::F32(a) => a.shape(), + Tensor::I64(a) => a.shape(), + Tensor::I32(a) => a.shape(), + Tensor::I16(a) => a.shape(), + Tensor::I8(a) => a.shape(), + Tensor::U64(a) => a.shape(), + Tensor::U32(a) => a.shape(), + Tensor::U16(a) => a.shape(), + Tensor::U8(a) => a.shape(), + Tensor::Bit(a) => a.shape(), + } + } + + /// Return the [`TensorType`] that describes this tensor's dtype and concrete shape. + pub fn tensor_type(&self) -> TensorType { + TensorType { + dtype: DTypeLike::Concrete(self.dtype()), + shape: self.shape().iter().map(|&n| Dim::Fixed(n)).collect(), + broadcastable: false, + } + } + + /// Element-wise power with NumPy-style broadcasting. + /// + /// For integer types the exponent is cast to `u32`; negative integer exponents + /// are not supported. + pub fn pow(&self, rhs: &Tensor) -> Tensor { + match (self, rhs) { + (Tensor::F32(a), Tensor::F32(b)) => { + Tensor::F32(broadcast_elementwise(a, b, |&x, &y| x.powf(y))) + } + (Tensor::F64(a), Tensor::F64(b)) => { + Tensor::F64(broadcast_elementwise(a, b, |&x, &y| x.powf(y))) + } + (Tensor::C64(a), Tensor::C64(b)) => { + Tensor::C64(broadcast_elementwise(a, b, |&x, &y| x.powc(y))) + } + (Tensor::C128(a), Tensor::C128(b)) => { + Tensor::C128(broadcast_elementwise(a, b, |&x, &y| x.powc(y))) + } + (Tensor::I8(a), Tensor::I8(b)) => { + Tensor::I8(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::I16(a), Tensor::I16(b)) => { + Tensor::I16(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::I32(a), Tensor::I32(b)) => { + Tensor::I32(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::I64(a), Tensor::I64(b)) => { + Tensor::I64(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::U8(a), Tensor::U8(b)) => { + Tensor::U8(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::U16(a), Tensor::U16(b)) => { + Tensor::U16(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + (Tensor::U32(a), Tensor::U32(b)) => { + Tensor::U32(broadcast_elementwise(a, b, |&x, &y| x.pow(y))) + } + (Tensor::U64(a), Tensor::U64(b)) => { + Tensor::U64(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32))) + } + _ => panic!("type mismatch in Tensor::pow"), + } + } + + /// Cast this tensor to `target`, consuming it. Returns `self` unchanged if already that dtype. + pub fn cast(self, target: DType) -> Tensor { + if self.dtype() == target { + return self; + } + match &self { + Tensor::Bit(a) | Tensor::U8(a) => cast_real!(a, u8, target), + Tensor::U16(a) => cast_real!(a, u16, target), + Tensor::U32(a) => cast_real!(a, u32, target), + Tensor::U64(a) => cast_real!(a, u64, target), + Tensor::I8(a) => cast_real!(a, i8, target), + Tensor::I16(a) => cast_real!(a, i16, target), + Tensor::I32(a) => cast_real!(a, i32, target), + Tensor::I64(a) => cast_real!(a, i64, target), + Tensor::F32(a) => cast_real!(a, f32, target), + Tensor::F64(a) => cast_real!(a, f64, target), + Tensor::C64(a) => cast_complex!(a, target), + Tensor::C128(a) => cast_complex!(a, target), + } + } +} + +/// Implement `From<&[T]>`, `From<&[T; N]>`, and `From>` for a given `Tensor` variant. +macro_rules! impl_tensor_from { + ($variant:ident, $t:ty) => { + impl From<&[$t]> for Tensor { + fn from(data: &[$t]) -> Self { + Tensor::$variant(ndarray::arr1(data).into_dyn()) + } + } + impl From<[$t; N]> for Tensor { + fn from(data: [$t; N]) -> Self { + Tensor::$variant(ndarray::arr1(&data).into_dyn()) + } + } + impl From> for Tensor { + fn from(data: ArrayD<$t>) -> Self { + Tensor::$variant(data) + } + } + }; +} + +impl_tensor_from!(C128, Complex); +impl_tensor_from!(C64, Complex); +impl_tensor_from!(F64, f64); +impl_tensor_from!(F32, f32); +impl_tensor_from!(I64, i64); +impl_tensor_from!(I32, i32); +impl_tensor_from!(I16, i16); +impl_tensor_from!(I8, i8); +impl_tensor_from!(U64, u64); +impl_tensor_from!(U32, u32); +impl_tensor_from!(U16, u16); +impl_tensor_from!(U8, u8); // u8 → U8; Bit requires explicit construction + +/// Implement a standard Rust binary operator trait for `Tensor` and `&Tensor`. +macro_rules! impl_tensor_binop { + ($trait:ident, $method:ident, $op:tt) => { + impl std::ops::$trait for &Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { + match (self, rhs) { + (Tensor::C128(a), Tensor::C128(b)) => Tensor::C128(a $op b), + (Tensor::C64(a), Tensor::C64(b)) => Tensor::C64(a $op b), + (Tensor::F64(a), Tensor::F64(b)) => Tensor::F64(a $op b), + (Tensor::F32(a), Tensor::F32(b)) => Tensor::F32(a $op b), + (Tensor::I64(a), Tensor::I64(b)) => Tensor::I64(a $op b), + (Tensor::I32(a), Tensor::I32(b)) => Tensor::I32(a $op b), + (Tensor::I16(a), Tensor::I16(b)) => Tensor::I16(a $op b), + (Tensor::I8(a), Tensor::I8(b)) => Tensor::I8(a $op b), + (Tensor::U64(a), Tensor::U64(b)) => Tensor::U64(a $op b), + (Tensor::U32(a), Tensor::U32(b)) => Tensor::U32(a $op b), + (Tensor::U16(a), Tensor::U16(b)) => Tensor::U16(a $op b), + (Tensor::U8(a), Tensor::U8(b)) => Tensor::U8(a $op b), + _ => panic!("type mismatch in Tensor::{}", stringify!($method)), + } + } + } + impl std::ops::$trait for Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { &self $op &rhs } + } + }; +} + +/// Like [`impl_tensor_binop!`], but omits complex variants for ops that don't support them +/// (e.g. `Rem`, which `num_complex` does not implement). +macro_rules! impl_tensor_binop_real { + ($trait:ident, $method:ident, $op:tt) => { + impl std::ops::$trait for &Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { + match (self, rhs) { + (Tensor::F64(a), Tensor::F64(b)) => Tensor::F64(a $op b), + (Tensor::F32(a), Tensor::F32(b)) => Tensor::F32(a $op b), + (Tensor::I64(a), Tensor::I64(b)) => Tensor::I64(a $op b), + (Tensor::I32(a), Tensor::I32(b)) => Tensor::I32(a $op b), + (Tensor::I16(a), Tensor::I16(b)) => Tensor::I16(a $op b), + (Tensor::I8(a), Tensor::I8(b)) => Tensor::I8(a $op b), + (Tensor::U64(a), Tensor::U64(b)) => Tensor::U64(a $op b), + (Tensor::U32(a), Tensor::U32(b)) => Tensor::U32(a $op b), + (Tensor::U16(a), Tensor::U16(b)) => Tensor::U16(a $op b), + (Tensor::U8(a), Tensor::U8(b)) => Tensor::U8(a $op b), + _ => panic!("type mismatch or unsupported dtype in Tensor::{}", stringify!($method)), + } + } + } + impl std::ops::$trait for Tensor { + type Output = Tensor; + fn $method(self, rhs: Self) -> Tensor { &self $op &rhs } + } + }; +} + +impl_tensor_binop!(Add, add, +); +impl_tensor_binop!(Sub, sub, -); +impl_tensor_binop!(Mul, mul, *); +impl_tensor_binop!(Div, div, /); +impl_tensor_binop_real!(Rem, rem, %); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_promotion_against_promotion_dag() { + use DType::*; + use rustworkx_core::dag_algo::lexicographical_topological_sort; + use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex}; + use rustworkx_core::traversal::descendants; + use std::collections::{HashMap, HashSet}; + + // define a DAG that implements all promotion rules; two DTypes + // should be promoted to their least common descendent in the DAG + let mut g: DiGraph = DiGraph::new(); + let mut idx: HashMap = HashMap::new(); + + let nodes = [ + Bit, U8, U16, U32, U64, I8, I16, I32, I64, F32, F64, C64, C128, + ]; + + for &dtype in &nodes { + idx.insert(dtype, g.add_node(dtype)); + } + + // within-kind promotions + g.add_edge(idx[&U8], idx[&U16], ()); + g.add_edge(idx[&U16], idx[&U32], ()); + g.add_edge(idx[&U32], idx[&U64], ()); + + g.add_edge(idx[&I8], idx[&I16], ()); + g.add_edge(idx[&I16], idx[&I32], ()); + g.add_edge(idx[&I32], idx[&I64], ()); + + g.add_edge(idx[&F32], idx[&F64], ()); + + g.add_edge(idx[&C64], idx[&C128], ()); + + // bit promotions + g.add_edge(idx[&Bit], idx[&U8], ()); + g.add_edge(idx[&Bit], idx[&I8], ()); + + // uint promotions + g.add_edge(idx[&U8], idx[&I16], ()); + g.add_edge(idx[&U16], idx[&I32], ()); + g.add_edge(idx[&U16], idx[&F32], ()); + g.add_edge(idx[&U32], idx[&I64], ()); + g.add_edge(idx[&U64], idx[&F64], ()); + + // int promotions + g.add_edge(idx[&I16], idx[&F32], ()); + g.add_edge(idx[&I32], idx[&F64], ()); + g.add_edge(idx[&I64], idx[&F64], ()); + + // float promotions + g.add_edge(idx[&F32], idx[&C64], ()); + g.add_edge(idx[&F64], idx[&C128], ()); + + let order = lexicographical_topological_sort( + &g, + |n: NodeIndex| Ok::(n.index()), + false, + None, + ) + .ok() + .unwrap(); + + let least_common_decendent = move |a: &DType, b: &DType| -> DType { + let da: HashSet<_> = descendants(&g, idx[&a]).collect(); + let db: HashSet<_> = descendants(&g, idx[&b]).collect(); + let common: HashSet = da.intersection(&db).copied().collect(); + let least_idx = order.iter().find(|n| common.contains(n)).unwrap(); + nodes[least_idx.index()] + }; + + for &a in &nodes { + for &b in &nodes { + assert_eq!( + promotion(a, b), + least_common_decendent(&a, &b), + "For promotion ({a}, {b})" + ) + } + } + } + + #[test] + fn test_promotion_idempotence() { + use DType::*; + let nodes = [ + Bit, U8, U16, U32, U64, I8, I16, I32, I64, F32, F64, C64, C128, + ]; + + for &a in &nodes { + assert_eq!(promotion(a, a), a, "For promotion ({a}, {a})") + } + } + + #[test] + fn test_promotion_commutativity() { + use DType::*; + let nodes = [ + Bit, U8, U16, U32, U64, I8, I16, I32, I64, F32, F64, C64, C128, + ]; + + for &a in &nodes { + for &b in &nodes { + assert_eq!(promotion(a, b), promotion(b, a), "For promotion ({a}, {b})") + } + } + } + + #[test] + fn test_tensor_type_concrete_shape() { + assert_eq!( + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![Dim::Fixed(3)], + broadcastable: false, + } + .concrete_shape(), + Some(vec![3]) + ); + + assert_eq!( + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![Dim::Fixed(3), Dim::Fixed(8)], + broadcastable: true, + } + .concrete_shape(), + Some(vec![3, 8]) + ); + + assert_eq!( + TensorType { + dtype: DTypeLike::Concrete(DType::Bit), + shape: vec![Dim::Fixed(3), Dim::Named("foo".into())], + broadcastable: false, + } + .concrete_shape(), + None + ); + } +}