From 42b75a1e1e154a3c424afbabefce25bb9720dfd0 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 3 Jan 2025 12:29:45 +0000 Subject: [PATCH] Make packing of GEMM inputs more flexible Previously the GEMM code assumed that data would be packed in the same format as the LHS / RHS input types, and with a fixed layout. To support new data types, especially int8, more flexibility will be needed. - int8 matmuls using dot product rather than FMA instructions will use a different blocked layout - For some architectures / data types it will make sense to expand inputs to a wider data type during packing rather than in the kernel - For some architectures / data types the kernel can operate on an unpacked LHS matrix, if it has unit column stride. For others packing will always be required. To enable this: - Add `Kernel` methods which return descriptors specifying the size and alignment required for packing a particular A or B input. - Modify kernel interface to use opaque `[u8]` slices for packing buffer contents. The kernel implementations will cast this to a slice of the type they use internally. - Add a `PackingBuffer` struct which wraps a `Vec` buffer and provides an API for reserving space in the buffer and casting its contents to `[u8]` slices for the kernel and its packing methods. The `PackedAMatrix` and `PackedBMatrix` types still have assumptions about the internal layout of packed buffers which will need to removed. That will happen in subsequent commits. --- src/gemm.rs | 248 +++++++++++++++---------------- src/gemm/kernels.rs | 72 ++++++++- src/gemm/kernels/aarch64.rs | 27 +++- src/gemm/kernels/generic.rs | 27 +++- src/gemm/kernels/simd_generic.rs | 20 ++- src/gemm/kernels/wasm.rs | 29 +++- src/gemm/kernels/x86_64.rs | 53 +++++-- src/gemm/packing.rs | 158 +++++++++++++++++++- src/number.rs | 1 + 9 files changed, 469 insertions(+), 166 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index a686dd26..11c9ff38 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -15,7 +15,7 @@ use rten_tensor::prelude::*; use rten_tensor::{Alloc, GlobalAlloc, Matrix, MatrixLayout, MatrixMut, NdTensorView, Storage}; use crate::iter_util::{range_chunks, MaybeParIter}; -use crate::number::{cast_pod_mut_slice, cast_pod_slice, Identities, Pod}; +use crate::number::{cast_pod_mut_slice, Identities, Pod}; use crate::tensor_pool::ExtractBuffer; mod kernels; @@ -23,17 +23,22 @@ mod packing; use kernels::generic::GenericKernel; use kernels::Kernel; +use packing::{PackElem, PackingBuffer}; /// Left-hand or "A" GEMM input that has been pre-packed. #[derive(Clone)] pub struct PackedAMatrix { - /// Sequence of packed row panels. - data: Vec, + /// Sequence of packed row panels. The exact format depends upon the kernel + /// that packed the data. + data: PackingBuffer, /// Height of row panel. This should match the kernel's [`mr`](Kernel::mr) /// value. panel_height: usize, + /// Stride of each panel in `data`. + panel_stride: usize, + /// Number of rows in the unpacked matrix. rows: usize, @@ -42,6 +47,8 @@ pub struct PackedAMatrix { /// Name of the kernel that packed this buffer. See [`Kernel::name`]. kernel_name: &'static str, + + _marker: PhantomData, } impl PackedAMatrix { @@ -50,37 +57,47 @@ impl PackedAMatrix { fn block(&self, rows: Range, depth: Range) -> LhsBlock { assert_eq!(rows.start % self.panel_height, 0); - let panel_stride = self.panel_len(); + // Size of each column in the packed block in bytes. This assumes the + // specific column major layout for each row panel currently used by + // the kernels. This will need to change as new packed formats are + // introduced. + let col_size = self.panel_height * size_of::(); + let panel_range = rows.start / self.panel_height..rows.end.div_ceil(self.panel_height); - let start = panel_range.start * panel_stride + depth.start * self.panel_height; - let end = (panel_range.end - 1) * panel_stride + depth.end * self.panel_height; - let data = &self.data[start..end]; - LhsBlock::Packed { data, panel_stride } - } + let start = panel_range.start * self.panel_stride + depth.start * col_size; + let end = (panel_range.end - 1) * self.panel_stride + depth.end * col_size; + let data = &self.data.as_bytes()[start..end]; - fn panel_len(&self) -> usize { - self.panel_height * self.cols + LhsBlock::Packed { + data, + panel_stride: self.panel_stride, + panel_len: (depth.end - depth.start) * col_size, + } } } impl ExtractBuffer for PackedAMatrix { - type Elem = T; + type Elem = PackElem; - fn extract_buffer(self) -> Option> { - Some(self.data) + fn extract_buffer(self) -> Option> { + Some(self.data.into_vec()) } } /// Right-hand or "B" GEMM input that has been pre-packed. #[derive(Clone)] pub struct PackedBMatrix { - /// Sequence of packed column panels. - data: Vec, + /// Sequence of packed column panels. The exact format depends upon the + /// kernel that packed the data. + data: PackingBuffer, /// Width of column panel. This should match the kernel's [`nr`](Kernel::nr) /// value. panel_width: usize, + /// Stride of each panel in `data`. + panel_stride: usize, + /// Number of rows in the unpacked matrix. rows: usize, @@ -89,6 +106,8 @@ pub struct PackedBMatrix { /// Name of the kernel that packed this buffer. See [`Kernel::name`]. kernel_name: &'static str, + + _marker: PhantomData, } impl PackedBMatrix { @@ -97,16 +116,19 @@ impl PackedBMatrix { fn block(&self, cols: Range, depth: Range) -> RhsBlock { assert_eq!(cols.start % self.panel_width, 0); - let panel_stride = self.panel_len(); + let row_size = self.panel_width * size_of::(); + let panel_range = cols.start / self.panel_width..cols.end.div_ceil(self.panel_width); - let start = panel_range.start * panel_stride + depth.start * self.panel_width; - let end = (panel_range.end - 1) * panel_stride + depth.end * self.panel_width; - let data = &self.data[start..end]; - RhsBlock { data, panel_stride } - } + let start = panel_range.start * self.panel_stride + depth.start * row_size; + let end = (panel_range.end - 1) * self.panel_stride + depth.end * row_size; + let data = &self.data.as_bytes()[start..end]; - fn panel_len(&self) -> usize { - self.panel_width * self.rows + RhsBlock { + data, + panel_stride: self.panel_stride, + panel_len: (depth.end - depth.start) * row_size, + _marker: PhantomData, + } } /// Number of rows in the unpacked matrix. @@ -121,10 +143,10 @@ impl PackedBMatrix { } impl ExtractBuffer for PackedBMatrix { - type Elem = T; + type Elem = PackElem; - fn extract_buffer(self) -> Option> { - Some(self.data) + fn extract_buffer(self) -> Option> { + Some(self.data.into_vec()) } } @@ -349,29 +371,16 @@ impl GemmExecutor(&self, alloc: A, a: Matrix) -> PackedAMatrix { - let mr = self.kernel.mr(); - let panel_len = mr * a.cols(); - let packed_len = a.rows().next_multiple_of(mr) * a.cols(); - let mut data = alloc.alloc(packed_len); - - // Pack input as a sequence of row panels. - let mut out_panels = data.spare_capacity_mut()[..packed_len].chunks_exact_mut(panel_len); - let mut n_init = 0; - for panel_rows in range_chunks(0..a.rows(), mr) { - let out_panel = out_panels.next().unwrap(); - let used_size = panel_rows.len().next_multiple_of(mr) * a.cols(); - let (used, unused) = out_panel.split_at_mut(used_size); - - self.kernel.pack_a_block(used, a, panel_rows, 0..a.cols()); - - unused.fill(MaybeUninit::new(LhsT::zero())); - n_init += out_panel.len(); - } + let layout = self.kernel.packed_a_layout(a, a.rows(), a.cols()); + let mut data = PackingBuffer::new(); + let uninit_data = data.alloc_in(alloc, &layout); - // Safety: We used `pack_a_block` to initialize `packed_len` elements. - assert_eq!(n_init, packed_len); + self.kernel + .pack_a_block(uninit_data, a, 0..a.rows(), 0..a.cols()); + + // Safety: We used `pack_a_block` to initialize `layout.size` bytes unsafe { - data.set_len(packed_len); + data.set_len(layout.size()); } PackedAMatrix { @@ -379,7 +388,9 @@ impl GemmExecutor GemmExecutor(&self, alloc: A, b: Matrix) -> PackedBMatrix { - let nr = self.kernel.nr(); - let packed_len = b.cols().next_multiple_of(nr) * b.rows(); - let panel_len = nr * b.rows(); - let mut out = alloc.alloc(packed_len); - - // Pack input as a sequence of column panels. - let mut out_panels = out.spare_capacity_mut()[..packed_len].chunks_exact_mut(panel_len); - let mut n_init = 0; - for panel_cols in range_chunks(0..b.cols(), nr) { - let out_panel = out_panels.next().unwrap(); - let used_size = panel_cols.len().next_multiple_of(nr) * b.rows(); - let (used, unused) = out_panel.split_at_mut(used_size); - - self.kernel - .pack_b_block(used, b, 0..b.rows(), panel_cols.clone()); - - unused.fill(MaybeUninit::new(RhsT::zero())); - n_init += out_panel.len(); - } + let layout = self.kernel.packed_b_layout(b.rows(), b.cols()); + let mut data = PackingBuffer::new(); + let uninit_data = data.alloc_in(alloc, &layout); + + self.kernel + .pack_b_block(uninit_data, b, 0..b.rows(), 0..b.cols()); - // Safety: We used `pack_b_block` to initialize `packed_len` elements. - assert_eq!(n_init, packed_len); + // Safety: We used `pack_b_block` to initialize `layout.size` bytes. unsafe { - out.set_len(packed_len); + data.set_len(layout.size()); } PackedBMatrix { - data: out, + data, rows: b.rows(), cols: b.cols(), - panel_width: nr, + panel_width: self.kernel.nr(), + panel_stride: layout.panel_stride(), kernel_name: self.kernel.name(), + _marker: PhantomData, } } @@ -933,14 +932,8 @@ fn gemm_impl( } // Buffers for packed blocks of the matrix. - // - // These use `u64` rather than LhsT / RhsT because statics cannot be generic. - // `u64` is assumed to have an alignment that is greater or equal to the - // alignment of any LhsT / RhsT. - thread_local!(static PACKED_A: RefCell> = const { RefCell::new(Vec::new()) }); - thread_local!(static PACKED_B: RefCell> = const { RefCell::new(Vec::new()) }); - assert!(align_of::() <= align_of::()); - assert!(align_of::() <= align_of::()); + thread_local!(static PACKED_A: RefCell = const { RefCell::new(PackingBuffer::new()) }); + thread_local!(static PACKED_B: RefCell = const { RefCell::new(PackingBuffer::new()) }); let n_col_blocks = b.cols().div_ceil(nc); let n_row_blocks = a.rows().div_ceil(mc); @@ -964,30 +957,27 @@ fn gemm_impl( for depth_range in range_chunks(0..a.cols(), kc) { // Borrowed packing buffer for current thread. Returned after // the GEMM block is computed. - let mut thread_local_packed_b: Option> = None; - let panel_length = depth_range.len(); - let packed_b_size = (col_end - col_start).next_multiple_of(nr) * panel_length; + let mut thread_local_packed_b: Option = None; let rhs_block = match b { GemmInputB::Unpacked(_) | GemmInputB::Virtual(_) => PACKED_B.with(|cell| { let mut packed_b = cell.take(); - packed_b.clear(); - packed_b - .reserve(packed_b_size.div_ceil(size_of::() / size_of::())); - let packed_b_slice = - cast_pod_mut_slice(packed_b.spare_capacity_mut()).unwrap(); - let packed_b_slice = &mut packed_b_slice[..packed_b_size]; + let layout = kernel.packed_b_layout(depth_range.len(), col_end - col_start); + let packed_uninit = packed_b.alloc(&layout); match b { GemmInputB::Unpacked(b) => kernel.pack_b_block( - packed_b_slice, + packed_uninit, b, depth_range.clone(), col_start..col_end, ), GemmInputB::Virtual(vm) => vm.pack_b( - packed_b_slice, + // Cast [MaybeUninit] => [MaybeUninit] as im2col packing + // currently assumes the packed data is in the same format as the + // RHS input. + cast_pod_mut_slice(packed_uninit).unwrap(), kernel.nr(), depth_range.clone(), col_start..col_end, @@ -995,17 +985,16 @@ fn gemm_impl( GemmInputB::Packed(_) => unreachable!(), } - // Safety: The packing call initialized `packed_b_size` elements. + // Safety: `pack_b_block` will have initialized `layout.size()` bytes. unsafe { - packed_b.set_len(packed_b_size); + packed_b.set_len(layout.size()); } thread_local_packed_b = Some(packed_b); - let packed_b_data = - cast_pod_slice::<_, RhsT>(thread_local_packed_b.as_deref().unwrap()) - .unwrap(); RhsBlock { - data: packed_b_data, - panel_stride: kernel.nr() * depth_range.len(), + data: thread_local_packed_b.as_ref().unwrap().as_bytes(), + panel_stride: layout.panel_stride(), + panel_len: layout.panel_stride(), + _marker: PhantomData, } }), GemmInputB::Packed(pm) => pm.block(col_range.clone(), depth_range.clone()), @@ -1026,45 +1015,41 @@ fn gemm_impl( let row_start = row_idx * mc; let row_end = (row_start + mc).min(a.rows()); let row_range = row_start..row_end; - let packed_a_size = - (row_end - row_start).next_multiple_of(mr) * depth_range.len(); // Borrowed packing buffer for current thread. Returned after // the GEMM block is computed. - let mut thread_local_packed_a: Option> = None; + let mut thread_local_packed_a: Option = None; let lhs_block = match a { - GemmInputA::Unpacked(a) if a.col_stride() == 1 => LhsBlock::Unpacked(a), GemmInputA::Unpacked(a) => PACKED_A.with(|cell| { - let mut packed_a = cell.take(); - packed_a.clear(); - packed_a.reserve( - packed_a_size.div_ceil(size_of::() / size_of::()), + let layout = kernel.packed_a_layout( + a, + row_end - row_start, + depth_range.len(), ); + if !layout.must_pack { + return LhsBlock::Unpacked(a); + }; + + let mut packed_a = cell.take(); + let packed_uninit = packed_a.alloc(&layout); - let packed_a_block = cast_pod_mut_slice::<_, MaybeUninit>( - packed_a.spare_capacity_mut(), - ) - .unwrap(); kernel.pack_a_block( - &mut packed_a_block[..packed_a_size], + packed_uninit, a, row_start..row_end, depth_range.clone(), ); - // Safety: `pack_a_block` will have initialized - // `packed_a_size` elements. + + // Safety: We initialized `layout.size` bytes. unsafe { - packed_a.set_len(packed_a_size); + packed_a.set_len(layout.size()); } thread_local_packed_a = Some(packed_a); - let packed_a = cast_pod_slice::<_, LhsT>( - thread_local_packed_a.as_deref().unwrap(), - ) - .unwrap(); LhsBlock::Packed { - data: packed_a, - panel_stride: mr * depth_range.len(), + data: thread_local_packed_a.as_ref().unwrap().as_bytes(), + panel_stride: layout.panel_stride(), + panel_len: layout.panel_stride(), } }), GemmInputA::Packed(pm) => { @@ -1102,10 +1087,13 @@ fn gemm_impl( enum LhsBlock<'a, T> { /// Packed block of A matrix, arranged as a sequence of row panels. Packed { - data: &'a [T], + data: &'a [u8], /// Stride between each row panel. panel_stride: usize, + + /// Length of each row panel. + panel_len: usize, }, /// Unpacked A matrix. This must have a column stride of 1. @@ -1116,10 +1104,15 @@ enum LhsBlock<'a, T> { /// a sequence of column panels. #[derive(Copy, Clone)] struct RhsBlock<'a, T> { - data: &'a [T], + data: &'a [u8], /// Stride between each column panel. panel_stride: usize, + + /// Size between each column panel. + panel_len: usize, + + _marker: PhantomData, } /// Process a single block (ie. a slice along each of the M/N/K dimensions) of a @@ -1163,7 +1156,7 @@ fn gemm_block( .enumerate() .for_each(|(block_col_tile, col_tile)| { let b_panel_offset = block_col_tile * b.panel_stride; - let b_panel = &b.data[b_panel_offset..b_panel_offset + nr * depth_range.len()]; + let b_panel = &b.data[b_panel_offset..b_panel_offset + b.panel_len]; // Loop over row tiles. for (block_row_tile, row_tile) in row_tiles.clone().enumerate() { @@ -1173,10 +1166,13 @@ fn gemm_block( let out_tile = unsafe { output.tile(row_tile, col_tile) }; let kernel_lhs = match a { - LhsBlock::Packed { data, panel_stride } => { + LhsBlock::Packed { + data, + panel_stride, + panel_len, + } => { let a_panel_offset = block_row_tile * panel_stride; - let a_panel = - &data[a_panel_offset..a_panel_offset + mr * depth_range.len()]; + let a_panel = &data[a_panel_offset..a_panel_offset + panel_len]; kernels::Lhs::Packed(a_panel) } LhsBlock::Unpacked(mat) => { diff --git a/src/gemm/kernels.rs b/src/gemm/kernels.rs index c6c8a4c5..5fd0fe4a 100644 --- a/src/gemm/kernels.rs +++ b/src/gemm/kernels.rs @@ -20,8 +20,9 @@ pub mod wasm; /// LHS / A input to a GEMM kernel. #[derive(Clone, Copy)] pub enum Lhs<'a, T> { - /// Input packed into a contiguous buffer in column major order. - Packed(&'a [T]), + /// Input packed by the kernel's [`pack_a_block`](Kernel::pack_a_block) + /// impl. + Packed(&'a [u8]), /// Unpacked input with a column stride of 1 and row stride of `row_stride`. /// @@ -36,6 +37,55 @@ pub enum Lhs<'a, T> { }, } +/// Metadata about a packed block of an input matrix. +/// +/// The packed block is expected to be organized as a sequence of panels with +/// stride [`panel_stride`](PackedInfo::panel_stride), but the kernel is +/// otherwise free to choose the layout. +pub struct PackedLayout { + size: usize, + align: usize, + panel_stride: usize, + + /// True if the input must be packed to be used by the kernel. + pub must_pack: bool, +} + +impl PackedLayout { + /// Construct a new packing buffer descriptor. + /// + /// `size`, `align` and `panel_stride` specify the minimum size of the + /// packing buffer, its alignment and the stride between panels + /// respectively. All units are in bytes. The size must be a multiple of + /// both the alignment and panel stride. + pub fn new(size: usize, align: usize, panel_stride: usize) -> PackedLayout { + debug_assert_eq!(size % align, 0); + debug_assert_eq!(size % panel_stride, 0); + + PackedLayout { + size, + align, + panel_stride, + must_pack: false, + } + } + + /// Return size of the packed block in bytes. + pub fn size(&self) -> usize { + self.size + } + + /// Return minimum alignment of the packed block. + pub fn align(&self) -> usize { + self.align + } + + /// Return stride between panels in bytes. + pub fn panel_stride(&self) -> usize { + self.panel_stride + } +} + /// Kernel that computes a small tile of a general matrix multiplication (GEMM) /// or general matrix-vector multiplication (GEMV). /// @@ -67,19 +117,31 @@ pub unsafe trait Kernel: Sync { /// Return a name for this kernel for use in logging etc. fn name(&self) -> &'static str; + /// Return the layout of a packing buffer required to pack a block of `a` + /// of size `rows x cols`. + fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout; + /// Pack a block of the LHS / "A" input for use by this kernel. fn pack_a_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], a: Matrix, rows: Range, cols: Range, ); + /// Return the layout of a packing buffer required to pack a block of a "B" + /// / RHS input of size `rows x cols`. + /// + /// Unlike `packed_a_layout` this doesn't take the matrix as an argument. + /// `packed_a_layout` may use this to indicate that the A input does not + /// need to be packed. For the B input it is assumed this is always packed. + fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout; + /// Pack a block of the RHS / "B" input for use by this kernel. fn pack_b_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], b: Matrix, rows: Range, cols: Range, @@ -109,7 +171,7 @@ pub unsafe trait Kernel: Sync { tile_ptr: *mut OutT, tile_row_stride: usize, a: Lhs, - b: &[RhsT], + b: &[u8], used_rows: usize, used_cols: usize, depth: usize, diff --git a/src/gemm/kernels/aarch64.rs b/src/gemm/kernels/aarch64.rs index 23c5346c..5eea16db 100644 --- a/src/gemm/kernels/aarch64.rs +++ b/src/gemm/kernels/aarch64.rs @@ -3,11 +3,12 @@ use std::mem::MaybeUninit; use std::ops::Range; use rten_simd::vec_count; -use rten_tensor::Matrix; +use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemm, simd_gemv}; -use super::{Kernel, Lhs, TempTile}; -use crate::gemm::packing::{pack_a_block, pack_b_block}; +use super::{Kernel, Lhs, PackedLayout, TempTile}; +use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_info, packed_b_info}; +use crate::number::{cast_pod_mut_slice, cast_pod_slice}; pub struct ArmNeonKernel { _private: (), @@ -37,23 +38,35 @@ unsafe impl Kernel for ArmNeonKernel { Self::NR } + fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + let mut info = packed_a_info::(rows, cols); + info.must_pack = a.col_stride() != 1; + info + } + fn pack_a_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], a: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); pack_a_block::(out, a, rows, cols); } + fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + packed_b_info::(rows, cols) + } + fn pack_b_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], b: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); pack_b_block::(out, b, rows, cols); } @@ -62,7 +75,7 @@ unsafe impl Kernel for ArmNeonKernel { tile_ptr: *mut f32, tile_row_stride: usize, a: Lhs, - b: &[f32], + b: &[u8], used_rows: usize, used_cols: usize, depth: usize, @@ -73,6 +86,8 @@ unsafe impl Kernel for ArmNeonKernel { const NR: usize = ArmNeonKernel::NR; const NR_REGS: usize = vec_count::(NR); + let b = cast_pod_slice(b).unwrap(); + if used_cols == NR { simd_gemm::( tile_ptr, diff --git a/src/gemm/kernels/generic.rs b/src/gemm/kernels/generic.rs index 75080eeb..5dcfb5f1 100644 --- a/src/gemm/kernels/generic.rs +++ b/src/gemm/kernels/generic.rs @@ -2,11 +2,12 @@ use std::mem::MaybeUninit; use std::ops::Range; use rten_simd::vec_count; -use rten_tensor::Matrix; +use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemm, simd_gemv}; -use super::{Kernel, Lhs, TempTile}; -use crate::gemm::packing::{pack_a_block, pack_b_block}; +use super::{Kernel, Lhs, PackedLayout, TempTile}; +use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_info, packed_b_info}; +use crate::number::{cast_pod_mut_slice, cast_pod_slice}; /// This is the base kernel that does not use architecture-specific intrinsics /// but is autovectorization-friendly. It is expected to perform the same as @@ -42,23 +43,35 @@ unsafe impl Kernel for GenericKernel { "base" } + fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + let mut info = packed_a_info::(rows, cols); + info.must_pack = a.col_stride() != 1; + info + } + fn pack_a_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], a: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).unwrap(); pack_a_block::(out, a, rows, cols); } + fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + packed_b_info::(rows, cols) + } + fn pack_b_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], b: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).unwrap(); pack_b_block::(out, b, rows, cols); } @@ -67,7 +80,7 @@ unsafe impl Kernel for GenericKernel { tile_ptr: *mut f32, tile_row_stride: usize, a: Lhs, - b: &[f32], + b: &[u8], used_rows: usize, used_cols: usize, depth: usize, @@ -78,6 +91,8 @@ unsafe impl Kernel for GenericKernel { const NR: usize = GenericKernel::NR; const NR_REGS: usize = vec_count::(NR); + let b = cast_pod_slice(b).unwrap(); + if used_cols == NR { simd_gemm::( tile_ptr, diff --git a/src/gemm/kernels/simd_generic.rs b/src/gemm/kernels/simd_generic.rs index 01a482b5..61abaf95 100644 --- a/src/gemm/kernels/simd_generic.rs +++ b/src/gemm/kernels/simd_generic.rs @@ -231,8 +231,14 @@ pub unsafe fn simd_gemm( assert!(depth > 0); let (a_ptr, a_row_stride) = match a { Lhs::Packed(data) => { - assert!(data.len() >= depth * MR); - (data.as_ptr(), 1) + let min_len = depth * MR * size_of::(); + assert!( + data.len() >= min_len, + "packed data len {} smaller than required {}", + data.len(), + min_len + ); + (data.as_ptr() as *const f32, 1) } Lhs::Unpacked { data, @@ -365,8 +371,14 @@ pub unsafe fn simd_gemm_tail 0); let (a_ptr, a_row_stride) = match a { Lhs::Packed(data) => { - assert!(data.len() >= depth * MR); - (data.as_ptr(), 1) + let min_len = depth * MR * size_of::(); + assert!( + data.len() >= min_len, + "packed data len {} smaller than required {}", + data.len(), + min_len + ); + (data.as_ptr() as *const f32, 1) } Lhs::Unpacked { data, diff --git a/src/gemm/kernels/wasm.rs b/src/gemm/kernels/wasm.rs index 6ff42614..c81ccc26 100644 --- a/src/gemm/kernels/wasm.rs +++ b/src/gemm/kernels/wasm.rs @@ -3,11 +3,12 @@ use std::ops::Range; use rten_simd::arch::wasm::v128f; use rten_simd::vec_count; -use rten_tensor::Matrix; +use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemm, simd_gemv}; -use super::{Kernel, Lhs, TempTile}; -use crate::gemm::packing::{pack_a_block, pack_b_block}; +use super::{Kernel, Lhs, PackedLayout, TempTile}; +use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_info, packed_b_info}; +use crate::number::{cast_pod_mut_slice, cast_pod_slice}; pub struct WasmKernel { _private: (), @@ -41,23 +42,35 @@ unsafe impl Kernel for WasmKernel { Self::NR } + fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + let mut info = packed_a_info::(rows, cols); + info.must_pack = a.col_stride() != 1; + info + } + fn pack_a_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], a: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).unwrap(); pack_a_block::(out, a, rows, cols); } + fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + packed_b_info::(rows, cols) + } + fn pack_b_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], b: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).unwrap(); pack_b_block::(out, b, rows, cols); } @@ -66,7 +79,7 @@ unsafe impl Kernel for WasmKernel { tile_ptr: *mut f32, tile_row_stride: usize, a: Lhs, - b: &[f32], + b: &[u8], used_rows: usize, used_cols: usize, depth: usize, @@ -83,7 +96,7 @@ unsafe impl Kernel for WasmKernel { tile_row_stride, a, used_rows, - b, + cast_pod_slice(b).unwrap(), depth, alpha, beta, @@ -95,7 +108,7 @@ unsafe impl Kernel for WasmKernel { NR, a, used_rows, - b, + cast_pod_slice(b).unwrap(), depth, alpha, 0., diff --git a/src/gemm/kernels/x86_64.rs b/src/gemm/kernels/x86_64.rs index 6de034e0..09952e6d 100644 --- a/src/gemm/kernels/x86_64.rs +++ b/src/gemm/kernels/x86_64.rs @@ -6,14 +6,15 @@ use std::ops::Range; use std::arch::x86_64::__m512; use rten_simd::vec_count; -use rten_tensor::Matrix; +use rten_tensor::{Matrix, MatrixLayout}; #[cfg(feature = "avx512")] use rten_simd::isa_detection::is_avx512_supported; use super::simd_generic::{simd_gemm, simd_gemv}; -use super::{Kernel, Lhs, TempTile}; -use crate::gemm::packing::{pack_a_block, pack_b_block}; +use super::{Kernel, Lhs, PackedLayout, TempTile}; +use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_info, packed_b_info}; +use crate::number::{cast_pod_mut_slice, cast_pod_slice}; /// Optimized kernel for x64 CPUs that support AVX + FMA instructions. pub struct FmaKernel { @@ -71,26 +72,40 @@ unsafe impl Kernel for FmaKernel { Self::NR } + fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + let mut info = packed_a_info::(rows, cols); + info.must_pack = a.col_stride() != 1; + info + } + fn pack_a_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], a: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); + // Safety: Kernel can only be constructed if AVX is supported. unsafe { pack_a_block_avx::<{ Self::MR }>(out, a, rows, cols); } } + fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + packed_b_info::(rows, cols) + } + fn pack_b_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], b: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).unwrap(); + // Safety: Kernel can only be constructed if AVX is supported. unsafe { pack_b_block_avx::<{ Self::NR }>(out, b, rows, cols); @@ -104,7 +119,7 @@ unsafe impl Kernel for FmaKernel { tile_ptr: *mut f32, tile_row_stride: usize, a: Lhs, - b: &[f32], + b: &[u8], used_rows: usize, used_cols: usize, depth: usize, @@ -115,6 +130,8 @@ unsafe impl Kernel for FmaKernel { const NR: usize = FmaKernel::NR; const NR_REGS: usize = vec_count::<__m256>(NR); + let b = cast_pod_slice(b).unwrap(); + // TODO - Replace temporary tile with masked loads and stores. let mut tmp_tile = TempTile::::new(); let (dest_ptr, dest_row_stride, dest_beta) = if used_cols == NR { @@ -209,26 +226,40 @@ unsafe impl Kernel for Avx512Kernel { Self::NR } + fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + let mut info = packed_a_info::(rows, cols); + info.must_pack = a.col_stride() != 1; + info + } + fn pack_a_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], a: Matrix, rows: Range, cols: Range, ) { - // Safety: We assume AVX-512 implies availability of AVX 2. + let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); + + // Safety: AVX-512 implies availability of AVX 2. unsafe { pack_a_block_avx::<{ Self::MR }>(out, a, rows, cols); } } + fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + packed_b_info::(rows, cols) + } + fn pack_b_block( &self, - out: &mut [MaybeUninit], + out: &mut [MaybeUninit], b: Matrix, rows: Range, cols: Range, ) { + let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); + // Safety: We assume AVX-512 implies availability of AVX 2. unsafe { pack_b_block_avx::<{ Self::NR }>(out, b, rows, cols); @@ -242,7 +273,7 @@ unsafe impl Kernel for Avx512Kernel { tile_ptr: *mut f32, tile_row_stride: usize, a: Lhs, - b: &[f32], + b: &[u8], used_rows: usize, used_cols: usize, depth: usize, @@ -253,6 +284,8 @@ unsafe impl Kernel for Avx512Kernel { const NR: usize = Avx512Kernel::NR; const NR_REGS: usize = vec_count::<__m512>(NR); + let b = cast_pod_slice(b).unwrap(); + // TODO - Replace temporary tile with masked loads and stores. let mut tmp_tile = TempTile::::new(); let (dest_ptr, dest_row_stride, dest_beta) = if used_cols == NR { diff --git a/src/gemm/packing.rs b/src/gemm/packing.rs index a86c79e0..75256b49 100644 --- a/src/gemm/packing.rs +++ b/src/gemm/packing.rs @@ -1,7 +1,18 @@ use std::mem::MaybeUninit; use std::ops::Range; -use rten_tensor::{Matrix, MatrixLayout, Storage}; +use rten_tensor::{Alloc, Matrix, MatrixLayout, Storage}; + +use super::kernels::PackedLayout; +use crate::number::{cast_pod_mut_slice, cast_pod_slice}; + +/// Return the required size and other metadata for packing an "A" matrix with +/// [`pack_a_block`]. +pub fn packed_a_info(rows: usize, cols: usize) -> PackedLayout { + let size = rows.next_multiple_of(MR) * cols * size_of::(); + let panel_stride = MR * cols * size_of::(); + PackedLayout::new(size, align_of::(), panel_stride) +} /// Pack a block of the "A" matrix for use by a GEMM kernel. /// @@ -89,6 +100,14 @@ pub fn pack_a_block( } } +/// Return the required size and other metadata for packing a "B" matrix with +/// [`pack_b_block`]. +pub fn packed_b_info(rows: usize, cols: usize) -> PackedLayout { + let size = cols.next_multiple_of(NR) * rows * size_of::(); + let panel_stride = NR * rows * size_of::(); + PackedLayout::new(size, align_of::(), panel_stride) +} + /// Pack a block of the "B" matrix for use by a GEMM kernel. /// /// The packed buffer is laid out as a sequence of `ceil(cols.len() / @@ -184,3 +203,140 @@ pub fn pack_b_block( } } } + +// Element type used by [`PackingBuffer`]. This must have an alignment that is +// at least as large as the alignment required by any of the kernels. +pub type PackElem = u32; + +/// Buffer used for storing a block of a packed matrix. +/// +/// The data type and layout of the contents is determined by the GEMM kernel, +/// subject to the constraints: +/// +/// - There is a maximum alignment the kernel can request. See [`PackElem`]. +/// - The stored data must all be plain `Copy` types for which any bit pattern +/// is valid. +#[derive(Clone)] +pub struct PackingBuffer { + buf: Vec, + used_len: usize, +} + +impl PackingBuffer { + /// Construct an empty packing buffer. + /// + /// No allocation happens until `alloc` is called. + pub const fn new() -> PackingBuffer { + PackingBuffer { + buf: Vec::new(), + used_len: 0, + } + } + + /// Clear the buffer and reserve space for a packed input. + /// + /// Returns an uninitialized slice of `layout.size()` bytes which the + /// caller must fill. + pub fn alloc(&mut self, layout: &PackedLayout) -> &mut [MaybeUninit] { + assert!(layout.align() <= align_of::()); + + let buf_len = layout.size().div_ceil(size_of::()); + self.buf.clear(); + self.buf.reserve(buf_len); + self.used_len = 0; + + let uninit_data = &mut self.buf.spare_capacity_mut()[..buf_len]; + cast_pod_mut_slice(uninit_data).unwrap() + } + + /// Clear the buffer and allocate a new one using `alloc`. + /// + /// When the packing buffer is no longer needed it can be extracted using + /// [`into_vec`](Self::into_vec) to be returned to the pool that `alloc` + /// allocates from. + pub fn alloc_in( + &mut self, + alloc: A, + layout: &PackedLayout, + ) -> &mut [MaybeUninit] { + assert!(layout.align() <= align_of::()); + + let buf_len = layout.size().div_ceil(size_of::()); + self.buf = alloc.alloc::(buf_len); + self.used_len = 0; + + let uninit_data = &mut self.buf.spare_capacity_mut()[..buf_len]; + cast_pod_mut_slice(uninit_data).unwrap() + } + + /// Set the number of bytes in the buffer which have been initialized. + pub unsafe fn set_len(&mut self, initialized_len: usize) { + let rounded_len = initialized_len.next_multiple_of(size_of::()); + assert_eq!(rounded_len, initialized_len); + + let buf_len = rounded_len / size_of::(); + assert!(buf_len <= self.buf.capacity()); + self.buf.set_len(buf_len); + self.used_len = initialized_len; + } + + /// Return the contents of the buffer as a slice of bytes. + pub fn as_bytes(&self) -> &[u8] { + &cast_pod_slice(&self.buf).unwrap()[..self.used_len] + } + + /// Extract the buffer from self. + pub fn into_vec(self) -> Vec { + self.buf + } +} + +impl Default for PackingBuffer { + fn default() -> Self { + PackingBuffer::new() + } +} + +#[cfg(test)] +mod tests { + use std::mem::MaybeUninit; + + use super::{PackedLayout, PackingBuffer}; + + #[test] + fn test_packing_buffer() { + struct Case { + size: usize, + align: usize, + panel_stride: usize, + } + + let cases = [Case { + size: 256, + align: 4, + panel_stride: 64, + }]; + + for Case { + size, + align, + panel_stride, + } in cases + { + let mut buf = PackingBuffer::new(); + assert_eq!(buf.as_bytes().len(), 0); + + let layout = PackedLayout::new(size, align, panel_stride); + let uninit_data = buf.alloc(&layout); + assert_eq!(uninit_data.len(), layout.size()); + + uninit_data.fill(MaybeUninit::new(0)); + + unsafe { + buf.set_len(layout.size()); + } + + assert_eq!(buf.as_bytes().len(), layout.size()); + } + } +} diff --git a/src/number.rs b/src/number.rs index 04c24706..5cace137 100644 --- a/src/number.rs +++ b/src/number.rs @@ -232,6 +232,7 @@ impl Pod for i8 {} impl Pod for u8 {} impl Pod for f32 {} impl Pod for i32 {} +impl Pod for u32 {} impl Pod for u64 {} impl Pod for MaybeUninit {}