Skip to content

Commit

Permalink
Merge pull request rustformers#311 from pixelspark/pedal-to-the-metal
Browse files Browse the repository at this point in the history
Implement Metal support
  • Loading branch information
philpax authored Jun 21, 2023
2 parents 275ba35 + 0921b1c commit 7927d0d
Show file tree
Hide file tree
Showing 26 changed files with 1,491 additions and 1,180 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand Down
8 changes: 7 additions & 1 deletion binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ pub struct Generate {
/// option will override this if specified.
#[arg(long, default_value_t = false)]
pub ignore_eos: bool,

/// Whether to use GPU acceleration when available
#[arg(long, default_value_t = false)]
pub use_gpu: bool,
}
impl Generate {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
Expand Down Expand Up @@ -301,6 +305,7 @@ impl Generate {
InferenceSessionConfig {
memory_k_type: mem_typ,
memory_v_type: mem_typ,
use_gpu: self.use_gpu,
}
}

Expand Down Expand Up @@ -403,11 +408,12 @@ pub struct ModelLoad {
pub lora_paths: Option<Vec<PathBuf>>,
}
impl ModelLoad {
pub fn load<M: llm::KnownModel + 'static>(&self) -> Result<Box<dyn Model>> {
pub fn load<M: llm::KnownModel + 'static>(&self, use_gpu: bool) -> Result<Box<dyn Model>> {
let params = ModelParameters {
prefer_mmap: !self.no_mmap,
context_size: self.num_ctx_tokens,
lora_adapters: self.lora_paths.clone(),
use_gpu,
};

let mut sp = Some(spinoff::Spinner::new(
Expand Down
25 changes: 13 additions & 12 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn handle_args<M: llm::KnownModel + 'static>(args: &cli_args::BaseArgs) -> Resul
fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(args.generate.use_gpu)?;

let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
Expand Down Expand Up @@ -119,7 +119,7 @@ fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
fn perplexity<M: llm::KnownModel + 'static>(args: &cli_args::Perplexity) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(args.generate.use_gpu)?;
let (mut session, _) = snapshot::read_or_create_session(
model.as_ref(),
None,
Expand Down Expand Up @@ -184,7 +184,7 @@ fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {

fn prompt_tokens<M: llm::KnownModel + 'static>(args: &cli_args::PromptTokens) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(false)?;
let toks = match model.vocabulary().tokenize(&prompt, false) {
Ok(toks) => toks,
Err(e) => {
Expand Down Expand Up @@ -231,8 +231,8 @@ fn interactive<M: llm::KnownModel + 'static>(
) -> Result<()> {
let prompt_file = args.prompt_file.contents();
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>()?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
let model = args.model_load.load::<M>(args.generate.use_gpu)?;
let (mut session, mut session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
args.generate.load_session.as_deref(),
Expand All @@ -250,11 +250,6 @@ fn interactive<M: llm::KnownModel + 'static>(
let readline = rl.readline(">> ");
match readline {
Ok(raw_line) => {
let session_backup = if chat_mode {
None
} else {
Some(session.clone())
};
let line = raw_line.replace("\\\n", "\n");

let prompt = prompt_file
Expand Down Expand Up @@ -302,8 +297,14 @@ fn interactive<M: llm::KnownModel + 'static>(
log::error!("Reply exceeds context window length");
}

if let Some(session_backup) = session_backup {
session = session_backup;
// Reload session in REPL mode
if !chat_mode {
(session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
args.generate.load_session.as_deref(),
inference_session_config,
);
}
}
Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => {
Expand Down
1 change: 1 addition & 0 deletions crates/ggml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ license = "MIT"
[dependencies]
thiserror = { workspace = true }
ggml-sys = { path = "sys", version = "0.2.0-dev" }
memmap2 = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
Expand Down
61 changes: 50 additions & 11 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::{
os::raw::{c_int, c_void},
ptr::NonNull,
sync::Arc,
};
use std::{os::raw::c_int, ptr::NonNull, sync::Arc};

use memmap2::Mmap;

use crate::{sys, usize_to_i32, usize_to_i64, Buffer, ComputationGraph, Tensor, Type};

Expand All @@ -13,23 +11,65 @@ pub struct Context {
/// allocated tensors. Tensors are owned by the object, so a [`Tensor`]
/// contains a `Weak` reference underneath and doesn't let you do anything
/// with it if the underlying context has been deallocated.
ptr: Arc<NonNull<sys::ggml_context>>,
pub ptr: Arc<NonNull<sys::ggml_context>>,

/// Memory mapping information
pub mmap: Option<Mmap>,

/// Backing buffer (in case we own it)
pub buffer: Option<Buffer>,
}

impl Context {
/// Creates a new [Context] using the buffer provided as memory
pub fn init_buffer(buffer: Buffer) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size: buffer.size(),
mem_buffer: buffer.data,
no_alloc: false,
})
};

Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
mmap: None,
buffer: Some(buffer),
}
}

/// Creates a new [Context] with the memory mapped file provided
pub fn init_mmap(mmap: Mmap) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size: mmap.len(),
mem_buffer: std::ptr::null_mut(),
no_alloc: true, // We are mmapping so ggml does not need to allocate any memory for us
})
};

Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
mmap: Some(mmap),
buffer: None,
}
}

/// Creates a new [Context] with the specified `mem_size` as a working area.
pub fn init(mem_size: usize, alloc: bool) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size,
// Null here means we want ggml to own this memory. We don't
// support passing an owned buffer from the Rust side.
// Null here means we want ggml to own this memory.
mem_buffer: std::ptr::null_mut(),
no_alloc: !alloc,
})
};

Self {
ptr: Arc::new(NonNull::new(raw).expect("Should not be null")),
mmap: None,
buffer: None,
}
}

Expand Down Expand Up @@ -391,7 +431,7 @@ impl Context {
/// If `scratch_buffer` is `None`, the scratch buffer will be disabled.
pub fn use_scratch<'a>(&'a self, scratch_buffer: Option<&'a mut Buffer>) {
let (size, data) = if let Some(buffer) = scratch_buffer {
(buffer.data.len(), buffer.data.as_ptr() as *mut c_void)
(buffer.size(), buffer.data)
} else {
(0, std::ptr::null_mut())
};
Expand Down Expand Up @@ -432,8 +472,7 @@ impl Context {

impl Drop for Context {
fn drop(&mut self) {
// SAFETY: The only non-weak copy of ptr is no longer accessible after
// this drop call.
// SAFETY: The only non-weak copy of ptr is no longer accessible after this drop call.
unsafe {
sys::ggml_free(self.ptr.as_ptr());
}
Expand Down
36 changes: 27 additions & 9 deletions crates/ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
//! All [Tensor]s are nodes in this computational graph, and values cannot be retrieved until computation is completed.
#![deny(missing_docs)]

use std::os::raw::{c_int, c_void};
use std::{
alloc::Layout,
os::raw::{c_int, c_void},
};

mod context;
mod tensor;
Expand All @@ -23,6 +26,9 @@ pub(crate) use ggml_sys as sys;
#[cfg(test)]
mod tests;

#[cfg(feature = "metal")]
pub mod metal;

/// The type of a tensor element.
pub type ElementType = Type;

Expand Down Expand Up @@ -218,23 +224,35 @@ impl Type {
///
/// See [Context::use_scratch].
pub struct Buffer {
data: Box<[u8]>,
data: *mut c_void,
layout: Layout,
}

const BUFFER_ALIGN: usize = 16384;

impl Buffer {
/// Creates a new buffer of the specified size.
pub fn new(size: usize) -> Self {
let mut data: Vec<u8> = Vec::with_capacity(size);
let layout = Layout::from_size_align(size, BUFFER_ALIGN).unwrap();

// SAFETY: The contents are intentionally uninitialized, as they will be passed to
// the ggml C API which will fill them with data.
#[allow(clippy::uninit_vec)]
unsafe {
data.set_len(size);
Buffer {
data: std::alloc::alloc(layout).cast(),
layout,
}
}
}

Buffer {
data: data.into_boxed_slice(),
/// Returns the size of the buffer in bytes
pub fn size(&self) -> usize {
self.layout.size()
}
}

impl Drop for Buffer {
fn drop(&mut self) {
unsafe {
std::alloc::dealloc(self.data.cast(), self.layout);
}
}
}
Expand Down
Loading

0 comments on commit 7927d0d

Please sign in to comment.