diff --git a/ext/candle/src/lib.rs b/ext/candle/src/lib.rs index d7645f2..57f7b5d 100644 --- a/ext/candle/src/lib.rs +++ b/ext/candle/src/lib.rs @@ -92,7 +92,10 @@ fn init(ruby: &Ruby) -> RbResult<()> { rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?; let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?; + rb_model.define_singleton_method("new", function!(ModelConfig::new, 0))?; rb_model.define_method("embedding", method!(ModelConfig::embedding, 1))?; + rb_model.define_method("to_s", method!(ModelConfig::__str__, 0))?; + rb_model.define_method("inspect", method!(ModelConfig::__repr__, 0))?; Ok(()) } diff --git a/ext/candle/src/model/config.rs b/ext/candle/src/model/config.rs index 02f8a40..f74a91f 100644 --- a/ext/candle/src/model/config.rs +++ b/ext/candle/src/model/config.rs @@ -11,12 +11,14 @@ use crate::model::{ use candle_core::{DType, Device, Module, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::jina_bert::{BertModel, Config}; -use core::result::Result; use magnus::Error; +use crate::model::RbResult; use tokenizers::Tokenizer; #[magnus::wrap(class = "Candle::Model", free_immediately, size)] -pub struct ModelConfig { +pub struct ModelConfig(pub ModelConfigInner); + +pub struct ModelConfigInner { device: Device, tokenizer_path: Option, @@ -25,17 +27,33 @@ pub struct ModelConfig { } impl ModelConfig { - pub fn build() -> ModelConfig { - ModelConfig { + pub fn new() -> RbResult { + Ok(ModelConfig(ModelConfigInner { device: Device::Cpu, model_path: None, tokenizer_path: None, - } + })) + } + + pub fn build() -> ModelConfig { + ModelConfig(ModelConfigInner { + device: Device::Cpu, + model_path: None, + tokenizer_path: None + }) } - pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, tokenizers::Tokenizer), Error> { + /// Performs the `sin` operation on the tensor. + /// &RETURNS&: Tensor + pub fn embedding(&self, input: String) -> RbResult { + let config = ModelConfig::build(); + let (model, tokenizer) = config.build_model_and_tokenizer()?; + Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?)) + } + + fn build_model_and_tokenizer(&self) -> Result<(BertModel, tokenizers::Tokenizer), Error> { use hf_hub::{api::sync::Api, Repo, RepoType}; - let model_path = match &self.model_path { + let model_path = match &self.0.model_path { Some(model_file) => std::path::PathBuf::from(model_file), None => Api::new() .map_err(wrap_hf_err)? @@ -46,7 +64,7 @@ impl ModelConfig { .get("model.safetensors") .map_err(wrap_hf_err)?, }; - let tokenizer_path = match &self.tokenizer_path { + let tokenizer_path = match &self.0.tokenizer_path { Some(file) => std::path::PathBuf::from(file), None => Api::new() .map_err(wrap_hf_err)? @@ -61,19 +79,13 @@ impl ModelConfig { let config = Config::v2_base(); let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path).map_err(wrap_std_err)?; let vb = unsafe { - VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.device) + VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.0.device) .map_err(wrap_candle_err)? }; let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?; Ok((model, tokenizer)) } - pub fn embedding(&self, input: String) -> Result { - let config = ModelConfig::build(); - let (model, tokenizer) = config.build_model_and_tokenizer()?; - return Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?)); - } - fn compute_embedding( &self, prompt: String, @@ -91,7 +103,7 @@ impl ModelConfig { .map_err(wrap_std_err)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], &self.device) + let token_ids = Tensor::new(&tokens[..], &self.0.device) .map_err(wrap_candle_err)? .unsqueeze(0) .map_err(wrap_candle_err)?; @@ -100,7 +112,15 @@ impl ModelConfig { let result = model.forward(&token_ids).map_err(wrap_candle_err)?; println!("{result}"); println!("Took {:?}", start.elapsed()); - return Ok(result); + Ok(result) + } + + pub fn __repr__(&self) -> String { + format!("Candle::Model(path={})", self.0.model_path.as_deref().unwrap_or("None")) + } + + pub fn __str__(&self) -> String { + self.__repr__() } }