Skip to content

Commit

Permalink
merged main
Browse files Browse the repository at this point in the history
  • Loading branch information
cpetersen committed Mar 12, 2024
1 parent 7dd6f4c commit b989c2a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
3 changes: 3 additions & 0 deletions ext/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ fn init(ruby: &Ruby) -> RbResult<()> {
rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?;

let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
rb_model.define_singleton_method("new", function!(ModelConfig::new, 0))?;
rb_model.define_method("embedding", method!(ModelConfig::embedding, 1))?;
rb_model.define_method("to_s", method!(ModelConfig::__str__, 0))?;
rb_model.define_method("inspect", method!(ModelConfig::__repr__, 0))?;

Ok(())
}
54 changes: 37 additions & 17 deletions ext/candle/src/model/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ use crate::model::{
use candle_core::{DType, Device, Module, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::jina_bert::{BertModel, Config};
use core::result::Result;
use magnus::Error;
use crate::model::RbResult;
use tokenizers::Tokenizer;

#[magnus::wrap(class = "Candle::Model", free_immediately, size)]
pub struct ModelConfig {
pub struct ModelConfig(pub ModelConfigInner);

pub struct ModelConfigInner {
device: Device,

tokenizer_path: Option<String>,
Expand All @@ -25,17 +27,33 @@ pub struct ModelConfig {
}

impl ModelConfig {
pub fn build() -> ModelConfig {
ModelConfig {
pub fn new() -> RbResult<Self> {
Ok(ModelConfig(ModelConfigInner {
device: Device::Cpu,
model_path: None,
tokenizer_path: None,
}
}))
}

pub fn build() -> ModelConfig {
ModelConfig(ModelConfigInner {
device: Device::Cpu,
model_path: None,
tokenizer_path: None
})
}

pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, tokenizers::Tokenizer), Error> {
/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
pub fn embedding(&self, input: String) -> RbResult<RbTensor> {
let config = ModelConfig::build();
let (model, tokenizer) = config.build_model_and_tokenizer()?;
Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?))
}

fn build_model_and_tokenizer(&self) -> Result<(BertModel, tokenizers::Tokenizer), Error> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model_path = match &self.model_path {
let model_path = match &self.0.model_path {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()
.map_err(wrap_hf_err)?
Expand All @@ -46,7 +64,7 @@ impl ModelConfig {
.get("model.safetensors")
.map_err(wrap_hf_err)?,
};
let tokenizer_path = match &self.tokenizer_path {
let tokenizer_path = match &self.0.tokenizer_path {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()
.map_err(wrap_hf_err)?
Expand All @@ -61,19 +79,13 @@ impl ModelConfig {
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path).map_err(wrap_std_err)?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.device)
VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.0.device)
.map_err(wrap_candle_err)?
};
let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?;
Ok((model, tokenizer))
}

pub fn embedding(&self, input: String) -> Result<RbTensor, Error> {
let config = ModelConfig::build();
let (model, tokenizer) = config.build_model_and_tokenizer()?;
return Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?));
}

fn compute_embedding(
&self,
prompt: String,
Expand All @@ -91,7 +103,7 @@ impl ModelConfig {
.map_err(wrap_std_err)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &self.device)
let token_ids = Tensor::new(&tokens[..], &self.0.device)
.map_err(wrap_candle_err)?
.unsqueeze(0)
.map_err(wrap_candle_err)?;
Expand All @@ -100,7 +112,15 @@ impl ModelConfig {
let result = model.forward(&token_ids).map_err(wrap_candle_err)?;
println!("{result}");
println!("Took {:?}", start.elapsed());
return Ok(result);
Ok(result)
}

pub fn __repr__(&self) -> String {
format!("Candle::Model(path={})", self.0.model_path.as_deref().unwrap_or("None"))
}

pub fn __str__(&self) -> String {
self.__repr__()
}
}

Expand Down

0 comments on commit b989c2a

Please sign in to comment.