Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Rust 1.72 fixes #416

Merged
merged 2 commits into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/ggml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ description = "Semi-idiomatic Rust bindings for the ggml library (from `ggml-sys
license = "MIT"

[dependencies]
thiserror = { workspace = true }
ggml-sys = { path = "sys", version = "0.2.0-dev" }

thiserror = { workspace = true }
memmap2 = { workspace = true }

[dev-dependencies]
Expand Down
13 changes: 11 additions & 2 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ impl PartialEq for ContextInner {
impl Eq for ContextInner {}
impl ContextInner {
pub(crate) fn new(ptr: *mut ggml_sys::ggml_context) -> Arc<Self> {
// This context can only be used from one thread at a time - hence why
// it doesn't implement `Send/Sync` - but higher-level abstractions may
// choose to layer their own abstractions that implement higher-level
// synchronization that can offer thread-safety guarantees. To ensure
// that we don't break those, we still use an `Arc` here.
// TODO: check if this is correct?
#[allow(clippy::arc_with_non_send_sync)]
Arc::new(Self {
ptr: NonNull::new(ptr).expect("Should not be null"),
offloaded_tensors: Default::default(),
Expand Down Expand Up @@ -118,7 +125,9 @@ impl PartialEq for ContextStorage {
impl Eq for ContextStorage {}

impl Context {
/// Creates a new [Context] with the given storage..
// See explanation in [`ContextInner::new`].
#[allow(clippy::arc_with_non_send_sync)]
/// Creates a new [Context] with the given storage.
pub fn new(storage: ContextStorage) -> Self {
let init_params = match &storage {
ContextStorage::Buffer(buffer) => sys::ggml_init_params {
Expand Down Expand Up @@ -296,7 +305,7 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// Repeats the `a` tensor along the first dimension of the `b` tensor.
/// Repeats the `a` tensor along the first dimension of the `b` tensor.
pub fn op_repeat(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_repeat(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
Expand Down
2 changes: 1 addition & 1 deletion crates/ggml/src/format/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub fn load<E: Error, R: BufRead + Seek>(
match container_type {
ContainerType::Ggml
| ContainerType::Ggmf(1)
| ContainerType::Ggjt(1 | 2 | 3)
| ContainerType::Ggjt(1..=3)
| ContainerType::Ggla(1) => {}
_ => return Err(LoadError::InvalidFormatVersion(container_type)),
}
Expand Down
12 changes: 8 additions & 4 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use tracing::{instrument, log};
use ggml::accelerator::metal::MetalContext;

use crate::{
mulf, util, InferenceParameters, Model, ModelParameters, OutputRequest, Prompt, TokenId,
TokenUtf8Buffer, TokenizationError,
mulf, util, InferenceParameters, Model, ModelContext, ModelParameters, OutputRequest, Prompt,
TokenId, TokenUtf8Buffer, TokenizationError,
};

// The size of a scratch buffer used for inference. This is used for temporary
Expand Down Expand Up @@ -148,6 +148,10 @@ impl InferenceSession {
ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024);
}

// TODO: revisit this with `Rc`, maybe? We should be able to prove that the session
// context is only accessed from one thread at a time, but I've already spent enough
// time on this as-is.
#[allow(clippy::arc_with_non_send_sync)]
let session_ctx = Arc::new(ggml::Context::new_with_allocate(context_byte_size));

// Initialize key + value memory tensors
Expand Down Expand Up @@ -215,7 +219,7 @@ impl InferenceSession {
/// Compute a model (possibly building a graph in the provided closure when called for the first time and/or when parameters have)
pub fn compute<F>(
&mut self,
#[allow(unused_variables)] model_context: Arc<Context>,
#[allow(unused_variables)] model_context: ModelContext,
input_tokens: &[TokenId],
builder: F,
) -> GraphOutputs
Expand All @@ -242,7 +246,7 @@ impl InferenceSession {
#[cfg(feature = "metal")]
{
if let Some(ref mut metal_context) = self.metal_context {
metal_context.add_context(model_context);
metal_context.add_context(model_context.0);
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub use loader::{
};
pub use lora::{LoraAdapter, LoraParameters};
pub use memmap2::Mmap;
pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest};
pub use model::{Hyperparameters, KnownModel, Model, ModelContext, ModelParameters, OutputRequest};
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use regex::Regex;
pub use tokenizer::{
Expand Down
32 changes: 12 additions & 20 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use std::{
fs::File,
io::{BufRead, BufReader, Read, Seek, SeekFrom},
path::{Path, PathBuf},
sync::Arc,
};

use crate::{
util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId,
Tokenizer, TokenizerLoadError, TokenizerSource,
util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelContext, ModelParameters,
TokenId, Tokenizer, TokenizerLoadError, TokenizerSource,
};
pub use ggml::{format::FormatMagic, ContainerType};
use ggml::{
Expand Down Expand Up @@ -398,7 +399,7 @@ pub trait TensorLoader<E: std::error::Error> {
/// Gets a tensor from the loader.
fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
/// Finish loading the model, returning the context.
fn finish(self) -> Context;
fn finish(self) -> ModelContext;
}

/// Load a GGML model from the `path` and configure it per the `params`. The status
Expand Down Expand Up @@ -653,12 +654,7 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
path: Default::default(),
})?;

let mut main_context = FileContext::new(
&self.context,
&mut self.file,
&self.path,
self.context.storage().as_mmap(),
);
let mut main_context = FileContext::new(&self.context, &mut self.file, &self.path);

let mut tensor = main_context.get_tensor(info)?;

Expand All @@ -681,29 +677,25 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
Ok(tensor)
}

fn finish(self) -> Context {
self.context
fn finish(self) -> ModelContext {
// We can ignore this warning as it's OK to share this particular
// context around, being that it is immutable.
#[allow(clippy::arc_with_non_send_sync)]
ModelContext(Arc::new(self.context))
}
}

pub(crate) struct FileContext<'a> {
context: &'a Context,
file: &'a mut File,
path: &'a Path,
mmap: Option<&'a Mmap>,
}
impl<'a> FileContext<'a> {
pub(crate) fn new(
context: &'a Context,
file: &'a mut File,
path: &'a Path,
mmap: Option<&'a Mmap>,
) -> Self {
pub(crate) fn new(context: &'a Context, file: &'a mut File, path: &'a Path) -> Self {
Self {
context,
file,
path,
mmap,
}
}

Expand Down Expand Up @@ -738,7 +730,7 @@ impl<'a> FileContext<'a> {
}
};

match self.mmap {
match self.context.storage().as_mmap() {
Some(mmap) => unsafe {
let ptr = mmap.as_ptr().offset(info.start_offset as isize);
tensor.set_data(ptr as *mut std::ffi::c_void);
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lora.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl LoraAdapter {
// Create a temporary context for the patching operations
// TODO: test if GPU can be enabled (make it configurable)
let patch_context = ggml::Context::new_with_allocate(patch_context_size);
let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path, None);
let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path);

// Load the A and B tensors
let a = patch_file.get_tensor(&a_info)?;
Expand Down
11 changes: 11 additions & 0 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
fmt::Debug,
io::{BufRead, Write},
path::{Path, PathBuf},
sync::Arc,
};

use ggml::accelerator::Backend;
Expand Down Expand Up @@ -263,3 +264,13 @@ pub struct OutputRequest {
/// `n_batch * n_embd`.
pub embeddings: Option<Vec<f32>>,
}

/// Contains the GGML context for a [`Model`]. Implements `Send` and `Sync`
/// to allow for the free transfer of models; this is made possible by this
/// context being effectively inert after creation, so that it cannot be
/// modified across threads.
#[derive(Clone)]
#[allow(clippy::arc_with_non_send_sync)]
pub struct ModelContext(pub(crate) Arc<ggml::Context>);
unsafe impl Send for ModelContext {}
unsafe impl Sync for ModelContext {}
8 changes: 3 additions & 5 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
//! for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom)
Expand Down Expand Up @@ -37,7 +35,7 @@ pub struct Bloom {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Bloom {}
Expand Down Expand Up @@ -101,7 +99,7 @@ impl KnownModel for Bloom {
output_norm_bias,
output,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 3 additions & 5 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
//! supported. It is currently only available as a preview.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae)
Expand All @@ -39,7 +37,7 @@ pub struct Falcon {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Falcon {}
Expand Down Expand Up @@ -138,7 +136,7 @@ impl KnownModel for Falcon {
output_norm_b,
lm_head,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 3 additions & 5 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
//! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/)
Expand Down Expand Up @@ -38,7 +36,7 @@ pub struct Gpt2 {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Gpt2 {}
Expand Down Expand Up @@ -123,7 +121,7 @@ impl KnownModel for Gpt2 {
wte,
wpe,
lm_head,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 4 additions & 4 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
//! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::error::Error;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
};

/// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b)
Expand All @@ -35,7 +35,7 @@ pub struct GptJ {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for GptJ {}
Expand Down Expand Up @@ -117,7 +117,7 @@ impl KnownModel for GptJ {
lmh_g,
lmh_b,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 4 additions & 4 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
//! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::error::Error;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
};

/// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox)
Expand All @@ -35,7 +35,7 @@ pub struct GptNeoX {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for GptNeoX {}
Expand Down Expand Up @@ -137,7 +137,7 @@ impl KnownModel for GptNeoX {
wte,
lmh_g,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
Loading