diff --git a/src/layout/dyn_repr.rs b/src/layout/dyn_repr.rs new file mode 100644 index 000000000..f09892d79 --- /dev/null +++ b/src/layout/dyn_repr.rs @@ -0,0 +1,188 @@ +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::iter::Cloned; +use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; +use core::ops::{Deref, DerefMut, Index, IndexMut}; +use core::slice::Iter; + +use crate::layout::rank::DynRank; +use crate::layout::ranked::Ranked; +use crate::layout::shape::Shape; +use crate::Axis; + +const CAP: usize = 4; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DynAxesRepr +{ + Inline(usize, [T; CAP]), + Alloc(Box<[T]>), +} + +/// An array shape with a dynamic rank. +pub type DShape = DynAxesRepr; + +impl Deref for DynAxesRepr +{ + type Target = [T]; + + fn deref(&self) -> &Self::Target + { + match self { + DynAxesRepr::Inline(len, arr) => { + debug_assert!(*len <= arr.len()); + unsafe { arr.get_unchecked(..*len) } + } + DynAxesRepr::Alloc(items) => items, + } + } +} + +impl DerefMut for DynAxesRepr +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + match self { + DynAxesRepr::Inline(len, arr) => { + debug_assert!(*len <= arr.len()); + unsafe { arr.get_unchecked_mut(..*len) } + } + DynAxesRepr::Alloc(items) => items, + } + } +} + +impl Index for DynAxesRepr +{ + type Output = T; + + fn index(&self, index: usize) -> &Self::Output + { + &(**self)[index] + } +} + +impl IndexMut for DynAxesRepr +{ + fn index_mut(&mut self, index: usize) -> &mut Self::Output + { + &mut (**self)[index] + } +} + +impl Index for DynAxesRepr +{ + type Output = T; + + fn index(&self, index: Axis) -> &Self::Output + { + self.index(index.0) + } +} + +impl IndexMut for DynAxesRepr +{ + fn index_mut(&mut self, index: Axis) -> &mut Self::Output + { + self.index_mut(index.0) + } +} + +impl From for DynAxesRepr +where + Rhs: AsRef<[T]>, + T: Default + Copy, +{ + fn from(value: Rhs) -> Self + { + let value = value.as_ref(); + let n = value.len(); + if n <= CAP { + let mut inline = [T::default(); CAP]; + inline.split_at_mut(n).0.copy_from_slice(value); + Self::Inline(n, inline) + } else { + Self::Alloc(value.into()) + } + } +} + +impl Ranked for DynAxesRepr +{ + type NDim = DynRank; + + fn ndim(&self) -> usize + { + match self { + DynAxesRepr::Inline(d, _) => d.clone(), + DynAxesRepr::Alloc(items) => items.len(), + } + } +} + +macro_rules! impl_op { + ($op_trait:ty, $op_fn:ident, $op_assign_trait:ty, $op_assign_fn:ident) => { + /// *Panics* if the two dimensionalities are different + impl $op_trait for DShape + where + Rhs: Into, + { + type Output = DShape; + + fn $op_fn(self, rhs: Rhs) -> ::Output { + let mut output = self.clone(); + output.$op_assign_fn(rhs); + output + } + } + + /// *Panics* if the two dimensionalities are different + impl $op_assign_trait for DShape + where + Rhs: Into, + { + fn $op_assign_fn(&mut self, rhs: Rhs) { + let other = rhs.into(); + for i in 0..self.ndim().max(other.ndim()) { + self[i].$op_assign_fn(other[i]); + } + } + } + }; +} + +impl_op!(Add, add, AddAssign, add_assign); +impl_op!(Sub, sub, SubAssign, sub_assign); +impl_op!(Mul, mul, MulAssign, mul_assign); + +impl IntoIterator for DynAxesRepr +where T: Clone +{ + type Item = T; + type IntoIter = alloc::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter + { + match self { + DynAxesRepr::Inline(len, arr) => Vec::from(arr[..len].to_vec()).into_iter(), + DynAxesRepr::Alloc(b) => b.into_vec().into_iter(), + } + } +} + +impl Shape for DynAxesRepr +{ + type Iter<'a> + = Cloned> + where Self: 'a; + + fn axis_len(&self, axis: usize) -> usize + { + self[axis] + } + + fn iter(&self) -> Self::Iter<'_> + { + (**self).iter().cloned() + } +} diff --git a/src/layout/mod.rs b/src/layout/mod.rs index 9f5d1e4e1..fa2358873 100644 --- a/src/layout/mod.rs +++ b/src/layout/mod.rs @@ -11,6 +11,46 @@ mod bitset; pub mod rank; pub mod ranked; +mod shape; +mod n_repr; +mod dyn_repr; + +use core::any::type_name; +use core::error::Error; +use core::fmt::{Debug, Display}; +use core::marker::PhantomData; + +use crate::layout::ranked::Ranked; #[allow(deprecated)] pub use bitset::{Layout, LayoutBitset}; +pub use dyn_repr::DShape; +pub use n_repr::NShape; +pub use shape::Shape; + +/// The error type for dealing with shapes and strides +#[derive(Debug, Clone, Copy)] +pub enum ShapeStrideError +{ + /// Out of bounds; specifically, using an index that is larger than the dimensionality of the shape or strides `S`. + OutOfBounds(PhantomData, usize), + /// The error when trying to construct or mutate a shape or strides with the wrong dimensionality value. + RankMismatch(PhantomData, usize), + /// The desired shape would represent an array with more elements than `isize::MAX` + ShapeOverflow, +} + +impl Display for ShapeStrideError +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + { + match self { + ShapeStrideError::OutOfBounds(_, idx) => + write!(f, "Index {idx} is larger than the dimensionality of {}", type_name::()), + ShapeStrideError::RankMismatch(_, rank) => write!(f, "{} has a rank of {}, which is incompatible with requested rank of {rank}", type_name::(), type_name::()), + ShapeStrideError::ShapeOverflow => write!(f, "The desired shape would represent an array with more elements than `usize::MAX`") + } + } +} + +impl Error for ShapeStrideError {} diff --git a/src/layout/n_repr.rs b/src/layout/n_repr.rs new file mode 100644 index 000000000..a29d932bf --- /dev/null +++ b/src/layout/n_repr.rs @@ -0,0 +1,288 @@ +use core::{ + array::IntoIter, + iter::Cloned, + marker::PhantomData, + ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Sub, SubAssign}, + slice::Iter, +}; + +use num_traits::Zero; + +use crate::layout::{ + rank::{ConstRank, Rank}, + ranked::Ranked, + shape::Shape, + ShapeStrideError, +}; + +/// A wrapper for fixed-length arrays that can be used for shape and strides. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ShapeStrideN([T; N]); + +/// An array shape with a constant rank. +pub type NShape = ShapeStrideN; + +impl Deref for ShapeStrideN +{ + type Target = [T; N]; + + fn deref(&self) -> &Self::Target + { + &self.0 + } +} + +impl DerefMut for ShapeStrideN +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + &mut self.0 + } +} + +impl Index for ShapeStrideN +{ + type Output = T; + + fn index(&self, index: usize) -> &Self::Output + { + &self.0[index] + } +} + +impl IndexMut for ShapeStrideN +{ + fn index_mut(&mut self, index: usize) -> &mut Self::Output + { + &mut self.0[index] + } +} + +impl From<[T; N]> for ShapeStrideN +{ + fn from(value: [T; N]) -> Self + { + Self { 0: value } + } +} + +impl From> for [T; N] +{ + fn from(value: ShapeStrideN) -> Self + { + value.0 + } +} + +impl<'a, T, const N: usize> TryFrom<&'a [T]> for ShapeStrideN +where T: Copy + Zero +{ + type Error = ShapeStrideError; + + fn try_from(value: &'a [T]) -> Result + { + if value.len() != N { + Err(ShapeStrideError::RankMismatch(PhantomData, value.len())) + } else { + let mut arr = [T::zero(); N]; + arr.copy_from_slice(value); + Ok(Self { 0: arr }) + } + } +} + +impl IntoIterator for ShapeStrideN +{ + type Item = T; + + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter + { + self.0.into_iter() + } +} + +macro_rules! shapestride_and_tuples { + ($(($tuple:ty, $N:literal)),*) => { + $( + impl From<$tuple> for ShapeStrideN + { + fn from(value: $tuple) -> Self + { + Self { 0: value.into() } + } + } + + impl From> for $tuple + { + fn from(value: ShapeStrideN) -> Self + { + value.0.into() + } + } + )* + }; +} + +shapestride_and_tuples!( + ((T,), 1), + ((T, T), 2), + ((T, T, T), 3), + ((T, T, T, T), 4), + ((T, T, T, T, T), 5), + ((T, T, T, T, T, T), 6), + ((T, T, T, T, T, T, T), 7), + ((T, T, T, T, T, T, T, T), 8), + ((T, T, T, T, T, T, T, T, T), 9), + ((T, T, T, T, T, T, T, T, T, T), 10), + ((T, T, T, T, T, T, T, T, T, T, T), 11), + ((T, T, T, T, T, T, T, T, T, T, T, T), 12) +); + +impl PartialEq for ShapeStrideN +where + T: PartialEq, + Rhs: AsRef<[T]>, +{ + fn eq(&self, other: &Rhs) -> bool + { + let other = other.as_ref(); + if other.len() != N { + return false; + } + for i in 0..N { + if self[i] != other[i] { + return false; + } + } + return true; + } +} + +impl Ranked for ShapeStrideN +where ConstRank: Rank +{ + type NDim = ConstRank; + + fn ndim(&self) -> usize + { + N + } +} + +impl Zero for NShape +{ + fn zero() -> Self + { + Self([usize::zero(); N]) + } + + fn is_zero(&self) -> bool + { + self.0.iter().all(|e| *e == 0usize) + } +} + +macro_rules! impl_op { + ($op_trait:ty, $op_fn:ident, $op_assign_trait:ty, $op_assign_fn:ident) => { + impl $op_trait for NShape + where + Rhs: Into>, + { + type Output = NShape; + + fn $op_fn(self, rhs: Rhs) -> Self::Output { + let mut output = self.clone(); + output.$op_assign_fn(rhs); + output + } + } + + impl $op_assign_trait for NShape + where + Rhs: Into>, + { + fn $op_assign_fn(&mut self, rhs: Rhs) { + let other = rhs.into(); + for i in 0..N { + self[i].$op_assign_fn(other[i]); + } + } + } + }; +} + +impl_op!(Add, add, AddAssign, add_assign); +impl_op!(Sub, sub, SubAssign, sub_assign); +impl_op!(Mul, mul, MulAssign, mul_assign); + +impl Add for NShape +{ + type Output = NShape; + + fn add(self, rhs: usize) -> Self::Output + { + let mut output = self.clone(); + output += rhs; + output + } +} + +impl AddAssign for NShape +{ + fn add_assign(&mut self, rhs: usize) + { + for o in self.iter_mut() { + *o += rhs; + } + } +} + +impl Mul for NShape +{ + type Output = NShape; + + fn mul(self, rhs: usize) -> Self::Output + { + let mut output = self.clone(); + for o in output.iter_mut() { + *o = *o * rhs; + } + output + } +} + +impl MulAssign for NShape +{ + fn mul_assign(&mut self, rhs: usize) + { + for o in self.iter_mut() { + *o += rhs; + } + } +} + +impl Shape for NShape +where ConstRank: Rank +{ + type Iter<'a> = Cloned>; + + fn axis_len(&self, axis: usize) -> usize + { + self.0[axis] + } + + fn iter(&self) -> Self::Iter<'_> + { + self.0.iter().cloned() + } +} + +impl Default for NShape +{ + fn default() -> Self + { + Self::zero() + } +} diff --git a/src/layout/shape.rs b/src/layout/shape.rs new file mode 100644 index 000000000..0207549e3 --- /dev/null +++ b/src/layout/shape.rs @@ -0,0 +1,73 @@ +use core::iter::Map; + +use crate::layout::ranked::Ranked; + +/// A trait for array shapes: lists of `usize` describing the length of each dimension of an array. +/// +/// ## Size Limits +/// Arrays cannot have more than [`usize::MAX`] elements; otherwise, it would not be possible +/// to address all of the elements in-memory. In addition, since an array may have negative strides +/// (i.e., elements _behind_ the current pointer in memory), arrays must be addressable using +/// [`isize`]. So no array can have more than [`isize::MAX`] elements. Finally, for multi-byte +/// element types, the total byte size of the array also cannot exceed `isize::MAX`. +/// +/// Implementing `Shape` for a type does not guarantee that the type will always represent an +/// array that adheres to these size limits. It does, however, provide access to two methods that +/// can cheaply check these invariants: [`Shape::size_checked`] and [`Shape::size_bytes_checked`]. +/// In order for an instance of a `Shape` to be valid, both of these methods must return `Some(_)`. +/// +/// ## Mutability +/// The `Shape` trait does not provide any sort of mutability for the lengths of each axis. +/// This allows users to define constant-sized shapes, which can significantly increase performance. +/// +/// Since `Shape` is still experimental, the mechanisms for mutability are still being designed. +pub trait Shape: Ranked +{ + /// The iterator type over the dimensions of the shape. + type Iter<'a>: Iterator + ExactSizeIterator + DoubleEndedIterator + where Self: 'a; + + /// The length of the array along a given axis. + fn axis_len(&self, axis: usize) -> usize; + + /// Iterate over the dimensions of the shape. + fn iter(&self) -> Self::Iter<'_>; + + /// Get the number of elements that the array contains. + /// + /// If the number of elements is greater than `isize::MAX`, returns `None`. + fn size_checked(&self) -> Option + { + self.iter() + .try_fold(1_usize, |acc, i| acc.checked_mul(i)) + .and_then(as_usize_if_isize_compatible) + } + + /// Get the number of bytes that this array would fill. + /// + /// This method checks for bytes overflow past `isize`. If this method returns `Some(_)`, + /// then users know that an allocated array is indexable using `isize` offsets. + fn size_bytes_checked(&self) -> Option + { + self.size_checked() + .and_then(|v| v.checked_mul(size_of::())) + .and_then(as_usize_if_isize_compatible) + } + + /// Iterate over the shape as `isize`. + /// + /// If the number of elements is greater than `isize::MAX`, returns `None`. + fn iter_isize<'a>(&'a self) -> Option, impl FnMut(usize) -> isize>> + { + self.size_checked().map(|_| self.iter().map(|v| v as isize)) + } +} + +fn as_usize_if_isize_compatible(v: usize) -> Option +{ + if v <= (isize::MAX as usize) { + Some(v) + } else { + None + } +}