Skip to content

Commit

Permalink
refactor: switch back to Arc for sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 31, 2023
1 parent 223fbee commit b70df9e
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 59 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
- `llm::InferenceRequest` no longer implements `Default::default`.
- The `infer` callback now provides an `InferenceResponse` instead of a string to disambiguate the source of the token. Additionally, it now returns an `InferenceFeedback` to control whether or not the generation should continue.
- Several fields have been renamed:
- `n_context_tokens` -> `context_size`
- `n_context_tokens` -> `context_size`

# 0.1.1 (2023-05-08)

- Fix an issue with the binary build of `llm-cli`.

# 0.1.0 (2023-05-08)

Initial release.
Initial release.
39 changes: 16 additions & 23 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt, ops::Deref, path::PathBuf};
use std::{fmt, ops::Deref, path::PathBuf, sync::Arc};

use clap::{Parser, Subcommand, ValueEnum};
use color_eyre::eyre::{bail, Result, WrapErr};
Expand Down Expand Up @@ -306,31 +306,24 @@ impl Generate {
}
}

pub fn sampler(&self, eot: llm::TokenId) -> llm::samplers::TopPTopK {
llm::samplers::TopPTopK {
top_k: self.top_k,
top_p: self.top_p,
repeat_penalty: self.repeat_penalty,
temperature: self.temperature,
bias_tokens: self.token_bias.clone().unwrap_or_else(|| {
if self.ignore_eos {
TokenBias::new(vec![(eot, -1.0)])
} else {
TokenBias::default()
}
}),
repetition_penalty_last_n: self.repeat_last_n,
}
}

pub fn inference_parameters<'a>(
&self,
sampler: &'a dyn llm::Sampler,
) -> InferenceParameters<'a> {
pub fn inference_parameters(&self, eot: llm::TokenId) -> InferenceParameters {
InferenceParameters {
n_threads: self.num_threads(),
n_batch: self.batch_size,
sampler,
sampler: Arc::new(llm::samplers::TopPTopK {
top_k: self.top_k,
top_p: self.top_p,
repeat_penalty: self.repeat_penalty,
temperature: self.temperature,
bias_tokens: self.token_bias.clone().unwrap_or_else(|| {
if self.ignore_eos {
TokenBias::new(vec![(eot, -1.0)])
} else {
TokenBias::default()
}
}),
repetition_penalty_last_n: self.repeat_last_n,
}),
}
}
}
Expand Down
9 changes: 3 additions & 6 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
args.generate.load_session.as_deref(),
inference_session_config,
);
let sampler = args.generate.sampler(model.eot_token_id());
let parameters = args.generate.inference_parameters(&sampler);
let parameters = args.generate.inference_parameters(model.eot_token_id());

let mut rng = args.generate.rng();
let res = session.infer::<Infallible>(
Expand Down Expand Up @@ -125,8 +124,7 @@ fn perplexity<M: llm::KnownModel + 'static>(args: &cli_args::Perplexity) -> Resu
args.generate.load_session.as_deref(),
inference_session_config,
);
let sampler = args.generate.sampler(model.eot_token_id());
let parameters = args.generate.inference_parameters(&sampler);
let parameters = args.generate.inference_parameters(model.eot_token_id());

session.perplexity(
model.as_ref(),
Expand Down Expand Up @@ -226,8 +224,7 @@ fn interactive<M: llm::KnownModel + 'static>(
args.generate.load_session.as_deref(),
inference_session_config,
);
let sampler = args.generate.sampler(model.eot_token_id());
let parameters = args.generate.inference_parameters(&sampler);
let parameters = args.generate.inference_parameters(model.eot_token_id());

let mut rng = args.generate.rng();
let mut rl = rustyline::Editor::<LineContinuationValidator, DefaultHistory>::new()?;
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ pub struct InferenceRequest<'a> {
/// The prompt to feed to the model.
pub prompt: Prompt<'a>,
/// The parameters to use during this inference attempt.
pub parameters: &'a InferenceParameters<'a>,
pub parameters: &'a InferenceParameters,
/// Whether or not to call the callback with the previous tokens
/// that were encountered in this session.
///
Expand Down
15 changes: 13 additions & 2 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub mod model;
pub mod samplers;
pub mod util;

use std::sync::Arc;

pub use ggml;
pub use ggml::Type as ElementType;

Expand Down Expand Up @@ -47,7 +49,7 @@ pub use vocabulary::{
///
/// This needs to be provided during all inference calls,
/// but can be changed between calls.
pub struct InferenceParameters<'a> {
pub struct InferenceParameters {
/// The number of threads to use. This is dependent on your user's system,
/// and should be selected accordingly.
///
Expand Down Expand Up @@ -81,5 +83,14 @@ pub struct InferenceParameters<'a> {
///
/// A recommended default sampler is [TopPTopK](samplers::TopPTopK), which is a standard
/// sampler that offers a [Default](samplers::TopPTopK::default) implementation.
pub sampler: &'a dyn Sampler,
pub sampler: Arc<dyn Sampler>,
}
impl Default for InferenceParameters {
fn default() -> Self {
Self {
n_threads: 8,
n_batch: 8,
sampler: Arc::new(samplers::TopPTopK::default()),
}
}
}
16 changes: 6 additions & 10 deletions crates/llm/examples/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ fn main() {
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
let inference_parameters = llm::InferenceParameters {
n_threads: 8,
n_batch: 8,
sampler: &llm::samplers::TopPTopK::default(),
};
let inference_parameters = llm::InferenceParameters::default();

// Generate embeddings for query and comparands
let query_embeddings = get_embeddings(model.as_ref(), &inference_parameters, query);
Expand Down Expand Up @@ -101,7 +97,7 @@ fn main() {
.map(|(text, embeddings)| {
(
text.as_str(),
cosine_similarity(&query_embeddings, &embeddings),
cosine_similarity(&query_embeddings, embeddings),
)
})
.collect();
Expand Down Expand Up @@ -133,7 +129,7 @@ fn get_embeddings(
.iter()
.map(|(_, tok)| *tok)
.collect::<Vec<_>>();
let _ = model.evaluate(
model.evaluate(
&mut session,
inference_parameters,
&query_token_ids,
Expand All @@ -143,9 +139,9 @@ fn get_embeddings(
}

fn cosine_similarity(v1: &[f32], v2: &[f32]) -> f32 {
let dot_product = dot(&v1, &v2);
let magnitude1 = magnitude(&v1);
let magnitude2 = magnitude(&v2);
let dot_product = dot(v1, v2);
let magnitude1 = magnitude(v1);
let magnitude2 = magnitude(v2);

dot_product / (magnitude1 * magnitude2)
}
Expand Down
6 changes: 1 addition & 5 deletions crates/llm/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ fn main() {
&mut rand::thread_rng(),
&llm::InferenceRequest {
prompt: prompt.into(),
parameters: &llm::InferenceParameters {
n_threads: 8,
n_batch: 8,
sampler: &llm::samplers::TopPTopK::default(),
},
parameters: &llm::InferenceParameters::default(),
play_back_previous_tokens: false,
maximum_token_count: None,
},
Expand Down
6 changes: 1 addition & 5 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ fn main() {
{character_name}: Paris is the capital of France."
);

let inference_parameters = llm::InferenceParameters {
n_threads: 8,
n_batch: 8,
sampler: &llm::samplers::TopPTopK::default(),
};
let inference_parameters = llm::InferenceParameters::default();

session
.feed_prompt(
Expand Down
6 changes: 1 addition & 5 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@
//! // inference parameters
//! &llm::InferenceRequest {
//! prompt: "Rust is a cool programming language because".into(),
//! parameters: &llm::InferenceParameters {
//! n_threads: 8,
//! n_batch: 8,
//! sampler: &llm::samplers::TopPTopK::default(),
//! },
//! parameters: &llm::InferenceParameters::default(),
//! play_back_previous_tokens: false,
//! maximum_token_count: None,
//! },
Expand Down

0 comments on commit b70df9e

Please sign in to comment.