Skip to content

Commit

Permalink
Improve slow-path performance for unary ops
Browse files Browse the repository at this point in the history
Replace iterators with a pattern that uses a fixed number of nested loops.
The same approach was previously applied to binary and ternary ops.

Part of #192.
  • Loading branch information
robertknight committed May 31, 2024
1 parent 414e1a0 commit 146ef2a
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 7 deletions.
42 changes: 42 additions & 0 deletions rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,48 @@ pub fn copy_into_uninit<T: Clone>(mut src: TensorView<T>, mut dest: TensorViewMu
});
}

/// Apply `f` to each element of `src` and write the output to `dest` in
/// contiguous order.
pub fn map_into_slice<T, R, F: Fn(&T) -> R>(
mut src: TensorView<T>,
dest: &mut [MaybeUninit<R>],
f: F,
) {
assert!(src.len() == dest.len());

while src.ndim() < 4 {
src.insert_axis(0);
}

// This would benefit from the same optimizations that `copy_into_slice` has
// for eg. transposed inputs, preferably without generating a ton of
// duplicate code for each map function `F`.

let mut out_offset = 0;
src.inner_iter::<4>().for_each(|src| {
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
for i2 in 0..src.size(2) {
for i3 in 0..src.size(3) {
// Safety: i0..i3 are in `[0, src.size(i))`.
let x = unsafe { src.get_unchecked([i0, i1, i2, i3]) };
let y = f(x);

// Safety: We write to `src.len()` successive output
// elements, and `src` and `dest` have the same length.
unsafe {
dest.get_unchecked_mut(out_offset).write(y);
}
out_offset += 1;
}
}
}
}
});

debug_assert!(out_offset == src.len());
}

/// Copy a slice of `src` specified by `ranges` into `dest` in contiguous order.
pub fn copy_range_into_slice<T: Clone>(
mut src: TensorView<T>,
Expand Down
26 changes: 26 additions & 0 deletions rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,32 @@ impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> {
}
}

/// Call `f` on each element of `view`.
pub fn for_each_mut<T, F: Fn(&mut T)>(mut view: TensorViewMut<T>, f: F) {
while view.ndim() < 4 {
view.insert_axis(0);
}

// This could be improved by sorting dimensions of `view` in order of
// decreasing stride. If the resulting view is contiguous, `f` can be
// applied to the underlying data directly. Even if it isn't, this will
// still make memory access as contiguous as possible.

view.inner_iter_mut::<4>().for_each(|mut src| {
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
for i2 in 0..src.size(2) {
for i3 in 0..src.size(3) {
// Safety: i0..i3 are in `[0, src.size(i))`.
let x = unsafe { src.get_unchecked_mut([i0, i1, i2, i3]) };
f(x);
}
}
}
}
});
}

// Tests for iterator internals. Most tests of iterators are currently done via
// tests on tensor methods.
#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion rten-tensor/src/overlap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::iter::zip;
use smallvec::SmallVec;

/// Return true if a given shape and strides describe a contiguous layout in
/// "C" order.
/// row-major ("C") order.
pub fn is_contiguous<S: AsRef<[usize]>>(shape: S, strides: S) -> bool {
// Trim leading 1s from the shape. These dimensions can have a larger
// stride than the product of inner dimensions without affecting whether
Expand Down
21 changes: 15 additions & 6 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::borrow::Cow;
use std::mem::MaybeUninit;
use std::ops::{Index, IndexMut, Range};

use crate::copy::{copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice};
use crate::copy::{
copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice, map_into_slice,
};
use crate::errors::{DimensionError, FromDataError, SliceError};
use crate::iterators::{
AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterDyn, InnerIterDynMut,
InnerIterMut, Iter, IterMut, Lanes, LanesMut, MutViewRef, ViewRef,
for_each_mut, AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterDyn,
InnerIterDynMut, InnerIterMut, Iter, IterMut, Lanes, LanesMut, MutViewRef, ViewRef,
};
use crate::layout::{
AsIndex, BroadcastLayout, DynLayout, IntoLayout, Layout, MatrixLayout, MutLayout, NdLayout,
Expand Down Expand Up @@ -519,7 +521,7 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {
// Fast path for contiguous tensors.
data.iter_mut().for_each(|x| *x = f(x));
} else {
self.iter_mut().for_each(|x| *x = f(x));
for_each_mut(self.as_dyn_mut(), |x| *x = f(x));
}
}

Expand Down Expand Up @@ -1526,12 +1528,19 @@ impl<T, S: Storage<Elem = T>, L: MutLayout + Clone> AsView for TensorBase<S, L>
where
F: Fn(&Self::Elem) -> U,
{
let mut buf = alloc.alloc(self.len());
let len = self.len();
let mut buf = alloc.alloc(len);
if let Some(data) = self.data() {
// Fast path for contiguous tensors.
buf.extend(data.iter().map(f));
} else {
buf.extend(self.iter().map(f));
let dest = &mut buf.spare_capacity_mut()[..len];
map_into_slice(self.as_dyn(), dest, f);

// Safety: `map_into` initialized all elements of `dest`.
unsafe {
buf.set_len(len);
}
};
TensorBase::from_data(self.shape(), buf)
}
Expand Down

0 comments on commit 146ef2a

Please sign in to comment.