1010// copyright notice, and modified files need to carry a notice indicating
1111// that they have been altered from the originals.
1212
13- use ndarray:: ArrayD ;
13+ use ndarray:: { ArrayD , IxDyn , Zip } ;
1414use num_complex:: Complex ;
1515use std:: fmt;
1616
@@ -19,13 +19,13 @@ use std::fmt;
1919pub enum DType {
2020 C128 , // complex
2121 C64 ,
22- F64 , // float
22+ F64 , // real
2323 F32 ,
24- I64 , // signed ints
24+ I64 , // signed integer
2525 I32 ,
2626 I16 ,
2727 I8 ,
28- U64 , // unsigned ints
28+ U64 , // unsigned integer
2929 U32 ,
3030 U16 ,
3131 U8 ,
@@ -53,9 +53,10 @@ impl fmt::Display for DType {
5353 }
5454}
5555
56- // A tensor data type whose value is yet unknown, but named .
56+ /// A tensor dtype that is unknown but identified by name .
5757#[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
5858pub struct DTypeVar {
59+ /// The variable name.
5960 pub name : String ,
6061}
6162
@@ -68,6 +69,7 @@ impl<T: Into<String>> From<T> for DTypeVar {
6869/// A tensor data type whose value is yet unknown, but will be the promotion of others.
6970#[ derive( Debug , Clone ) ]
7071pub struct DTypePromotion {
72+ /// The dtype arguments to promote over.
7173 pub args : Vec < DTypeLike > ,
7274}
7375
@@ -80,14 +82,17 @@ impl<T: Into<Vec<DTypeLike>>> From<T> for DTypePromotion {
8082/// A tensor data type, known or unknown.
8183#[ derive( Debug , Clone ) ]
8284pub enum DTypeLike {
85+ /// A fully resolved dtype.
8386 Concrete ( DType ) ,
87+ /// A dtype identified by a variable name, to be resolved later.
8488 Var ( DTypeVar ) ,
89+ /// A dtype that is the promotion of one or more other dtypes.
8590 Promotion ( DTypePromotion ) ,
8691}
8792
8893/// Promote a pair of DTypes to the smallest type compatible with both.
8994///
90- /// QuantumProgram operations often, but not necessarily, use this promotion rule
95+ /// QuantumProgram nodes often, but not necessarily, use this promotion rule
9196/// to determine their output type.
9297///
9398/// This function implements the same promotion rules as NumPy, modulo that we don't
@@ -203,7 +208,7 @@ pub struct TensorType {
203208}
204209
205210impl TensorType {
206- // Return a dimension vector if all are sizes are fixed, None otherwise .
211+ /// Return a dimension vector if all sizes are fixed, or ` None` if any are named .
207212 pub fn concrete_shape ( & self ) -> Option < Vec < usize > > {
208213 let mut out = Vec :: with_capacity ( self . shape . len ( ) ) ;
209214 for d in & self . shape {
@@ -219,21 +224,308 @@ impl TensorType {
219224/// A tensor of one of the supported dtypes.
220225#[ derive( Debug , Clone ) ]
221226pub enum Tensor {
222- C64 ( ArrayD < Complex < f32 > > ) ,
227+ C64 ( ArrayD < Complex < f32 > > ) , // complex
223228 C128 ( ArrayD < Complex < f64 > > ) ,
224- F32 ( ArrayD < f32 > ) ,
229+ F32 ( ArrayD < f32 > ) , // real
225230 F64 ( ArrayD < f64 > ) ,
226- I8 ( ArrayD < i8 > ) ,
231+ I8 ( ArrayD < i8 > ) , // signed integer
227232 I16 ( ArrayD < i16 > ) ,
228233 I32 ( ArrayD < i32 > ) ,
229234 I64 ( ArrayD < i64 > ) ,
230- U8 ( ArrayD < u8 > ) ,
235+ U8 ( ArrayD < u8 > ) , // unsigned integer
231236 U16 ( ArrayD < u16 > ) ,
232237 U32 ( ArrayD < u32 > ) ,
233238 U64 ( ArrayD < u64 > ) ,
234- Bit ( ArrayD < u8 > ) ,
239+ Bit ( ArrayD < u8 > ) , // bool
235240}
236241
242+ /// Cast an `ArrayD` of a real numeric type to any supported dtype.
243+ macro_rules! cast_real {
244+ ( $arr: expr, $src: ty, $target: expr) => {
245+ match $target {
246+ DType :: Bit => Tensor :: Bit ( $arr. mapv( |x: $src| x as u8 ) ) ,
247+ DType :: U8 => Tensor :: U8 ( $arr. mapv( |x: $src| x as u8 ) ) ,
248+ DType :: U16 => Tensor :: U16 ( $arr. mapv( |x: $src| x as u16 ) ) ,
249+ DType :: U32 => Tensor :: U32 ( $arr. mapv( |x: $src| x as u32 ) ) ,
250+ DType :: U64 => Tensor :: U64 ( $arr. mapv( |x: $src| x as u64 ) ) ,
251+ DType :: I8 => Tensor :: I8 ( $arr. mapv( |x: $src| x as i8 ) ) ,
252+ DType :: I16 => Tensor :: I16 ( $arr. mapv( |x: $src| x as i16 ) ) ,
253+ DType :: I32 => Tensor :: I32 ( $arr. mapv( |x: $src| x as i32 ) ) ,
254+ DType :: I64 => Tensor :: I64 ( $arr. mapv( |x: $src| x as i64 ) ) ,
255+ DType :: F32 => Tensor :: F32 ( $arr. mapv( |x: $src| x as f32 ) ) ,
256+ DType :: F64 => Tensor :: F64 ( $arr. mapv( |x: $src| x as f64 ) ) ,
257+ DType :: C64 => Tensor :: C64 ( $arr. mapv( |x: $src| Complex :: new( x as f32 , 0.0 ) ) ) ,
258+ DType :: C128 => Tensor :: C128 ( $arr. mapv( |x: $src| Complex :: new( x as f64 , 0.0 ) ) ) ,
259+ }
260+ } ;
261+ }
262+
263+ /// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets).
264+ macro_rules! cast_complex {
265+ ( $arr: expr, $target: expr) => {
266+ match $target {
267+ DType :: C64 => Tensor :: C64 ( $arr. mapv( |x| Complex :: new( x. re as f32 , x. im as f32 ) ) ) ,
268+ DType :: C128 => Tensor :: C128 ( $arr. mapv( |x| Complex :: new( x. re as f64 , x. im as f64 ) ) ) ,
269+ _ => panic!( "cannot cast complex tensor to a real dtype" ) ,
270+ }
271+ } ;
272+ }
273+
274+ /// Element-wise binary operation on two arrays with NumPy-style broadcasting.
275+ ///
276+ /// Unlike ndarray's built-in arithmetic operators which handle broadcasting automatically,
277+ /// this helper is needed for operations without a Rust operator (e.g. `pow`).
278+ fn broadcast_elementwise < T , F > ( a : & ArrayD < T > , b : & ArrayD < T > , op : F ) -> ArrayD < T >
279+ where
280+ T : Clone ,
281+ F : Fn ( & T , & T ) -> T ,
282+ {
283+ let ndim = a. ndim ( ) . max ( b. ndim ( ) ) ;
284+ let out_shape: Vec < usize > = ( 0 ..ndim)
285+ . map ( |i| {
286+ let d_a = if i >= ndim - a. ndim ( ) {
287+ a. shape ( ) [ i - ( ndim - a. ndim ( ) ) ]
288+ } else {
289+ 1
290+ } ;
291+ let d_b = if i >= ndim - b. ndim ( ) {
292+ b. shape ( ) [ i - ( ndim - b. ndim ( ) ) ]
293+ } else {
294+ 1
295+ } ;
296+ match ( d_a, d_b) {
297+ ( x, y) if x == y => x,
298+ ( 1 , y) => y,
299+ ( x, 1 ) => x,
300+ _ => panic ! (
301+ "shapes {:?} and {:?} are not broadcast-compatible" ,
302+ a. shape( ) ,
303+ b. shape( )
304+ ) ,
305+ }
306+ } )
307+ . collect ( ) ;
308+ let out_ix = IxDyn ( & out_shape) ;
309+ let a_bc = a. broadcast ( out_ix. clone ( ) ) . expect ( "broadcast failed" ) ;
310+ let b_bc = b. broadcast ( out_ix) . expect ( "broadcast failed" ) ;
311+ Zip :: from ( a_bc) . and ( b_bc) . map_collect ( op)
312+ }
313+
314+ impl Tensor {
315+ /// Return the dtype of this tensor.
316+ pub fn dtype ( & self ) -> DType {
317+ match self {
318+ Tensor :: C128 ( _) => DType :: C128 ,
319+ Tensor :: C64 ( _) => DType :: C64 ,
320+ Tensor :: F64 ( _) => DType :: F64 ,
321+ Tensor :: F32 ( _) => DType :: F32 ,
322+ Tensor :: I64 ( _) => DType :: I64 ,
323+ Tensor :: I32 ( _) => DType :: I32 ,
324+ Tensor :: I16 ( _) => DType :: I16 ,
325+ Tensor :: I8 ( _) => DType :: I8 ,
326+ Tensor :: U64 ( _) => DType :: U64 ,
327+ Tensor :: U32 ( _) => DType :: U32 ,
328+ Tensor :: U16 ( _) => DType :: U16 ,
329+ Tensor :: U8 ( _) => DType :: U8 ,
330+ Tensor :: Bit ( _) => DType :: Bit ,
331+ }
332+ }
333+
334+ /// Return the shape of this tensor as a slice of dimension sizes.
335+ pub fn shape ( & self ) -> & [ usize ] {
336+ match self {
337+ Tensor :: C128 ( a) => a. shape ( ) ,
338+ Tensor :: C64 ( a) => a. shape ( ) ,
339+ Tensor :: F64 ( a) => a. shape ( ) ,
340+ Tensor :: F32 ( a) => a. shape ( ) ,
341+ Tensor :: I64 ( a) => a. shape ( ) ,
342+ Tensor :: I32 ( a) => a. shape ( ) ,
343+ Tensor :: I16 ( a) => a. shape ( ) ,
344+ Tensor :: I8 ( a) => a. shape ( ) ,
345+ Tensor :: U64 ( a) => a. shape ( ) ,
346+ Tensor :: U32 ( a) => a. shape ( ) ,
347+ Tensor :: U16 ( a) => a. shape ( ) ,
348+ Tensor :: U8 ( a) => a. shape ( ) ,
349+ Tensor :: Bit ( a) => a. shape ( ) ,
350+ }
351+ }
352+
353+ /// Return the [`TensorType`] that describes this tensor's dtype and concrete shape.
354+ pub fn tensor_type ( & self ) -> TensorType {
355+ TensorType {
356+ dtype : DTypeLike :: Concrete ( self . dtype ( ) ) ,
357+ shape : self . shape ( ) . iter ( ) . map ( |& n| Dim :: Fixed ( n) ) . collect ( ) ,
358+ broadcastable : false ,
359+ }
360+ }
361+
362+ /// Element-wise power with NumPy-style broadcasting.
363+ ///
364+ /// For integer types the exponent is cast to `u32`; negative integer exponents
365+ /// are not supported.
366+ pub fn pow ( & self , rhs : & Tensor ) -> Tensor {
367+ match ( self , rhs) {
368+ ( Tensor :: F32 ( a) , Tensor :: F32 ( b) ) => {
369+ Tensor :: F32 ( broadcast_elementwise ( a, b, |& x, & y| x. powf ( y) ) )
370+ }
371+ ( Tensor :: F64 ( a) , Tensor :: F64 ( b) ) => {
372+ Tensor :: F64 ( broadcast_elementwise ( a, b, |& x, & y| x. powf ( y) ) )
373+ }
374+ ( Tensor :: C64 ( a) , Tensor :: C64 ( b) ) => {
375+ Tensor :: C64 ( broadcast_elementwise ( a, b, |& x, & y| x. powc ( y) ) )
376+ }
377+ ( Tensor :: C128 ( a) , Tensor :: C128 ( b) ) => {
378+ Tensor :: C128 ( broadcast_elementwise ( a, b, |& x, & y| x. powc ( y) ) )
379+ }
380+ ( Tensor :: I8 ( a) , Tensor :: I8 ( b) ) => {
381+ Tensor :: I8 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y as u32 ) ) )
382+ }
383+ ( Tensor :: I16 ( a) , Tensor :: I16 ( b) ) => {
384+ Tensor :: I16 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y as u32 ) ) )
385+ }
386+ ( Tensor :: I32 ( a) , Tensor :: I32 ( b) ) => {
387+ Tensor :: I32 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y as u32 ) ) )
388+ }
389+ ( Tensor :: I64 ( a) , Tensor :: I64 ( b) ) => {
390+ Tensor :: I64 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y as u32 ) ) )
391+ }
392+ ( Tensor :: U8 ( a) , Tensor :: U8 ( b) ) => {
393+ Tensor :: U8 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y as u32 ) ) )
394+ }
395+ ( Tensor :: U16 ( a) , Tensor :: U16 ( b) ) => {
396+ Tensor :: U16 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y as u32 ) ) )
397+ }
398+ ( Tensor :: U32 ( a) , Tensor :: U32 ( b) ) => {
399+ Tensor :: U32 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y) ) )
400+ }
401+ ( Tensor :: U64 ( a) , Tensor :: U64 ( b) ) => {
402+ Tensor :: U64 ( broadcast_elementwise ( a, b, |& x, & y| x. pow ( y as u32 ) ) )
403+ }
404+ _ => panic ! ( "type mismatch in Tensor::pow" ) ,
405+ }
406+ }
407+
408+ /// Cast this tensor to `target`, consuming it. Returns `self` unchanged if already that dtype.
409+ pub fn cast ( self , target : DType ) -> Tensor {
410+ if self . dtype ( ) == target {
411+ return self ;
412+ }
413+ match & self {
414+ Tensor :: Bit ( a) | Tensor :: U8 ( a) => cast_real ! ( a, u8 , target) ,
415+ Tensor :: U16 ( a) => cast_real ! ( a, u16 , target) ,
416+ Tensor :: U32 ( a) => cast_real ! ( a, u32 , target) ,
417+ Tensor :: U64 ( a) => cast_real ! ( a, u64 , target) ,
418+ Tensor :: I8 ( a) => cast_real ! ( a, i8 , target) ,
419+ Tensor :: I16 ( a) => cast_real ! ( a, i16 , target) ,
420+ Tensor :: I32 ( a) => cast_real ! ( a, i32 , target) ,
421+ Tensor :: I64 ( a) => cast_real ! ( a, i64 , target) ,
422+ Tensor :: F32 ( a) => cast_real ! ( a, f32 , target) ,
423+ Tensor :: F64 ( a) => cast_real ! ( a, f64 , target) ,
424+ Tensor :: C64 ( a) => cast_complex ! ( a, target) ,
425+ Tensor :: C128 ( a) => cast_complex ! ( a, target) ,
426+ }
427+ }
428+ }
429+
430+ /// Implement `From<&[T]>`, `From<&[T; N]>`, and `From<ArrayD<T>>` for a given `Tensor` variant.
431+ macro_rules! impl_tensor_from {
432+ ( $variant: ident, $t: ty) => {
433+ impl From <& [ $t] > for Tensor {
434+ fn from( data: & [ $t] ) -> Self {
435+ Tensor :: $variant( ndarray:: arr1( data) . into_dyn( ) )
436+ }
437+ }
438+ impl <const N : usize > From <[ $t; N ] > for Tensor {
439+ fn from( data: [ $t; N ] ) -> Self {
440+ Tensor :: $variant( ndarray:: arr1( & data) . into_dyn( ) )
441+ }
442+ }
443+ impl From <ArrayD <$t>> for Tensor {
444+ fn from( data: ArrayD <$t>) -> Self {
445+ Tensor :: $variant( data)
446+ }
447+ }
448+ } ;
449+ }
450+
451+ impl_tensor_from ! ( C128 , Complex <f64 >) ;
452+ impl_tensor_from ! ( C64 , Complex <f32 >) ;
453+ impl_tensor_from ! ( F64 , f64 ) ;
454+ impl_tensor_from ! ( F32 , f32 ) ;
455+ impl_tensor_from ! ( I64 , i64 ) ;
456+ impl_tensor_from ! ( I32 , i32 ) ;
457+ impl_tensor_from ! ( I16 , i16 ) ;
458+ impl_tensor_from ! ( I8 , i8 ) ;
459+ impl_tensor_from ! ( U64 , u64 ) ;
460+ impl_tensor_from ! ( U32 , u32 ) ;
461+ impl_tensor_from ! ( U16 , u16 ) ;
462+ impl_tensor_from ! ( U8 , u8 ) ; // u8 → U8; Bit requires explicit construction
463+
464+ /// Implement a standard Rust binary operator trait for `Tensor` and `&Tensor`.
465+ macro_rules! impl_tensor_binop {
466+ ( $trait: ident, $method: ident, $op: tt) => {
467+ impl std:: ops:: $trait for & Tensor {
468+ type Output = Tensor ;
469+ fn $method( self , rhs: Self ) -> Tensor {
470+ match ( self , rhs) {
471+ ( Tensor :: C128 ( a) , Tensor :: C128 ( b) ) => Tensor :: C128 ( a $op b) ,
472+ ( Tensor :: C64 ( a) , Tensor :: C64 ( b) ) => Tensor :: C64 ( a $op b) ,
473+ ( Tensor :: F64 ( a) , Tensor :: F64 ( b) ) => Tensor :: F64 ( a $op b) ,
474+ ( Tensor :: F32 ( a) , Tensor :: F32 ( b) ) => Tensor :: F32 ( a $op b) ,
475+ ( Tensor :: I64 ( a) , Tensor :: I64 ( b) ) => Tensor :: I64 ( a $op b) ,
476+ ( Tensor :: I32 ( a) , Tensor :: I32 ( b) ) => Tensor :: I32 ( a $op b) ,
477+ ( Tensor :: I16 ( a) , Tensor :: I16 ( b) ) => Tensor :: I16 ( a $op b) ,
478+ ( Tensor :: I8 ( a) , Tensor :: I8 ( b) ) => Tensor :: I8 ( a $op b) ,
479+ ( Tensor :: U64 ( a) , Tensor :: U64 ( b) ) => Tensor :: U64 ( a $op b) ,
480+ ( Tensor :: U32 ( a) , Tensor :: U32 ( b) ) => Tensor :: U32 ( a $op b) ,
481+ ( Tensor :: U16 ( a) , Tensor :: U16 ( b) ) => Tensor :: U16 ( a $op b) ,
482+ ( Tensor :: U8 ( a) , Tensor :: U8 ( b) ) => Tensor :: U8 ( a $op b) ,
483+ _ => panic!( "type mismatch in Tensor::{}" , stringify!( $method) ) ,
484+ }
485+ }
486+ }
487+ impl std:: ops:: $trait for Tensor {
488+ type Output = Tensor ;
489+ fn $method( self , rhs: Self ) -> Tensor { & self $op & rhs }
490+ }
491+ } ;
492+ }
493+
494+ /// Like [`impl_tensor_binop!`], but omits complex variants for ops that don't support them
495+ /// (e.g. `Rem`, which `num_complex` does not implement).
496+ macro_rules! impl_tensor_binop_real {
497+ ( $trait: ident, $method: ident, $op: tt) => {
498+ impl std:: ops:: $trait for & Tensor {
499+ type Output = Tensor ;
500+ fn $method( self , rhs: Self ) -> Tensor {
501+ match ( self , rhs) {
502+ ( Tensor :: F64 ( a) , Tensor :: F64 ( b) ) => Tensor :: F64 ( a $op b) ,
503+ ( Tensor :: F32 ( a) , Tensor :: F32 ( b) ) => Tensor :: F32 ( a $op b) ,
504+ ( Tensor :: I64 ( a) , Tensor :: I64 ( b) ) => Tensor :: I64 ( a $op b) ,
505+ ( Tensor :: I32 ( a) , Tensor :: I32 ( b) ) => Tensor :: I32 ( a $op b) ,
506+ ( Tensor :: I16 ( a) , Tensor :: I16 ( b) ) => Tensor :: I16 ( a $op b) ,
507+ ( Tensor :: I8 ( a) , Tensor :: I8 ( b) ) => Tensor :: I8 ( a $op b) ,
508+ ( Tensor :: U64 ( a) , Tensor :: U64 ( b) ) => Tensor :: U64 ( a $op b) ,
509+ ( Tensor :: U32 ( a) , Tensor :: U32 ( b) ) => Tensor :: U32 ( a $op b) ,
510+ ( Tensor :: U16 ( a) , Tensor :: U16 ( b) ) => Tensor :: U16 ( a $op b) ,
511+ ( Tensor :: U8 ( a) , Tensor :: U8 ( b) ) => Tensor :: U8 ( a $op b) ,
512+ _ => panic!( "type mismatch or unsupported dtype in Tensor::{}" , stringify!( $method) ) ,
513+ }
514+ }
515+ }
516+ impl std:: ops:: $trait for Tensor {
517+ type Output = Tensor ;
518+ fn $method( self , rhs: Self ) -> Tensor { & self $op & rhs }
519+ }
520+ } ;
521+ }
522+
523+ impl_tensor_binop ! ( Add , add, +) ;
524+ impl_tensor_binop ! ( Sub , sub, -) ;
525+ impl_tensor_binop ! ( Mul , mul, * ) ;
526+ impl_tensor_binop ! ( Div , div, /) ;
527+ impl_tensor_binop_real ! ( Rem , rem, %) ;
528+
237529#[ cfg( test) ]
238530mod test {
239531 use super :: * ;
0 commit comments