Skip to content

Commit 23e7055

Browse files
committed
Add operation implementations for Tensor
1 parent 0d13238 commit 23e7055

1 file changed

Lines changed: 304 additions & 12 deletions

File tree

crates/providers/src/tensor.rs

Lines changed: 304 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// copyright notice, and modified files need to carry a notice indicating
1111
// that they have been altered from the originals.
1212

13-
use ndarray::ArrayD;
13+
use ndarray::{ArrayD, IxDyn, Zip};
1414
use num_complex::Complex;
1515
use std::fmt;
1616

@@ -19,13 +19,13 @@ use std::fmt;
1919
pub enum DType {
2020
C128, // complex
2121
C64,
22-
F64, // float
22+
F64, // real
2323
F32,
24-
I64, // signed ints
24+
I64, // signed integer
2525
I32,
2626
I16,
2727
I8,
28-
U64, // unsigned ints
28+
U64, // unsigned integer
2929
U32,
3030
U16,
3131
U8,
@@ -53,9 +53,10 @@ impl fmt::Display for DType {
5353
}
5454
}
5555

56-
// A tensor data type whose value is yet unknown, but named.
56+
/// A tensor dtype that is unknown but identified by name.
5757
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5858
pub struct DTypeVar {
59+
/// The variable name.
5960
pub name: String,
6061
}
6162

@@ -68,6 +69,7 @@ impl<T: Into<String>> From<T> for DTypeVar {
6869
/// A tensor data type whose value is yet unknown, but will be the promotion of others.
6970
#[derive(Debug, Clone)]
7071
pub struct DTypePromotion {
72+
/// The dtype arguments to promote over.
7173
pub args: Vec<DTypeLike>,
7274
}
7375

@@ -80,14 +82,17 @@ impl<T: Into<Vec<DTypeLike>>> From<T> for DTypePromotion {
8082
/// A tensor data type, known or unknown.
8183
#[derive(Debug, Clone)]
8284
pub enum DTypeLike {
85+
/// A fully resolved dtype.
8386
Concrete(DType),
87+
/// A dtype identified by a variable name, to be resolved later.
8488
Var(DTypeVar),
89+
/// A dtype that is the promotion of one or more other dtypes.
8590
Promotion(DTypePromotion),
8691
}
8792

8893
/// Promote a pair of DTypes to the smallest type compatible with both.
8994
///
90-
/// QuantumProgram operations often, but not necessarily, use this promotion rule
95+
/// QuantumProgram nodes often, but not necessarily, use this promotion rule
9196
/// to determine their output type.
9297
///
9398
/// This function implements the same promotion rules as NumPy, modulo that we don't
@@ -203,7 +208,7 @@ pub struct TensorType {
203208
}
204209

205210
impl TensorType {
206-
// Return a dimension vector if all are sizes are fixed, None otherwise.
211+
/// Return a dimension vector if all sizes are fixed, or `None` if any are named.
207212
pub fn concrete_shape(&self) -> Option<Vec<usize>> {
208213
let mut out = Vec::with_capacity(self.shape.len());
209214
for d in &self.shape {
@@ -219,21 +224,308 @@ impl TensorType {
219224
/// A tensor of one of the supported dtypes.
220225
#[derive(Debug, Clone)]
221226
pub enum Tensor {
222-
C64(ArrayD<Complex<f32>>),
227+
C64(ArrayD<Complex<f32>>), // complex
223228
C128(ArrayD<Complex<f64>>),
224-
F32(ArrayD<f32>),
229+
F32(ArrayD<f32>), // real
225230
F64(ArrayD<f64>),
226-
I8(ArrayD<i8>),
231+
I8(ArrayD<i8>), // signed integer
227232
I16(ArrayD<i16>),
228233
I32(ArrayD<i32>),
229234
I64(ArrayD<i64>),
230-
U8(ArrayD<u8>),
235+
U8(ArrayD<u8>), // unsigned integer
231236
U16(ArrayD<u16>),
232237
U32(ArrayD<u32>),
233238
U64(ArrayD<u64>),
234-
Bit(ArrayD<u8>),
239+
Bit(ArrayD<u8>), // bool
235240
}
236241

242+
/// Cast an `ArrayD` of a real numeric type to any supported dtype.
243+
macro_rules! cast_real {
244+
($arr:expr, $src:ty, $target:expr) => {
245+
match $target {
246+
DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8)),
247+
DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8)),
248+
DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16)),
249+
DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32)),
250+
DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64)),
251+
DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8)),
252+
DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16)),
253+
DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32)),
254+
DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64)),
255+
DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32)),
256+
DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64)),
257+
DType::C64 => Tensor::C64($arr.mapv(|x: $src| Complex::new(x as f32, 0.0))),
258+
DType::C128 => Tensor::C128($arr.mapv(|x: $src| Complex::new(x as f64, 0.0))),
259+
}
260+
};
261+
}
262+
263+
/// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets).
264+
macro_rules! cast_complex {
265+
($arr:expr, $target:expr) => {
266+
match $target {
267+
DType::C64 => Tensor::C64($arr.mapv(|x| Complex::new(x.re as f32, x.im as f32))),
268+
DType::C128 => Tensor::C128($arr.mapv(|x| Complex::new(x.re as f64, x.im as f64))),
269+
_ => panic!("cannot cast complex tensor to a real dtype"),
270+
}
271+
};
272+
}
273+
274+
/// Element-wise binary operation on two arrays with NumPy-style broadcasting.
275+
///
276+
/// Unlike ndarray's built-in arithmetic operators which handle broadcasting automatically,
277+
/// this helper is needed for operations without a Rust operator (e.g. `pow`).
278+
fn broadcast_elementwise<T, F>(a: &ArrayD<T>, b: &ArrayD<T>, op: F) -> ArrayD<T>
279+
where
280+
T: Clone,
281+
F: Fn(&T, &T) -> T,
282+
{
283+
let ndim = a.ndim().max(b.ndim());
284+
let out_shape: Vec<usize> = (0..ndim)
285+
.map(|i| {
286+
let d_a = if i >= ndim - a.ndim() {
287+
a.shape()[i - (ndim - a.ndim())]
288+
} else {
289+
1
290+
};
291+
let d_b = if i >= ndim - b.ndim() {
292+
b.shape()[i - (ndim - b.ndim())]
293+
} else {
294+
1
295+
};
296+
match (d_a, d_b) {
297+
(x, y) if x == y => x,
298+
(1, y) => y,
299+
(x, 1) => x,
300+
_ => panic!(
301+
"shapes {:?} and {:?} are not broadcast-compatible",
302+
a.shape(),
303+
b.shape()
304+
),
305+
}
306+
})
307+
.collect();
308+
let out_ix = IxDyn(&out_shape);
309+
let a_bc = a.broadcast(out_ix.clone()).expect("broadcast failed");
310+
let b_bc = b.broadcast(out_ix).expect("broadcast failed");
311+
Zip::from(a_bc).and(b_bc).map_collect(op)
312+
}
313+
314+
impl Tensor {
315+
/// Return the dtype of this tensor.
316+
pub fn dtype(&self) -> DType {
317+
match self {
318+
Tensor::C128(_) => DType::C128,
319+
Tensor::C64(_) => DType::C64,
320+
Tensor::F64(_) => DType::F64,
321+
Tensor::F32(_) => DType::F32,
322+
Tensor::I64(_) => DType::I64,
323+
Tensor::I32(_) => DType::I32,
324+
Tensor::I16(_) => DType::I16,
325+
Tensor::I8(_) => DType::I8,
326+
Tensor::U64(_) => DType::U64,
327+
Tensor::U32(_) => DType::U32,
328+
Tensor::U16(_) => DType::U16,
329+
Tensor::U8(_) => DType::U8,
330+
Tensor::Bit(_) => DType::Bit,
331+
}
332+
}
333+
334+
/// Return the shape of this tensor as a slice of dimension sizes.
335+
pub fn shape(&self) -> &[usize] {
336+
match self {
337+
Tensor::C128(a) => a.shape(),
338+
Tensor::C64(a) => a.shape(),
339+
Tensor::F64(a) => a.shape(),
340+
Tensor::F32(a) => a.shape(),
341+
Tensor::I64(a) => a.shape(),
342+
Tensor::I32(a) => a.shape(),
343+
Tensor::I16(a) => a.shape(),
344+
Tensor::I8(a) => a.shape(),
345+
Tensor::U64(a) => a.shape(),
346+
Tensor::U32(a) => a.shape(),
347+
Tensor::U16(a) => a.shape(),
348+
Tensor::U8(a) => a.shape(),
349+
Tensor::Bit(a) => a.shape(),
350+
}
351+
}
352+
353+
/// Return the [`TensorType`] that describes this tensor's dtype and concrete shape.
354+
pub fn tensor_type(&self) -> TensorType {
355+
TensorType {
356+
dtype: DTypeLike::Concrete(self.dtype()),
357+
shape: self.shape().iter().map(|&n| Dim::Fixed(n)).collect(),
358+
broadcastable: false,
359+
}
360+
}
361+
362+
/// Element-wise power with NumPy-style broadcasting.
363+
///
364+
/// For integer types the exponent is cast to `u32`; negative integer exponents
365+
/// are not supported.
366+
pub fn pow(&self, rhs: &Tensor) -> Tensor {
367+
match (self, rhs) {
368+
(Tensor::F32(a), Tensor::F32(b)) => {
369+
Tensor::F32(broadcast_elementwise(a, b, |&x, &y| x.powf(y)))
370+
}
371+
(Tensor::F64(a), Tensor::F64(b)) => {
372+
Tensor::F64(broadcast_elementwise(a, b, |&x, &y| x.powf(y)))
373+
}
374+
(Tensor::C64(a), Tensor::C64(b)) => {
375+
Tensor::C64(broadcast_elementwise(a, b, |&x, &y| x.powc(y)))
376+
}
377+
(Tensor::C128(a), Tensor::C128(b)) => {
378+
Tensor::C128(broadcast_elementwise(a, b, |&x, &y| x.powc(y)))
379+
}
380+
(Tensor::I8(a), Tensor::I8(b)) => {
381+
Tensor::I8(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32)))
382+
}
383+
(Tensor::I16(a), Tensor::I16(b)) => {
384+
Tensor::I16(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32)))
385+
}
386+
(Tensor::I32(a), Tensor::I32(b)) => {
387+
Tensor::I32(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32)))
388+
}
389+
(Tensor::I64(a), Tensor::I64(b)) => {
390+
Tensor::I64(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32)))
391+
}
392+
(Tensor::U8(a), Tensor::U8(b)) => {
393+
Tensor::U8(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32)))
394+
}
395+
(Tensor::U16(a), Tensor::U16(b)) => {
396+
Tensor::U16(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32)))
397+
}
398+
(Tensor::U32(a), Tensor::U32(b)) => {
399+
Tensor::U32(broadcast_elementwise(a, b, |&x, &y| x.pow(y)))
400+
}
401+
(Tensor::U64(a), Tensor::U64(b)) => {
402+
Tensor::U64(broadcast_elementwise(a, b, |&x, &y| x.pow(y as u32)))
403+
}
404+
_ => panic!("type mismatch in Tensor::pow"),
405+
}
406+
}
407+
408+
/// Cast this tensor to `target`, consuming it. Returns `self` unchanged if already that dtype.
409+
pub fn cast(self, target: DType) -> Tensor {
410+
if self.dtype() == target {
411+
return self;
412+
}
413+
match &self {
414+
Tensor::Bit(a) | Tensor::U8(a) => cast_real!(a, u8, target),
415+
Tensor::U16(a) => cast_real!(a, u16, target),
416+
Tensor::U32(a) => cast_real!(a, u32, target),
417+
Tensor::U64(a) => cast_real!(a, u64, target),
418+
Tensor::I8(a) => cast_real!(a, i8, target),
419+
Tensor::I16(a) => cast_real!(a, i16, target),
420+
Tensor::I32(a) => cast_real!(a, i32, target),
421+
Tensor::I64(a) => cast_real!(a, i64, target),
422+
Tensor::F32(a) => cast_real!(a, f32, target),
423+
Tensor::F64(a) => cast_real!(a, f64, target),
424+
Tensor::C64(a) => cast_complex!(a, target),
425+
Tensor::C128(a) => cast_complex!(a, target),
426+
}
427+
}
428+
}
429+
430+
/// Implement `From<&[T]>`, `From<&[T; N]>`, and `From<ArrayD<T>>` for a given `Tensor` variant.
431+
macro_rules! impl_tensor_from {
432+
($variant:ident, $t:ty) => {
433+
impl From<&[$t]> for Tensor {
434+
fn from(data: &[$t]) -> Self {
435+
Tensor::$variant(ndarray::arr1(data).into_dyn())
436+
}
437+
}
438+
impl<const N: usize> From<[$t; N]> for Tensor {
439+
fn from(data: [$t; N]) -> Self {
440+
Tensor::$variant(ndarray::arr1(&data).into_dyn())
441+
}
442+
}
443+
impl From<ArrayD<$t>> for Tensor {
444+
fn from(data: ArrayD<$t>) -> Self {
445+
Tensor::$variant(data)
446+
}
447+
}
448+
};
449+
}
450+
451+
impl_tensor_from!(C128, Complex<f64>);
452+
impl_tensor_from!(C64, Complex<f32>);
453+
impl_tensor_from!(F64, f64);
454+
impl_tensor_from!(F32, f32);
455+
impl_tensor_from!(I64, i64);
456+
impl_tensor_from!(I32, i32);
457+
impl_tensor_from!(I16, i16);
458+
impl_tensor_from!(I8, i8);
459+
impl_tensor_from!(U64, u64);
460+
impl_tensor_from!(U32, u32);
461+
impl_tensor_from!(U16, u16);
462+
impl_tensor_from!(U8, u8); // u8 → U8; Bit requires explicit construction
463+
464+
/// Implement a standard Rust binary operator trait for `Tensor` and `&Tensor`.
465+
macro_rules! impl_tensor_binop {
466+
($trait:ident, $method:ident, $op:tt) => {
467+
impl std::ops::$trait for &Tensor {
468+
type Output = Tensor;
469+
fn $method(self, rhs: Self) -> Tensor {
470+
match (self, rhs) {
471+
(Tensor::C128(a), Tensor::C128(b)) => Tensor::C128(a $op b),
472+
(Tensor::C64(a), Tensor::C64(b)) => Tensor::C64(a $op b),
473+
(Tensor::F64(a), Tensor::F64(b)) => Tensor::F64(a $op b),
474+
(Tensor::F32(a), Tensor::F32(b)) => Tensor::F32(a $op b),
475+
(Tensor::I64(a), Tensor::I64(b)) => Tensor::I64(a $op b),
476+
(Tensor::I32(a), Tensor::I32(b)) => Tensor::I32(a $op b),
477+
(Tensor::I16(a), Tensor::I16(b)) => Tensor::I16(a $op b),
478+
(Tensor::I8(a), Tensor::I8(b)) => Tensor::I8(a $op b),
479+
(Tensor::U64(a), Tensor::U64(b)) => Tensor::U64(a $op b),
480+
(Tensor::U32(a), Tensor::U32(b)) => Tensor::U32(a $op b),
481+
(Tensor::U16(a), Tensor::U16(b)) => Tensor::U16(a $op b),
482+
(Tensor::U8(a), Tensor::U8(b)) => Tensor::U8(a $op b),
483+
_ => panic!("type mismatch in Tensor::{}", stringify!($method)),
484+
}
485+
}
486+
}
487+
impl std::ops::$trait for Tensor {
488+
type Output = Tensor;
489+
fn $method(self, rhs: Self) -> Tensor { &self $op &rhs }
490+
}
491+
};
492+
}
493+
494+
/// Like [`impl_tensor_binop!`], but omits complex variants for ops that don't support them
495+
/// (e.g. `Rem`, which `num_complex` does not implement).
496+
macro_rules! impl_tensor_binop_real {
497+
($trait:ident, $method:ident, $op:tt) => {
498+
impl std::ops::$trait for &Tensor {
499+
type Output = Tensor;
500+
fn $method(self, rhs: Self) -> Tensor {
501+
match (self, rhs) {
502+
(Tensor::F64(a), Tensor::F64(b)) => Tensor::F64(a $op b),
503+
(Tensor::F32(a), Tensor::F32(b)) => Tensor::F32(a $op b),
504+
(Tensor::I64(a), Tensor::I64(b)) => Tensor::I64(a $op b),
505+
(Tensor::I32(a), Tensor::I32(b)) => Tensor::I32(a $op b),
506+
(Tensor::I16(a), Tensor::I16(b)) => Tensor::I16(a $op b),
507+
(Tensor::I8(a), Tensor::I8(b)) => Tensor::I8(a $op b),
508+
(Tensor::U64(a), Tensor::U64(b)) => Tensor::U64(a $op b),
509+
(Tensor::U32(a), Tensor::U32(b)) => Tensor::U32(a $op b),
510+
(Tensor::U16(a), Tensor::U16(b)) => Tensor::U16(a $op b),
511+
(Tensor::U8(a), Tensor::U8(b)) => Tensor::U8(a $op b),
512+
_ => panic!("type mismatch or unsupported dtype in Tensor::{}", stringify!($method)),
513+
}
514+
}
515+
}
516+
impl std::ops::$trait for Tensor {
517+
type Output = Tensor;
518+
fn $method(self, rhs: Self) -> Tensor { &self $op &rhs }
519+
}
520+
};
521+
}
522+
523+
impl_tensor_binop!(Add, add, +);
524+
impl_tensor_binop!(Sub, sub, -);
525+
impl_tensor_binop!(Mul, mul, *);
526+
impl_tensor_binop!(Div, div, /);
527+
impl_tensor_binop_real!(Rem, rem, %);
528+
237529
#[cfg(test)]
238530
mod test {
239531
use super::*;

0 commit comments

Comments
 (0)