diff --git a/fastrag/embedders/ipex_embedder.py b/fastrag/embedders/ipex_embedder.py index 89140b1..5f6cd7d 100644 --- a/fastrag/embedders/ipex_embedder.py +++ b/fastrag/embedders/ipex_embedder.py @@ -1,5 +1,6 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union +import torch from haystack.components.embedders import ( SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder, @@ -33,7 +34,7 @@ def __init__( import sentence_transformers class _IPEXSTTransformers(sentence_transformers.models.Transformer): - def _load_model(self, model_name_or_path, config, cache_dir, **model_args): + def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args): print("Loading IPEX ST Transformer model") optimized_intel_import.check() self.auto_model = IPEXModel.from_pretrained( @@ -89,23 +90,39 @@ def _load_auto_model( cache_folder: Optional[str], revision: Optional[str] = None, trust_remote_code: bool = False, + local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): """ Creates a simple Transformer + Mean Pooling model and returns the modules """ + + shared_kwargs = { + "token": token, + "trust_remote_code": trust_remote_code, + "revision": revision, + "local_files_only": local_files_only, + } + model_kwargs = ( + shared_kwargs if model_kwargs is None else {**shared_kwargs, **model_kwargs} + ) + tokenizer_kwargs = ( + shared_kwargs + if tokenizer_kwargs is None + else {**shared_kwargs, **tokenizer_kwargs} + ) + config_kwargs = ( + shared_kwargs if config_kwargs is None else {**shared_kwargs, **config_kwargs} + ) + transformer_model = _IPEXSTTransformers( model_name_or_path, cache_dir=cache_folder, - model_args={ - "token": token, - "trust_remote_code": trust_remote_code, - "revision": revision, - }, - tokenizer_args={ - "token": token, - "trust_remote_code": trust_remote_code, - "revision": revision, - }, + model_args=model_kwargs, + tokenizer_args=tokenizer_kwargs, + config_args=config_kwargs, ) pooling_model = sentence_transformers.models.Pooling( transformer_model.get_word_embedding_dimension(), "mean" @@ -114,7 +131,7 @@ def _load_auto_model( @property def device(self): - return "cpu" + return torch.device("cpu") self.model = _IPEXSentenceTransformer( model_name_or_path=model, @@ -132,7 +149,7 @@ def ipex_model_warm_up(self): """ Initializes the component. """ - if not hasattr(self, "embedding_backend"): + if not getattr(self, "embedding_backend", None): self.embedding_backend = _IPEXSentenceTransformersEmbeddingBackend( model=self.model, device=self.device.to_torch_str(), diff --git a/fastrag/rankers/bi_encoder_ranker.py b/fastrag/rankers/bi_encoder_ranker.py index 960d720..1e5d18d 100644 --- a/fastrag/rankers/bi_encoder_ranker.py +++ b/fastrag/rankers/bi_encoder_ranker.py @@ -110,7 +110,7 @@ def run( scores = torch.tensor(query_vector) @ doc_vectors.T ## perhaps need to break it into chunks scores = scores.reshape(len(documents)) # Store scores in documents_with_vectors - for doc, score in zip(documents_with_vectors, scores.tolist()): + for doc, score in zip(documents_with_vectors, scores.tolist()): doc.score = score indices = scores.cpu().sort(descending=True).indices