Skip to content

Commit

Permalink
Make packing of GEMM inputs more flexible
Browse files Browse the repository at this point in the history
Previously the GEMM code assumed that data would be packed in the same format as
the LHS / RHS input types, and with a fixed layout. To support new data types,
especially int8, more flexibility will be needed.

 - int8 matmuls using dot product rather than FMA instructions will use a
   different blocked layout

 - For some architectures / data types it will make sense to expand inputs to
   a wider data type during packing rather than in the kernel

 - For some architectures / data types the kernel can operate on an unpacked
   LHS matrix, if it has unit column stride. For others packing will always be
   required.

To enable this:

 - Add `Kernel` methods which return descriptors specifying the size and
   alignment required for packing a particular A or B input.

 - Modify kernel interface to use opaque `[u8]` slices for packing buffer
   contents. The kernel implementations will cast this to a slice of the type
   they use internally.

 - Add a `PackingBuffer` struct which wraps a `Vec<u32>` buffer and
   provides an API for reserving space in the buffer and casting its
   contents to `[u8]` slices for the kernel and its packing methods.

The `PackedAMatrix` and `PackedBMatrix` types still have assumptions about
the internal layout of packed buffers which will need to removed. That will
happen in subsequent commits.
  • Loading branch information
robertknight committed Jan 3, 2025
1 parent 9373e69 commit 42b75a1
Show file tree
Hide file tree
Showing 9 changed files with 469 additions and 166 deletions.
248 changes: 122 additions & 126 deletions src/gemm.rs

Large diffs are not rendered by default.

72 changes: 67 additions & 5 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ pub mod wasm;
/// LHS / A input to a GEMM kernel.
#[derive(Clone, Copy)]
pub enum Lhs<'a, T> {
/// Input packed into a contiguous buffer in column major order.
Packed(&'a [T]),
/// Input packed by the kernel's [`pack_a_block`](Kernel::pack_a_block)
/// impl.
Packed(&'a [u8]),

/// Unpacked input with a column stride of 1 and row stride of `row_stride`.
///
Expand All @@ -36,6 +37,55 @@ pub enum Lhs<'a, T> {
},
}

/// Metadata about a packed block of an input matrix.
///
/// The packed block is expected to be organized as a sequence of panels with
/// stride [`panel_stride`](PackedInfo::panel_stride), but the kernel is
/// otherwise free to choose the layout.
pub struct PackedLayout {
size: usize,
align: usize,
panel_stride: usize,

/// True if the input must be packed to be used by the kernel.
pub must_pack: bool,
}

impl PackedLayout {
/// Construct a new packing buffer descriptor.
///
/// `size`, `align` and `panel_stride` specify the minimum size of the
/// packing buffer, its alignment and the stride between panels
/// respectively. All units are in bytes. The size must be a multiple of
/// both the alignment and panel stride.
pub fn new(size: usize, align: usize, panel_stride: usize) -> PackedLayout {
debug_assert_eq!(size % align, 0);
debug_assert_eq!(size % panel_stride, 0);

PackedLayout {
size,
align,
panel_stride,
must_pack: false,
}
}

/// Return size of the packed block in bytes.
pub fn size(&self) -> usize {
self.size
}

/// Return minimum alignment of the packed block.
pub fn align(&self) -> usize {
self.align
}

/// Return stride between panels in bytes.
pub fn panel_stride(&self) -> usize {
self.panel_stride
}
}

/// Kernel that computes a small tile of a general matrix multiplication (GEMM)
/// or general matrix-vector multiplication (GEMV).
///
Expand Down Expand Up @@ -67,19 +117,31 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
/// Return a name for this kernel for use in logging etc.
fn name(&self) -> &'static str;

/// Return the layout of a packing buffer required to pack a block of `a`
/// of size `rows x cols`.
fn packed_a_layout(&self, a: Matrix<LhsT>, rows: usize, cols: usize) -> PackedLayout;

/// Pack a block of the LHS / "A" input for use by this kernel.
fn pack_a_block(
&self,
out: &mut [MaybeUninit<LhsT>],
out: &mut [MaybeUninit<u8>],
a: Matrix<LhsT>,
rows: Range<usize>,
cols: Range<usize>,
);

/// Return the layout of a packing buffer required to pack a block of a "B"
/// / RHS input of size `rows x cols`.
///
/// Unlike `packed_a_layout` this doesn't take the matrix as an argument.
/// `packed_a_layout` may use this to indicate that the A input does not
/// need to be packed. For the B input it is assumed this is always packed.
fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout;

/// Pack a block of the RHS / "B" input for use by this kernel.
fn pack_b_block(
&self,
out: &mut [MaybeUninit<RhsT>],
out: &mut [MaybeUninit<u8>],
b: Matrix<RhsT>,
rows: Range<usize>,
cols: Range<usize>,
Expand Down Expand Up @@ -109,7 +171,7 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
tile_ptr: *mut OutT,
tile_row_stride: usize,
a: Lhs<LhsT>,
b: &[RhsT],
b: &[u8],
used_rows: usize,
used_cols: usize,
depth: usize,
Expand Down
27 changes: 21 additions & 6 deletions src/gemm/kernels/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::mem::MaybeUninit;
use std::ops::Range;

use rten_simd::vec_count;
use rten_tensor::Matrix;
use rten_tensor::{Matrix, MatrixLayout};

use super::simd_generic::{simd_gemm, simd_gemv};
use super::{Kernel, Lhs, TempTile};
use crate::gemm::packing::{pack_a_block, pack_b_block};
use super::{Kernel, Lhs, PackedLayout, TempTile};
use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_info, packed_b_info};
use crate::number::{cast_pod_mut_slice, cast_pod_slice};

pub struct ArmNeonKernel {
_private: (),
Expand Down Expand Up @@ -37,23 +38,35 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
Self::NR
}

fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout {
let mut info = packed_a_info::<f32, { Self::MR }>(rows, cols);
info.must_pack = a.col_stride() != 1;
info
}

fn pack_a_block(
&self,
out: &mut [MaybeUninit<f32>],
out: &mut [MaybeUninit<u8>],
a: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer");
pack_a_block::<f32, { Self::MR }>(out, a, rows, cols);
}

fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout {
packed_b_info::<f32, { Self::NR }>(rows, cols)
}

fn pack_b_block(
&self,
out: &mut [MaybeUninit<f32>],
out: &mut [MaybeUninit<u8>],
b: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer");
pack_b_block::<f32, { Self::NR }>(out, b, rows, cols);
}

Expand All @@ -62,7 +75,7 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
tile_ptr: *mut f32,
tile_row_stride: usize,
a: Lhs<f32>,
b: &[f32],
b: &[u8],
used_rows: usize,
used_cols: usize,
depth: usize,
Expand All @@ -73,6 +86,8 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
const NR: usize = ArmNeonKernel::NR;
const NR_REGS: usize = vec_count::<float32x4_t>(NR);

let b = cast_pod_slice(b).unwrap();

if used_cols == NR {
simd_gemm::<float32x4_t, MR, NR_REGS>(
tile_ptr,
Expand Down
27 changes: 21 additions & 6 deletions src/gemm/kernels/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use std::mem::MaybeUninit;
use std::ops::Range;

use rten_simd::vec_count;
use rten_tensor::Matrix;
use rten_tensor::{Matrix, MatrixLayout};

use super::simd_generic::{simd_gemm, simd_gemv};
use super::{Kernel, Lhs, TempTile};
use crate::gemm::packing::{pack_a_block, pack_b_block};
use super::{Kernel, Lhs, PackedLayout, TempTile};
use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_info, packed_b_info};
use crate::number::{cast_pod_mut_slice, cast_pod_slice};

/// This is the base kernel that does not use architecture-specific intrinsics
/// but is autovectorization-friendly. It is expected to perform the same as
Expand Down Expand Up @@ -42,23 +43,35 @@ unsafe impl Kernel<f32, f32, f32> for GenericKernel {
"base"
}

fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout {
let mut info = packed_a_info::<f32, { Self::MR }>(rows, cols);
info.must_pack = a.col_stride() != 1;
info
}

fn pack_a_block(
&self,
out: &mut [MaybeUninit<f32>],
out: &mut [MaybeUninit<u8>],
a: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).unwrap();
pack_a_block::<f32, { Self::MR }>(out, a, rows, cols);
}

fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout {
packed_b_info::<f32, { Self::NR }>(rows, cols)
}

fn pack_b_block(
&self,
out: &mut [MaybeUninit<f32>],
out: &mut [MaybeUninit<u8>],
b: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).unwrap();
pack_b_block::<f32, { Self::NR }>(out, b, rows, cols);
}

Expand All @@ -67,7 +80,7 @@ unsafe impl Kernel<f32, f32, f32> for GenericKernel {
tile_ptr: *mut f32,
tile_row_stride: usize,
a: Lhs<f32>,
b: &[f32],
b: &[u8],
used_rows: usize,
used_cols: usize,
depth: usize,
Expand All @@ -78,6 +91,8 @@ unsafe impl Kernel<f32, f32, f32> for GenericKernel {
const NR: usize = GenericKernel::NR;
const NR_REGS: usize = vec_count::<f32>(NR);

let b = cast_pod_slice(b).unwrap();

if used_cols == NR {
simd_gemm::<f32, MR, NR_REGS>(
tile_ptr,
Expand Down
20 changes: 16 additions & 4 deletions src/gemm/kernels/simd_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,14 @@ pub unsafe fn simd_gemm<S: SimdFloat, const MR: usize, const NR_REGS: usize>(
assert!(depth > 0);
let (a_ptr, a_row_stride) = match a {
Lhs::Packed(data) => {
assert!(data.len() >= depth * MR);
(data.as_ptr(), 1)
let min_len = depth * MR * size_of::<f32>();
assert!(
data.len() >= min_len,
"packed data len {} smaller than required {}",
data.len(),
min_len
);
(data.as_ptr() as *const f32, 1)
}
Lhs::Unpacked {
data,
Expand Down Expand Up @@ -365,8 +371,14 @@ pub unsafe fn simd_gemm_tail<S: SimdFloat, const MR: usize, const NR_REGS: usize
assert!(depth > 0);
let (a_ptr, a_row_stride) = match a {
Lhs::Packed(data) => {
assert!(data.len() >= depth * MR);
(data.as_ptr(), 1)
let min_len = depth * MR * size_of::<f32>();
assert!(
data.len() >= min_len,
"packed data len {} smaller than required {}",
data.len(),
min_len
);
(data.as_ptr() as *const f32, 1)
}
Lhs::Unpacked {
data,
Expand Down
29 changes: 21 additions & 8 deletions src/gemm/kernels/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::ops::Range;

use rten_simd::arch::wasm::v128f;
use rten_simd::vec_count;
use rten_tensor::Matrix;
use rten_tensor::{Matrix, MatrixLayout};

use super::simd_generic::{simd_gemm, simd_gemv};
use super::{Kernel, Lhs, TempTile};
use crate::gemm::packing::{pack_a_block, pack_b_block};
use super::{Kernel, Lhs, PackedLayout, TempTile};
use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_info, packed_b_info};
use crate::number::{cast_pod_mut_slice, cast_pod_slice};

pub struct WasmKernel {
_private: (),
Expand Down Expand Up @@ -41,23 +42,35 @@ unsafe impl Kernel<f32, f32, f32> for WasmKernel {
Self::NR
}

fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout {
let mut info = packed_a_info::<f32, { Self::MR }>(rows, cols);
info.must_pack = a.col_stride() != 1;
info
}

fn pack_a_block(
&self,
out: &mut [MaybeUninit<f32>],
out: &mut [MaybeUninit<u8>],
a: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).unwrap();
pack_a_block::<f32, { Self::MR }>(out, a, rows, cols);
}

fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout {
packed_b_info::<f32, { Self::NR }>(rows, cols)
}

fn pack_b_block(
&self,
out: &mut [MaybeUninit<f32>],
out: &mut [MaybeUninit<u8>],
b: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).unwrap();
pack_b_block::<f32, { Self::NR }>(out, b, rows, cols);
}

Expand All @@ -66,7 +79,7 @@ unsafe impl Kernel<f32, f32, f32> for WasmKernel {
tile_ptr: *mut f32,
tile_row_stride: usize,
a: Lhs<f32>,
b: &[f32],
b: &[u8],
used_rows: usize,
used_cols: usize,
depth: usize,
Expand All @@ -83,7 +96,7 @@ unsafe impl Kernel<f32, f32, f32> for WasmKernel {
tile_row_stride,
a,
used_rows,
b,
cast_pod_slice(b).unwrap(),
depth,
alpha,
beta,
Expand All @@ -95,7 +108,7 @@ unsafe impl Kernel<f32, f32, f32> for WasmKernel {
NR,
a,
used_rows,
b,
cast_pod_slice(b).unwrap(),
depth,
alpha,
0.,
Expand Down
Loading

0 comments on commit 42b75a1

Please sign in to comment.