-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding.py
32 lines (27 loc) · 917 Bytes
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# embedding.py
import numpy as np
from constants import EMBEDDING_MODEL_NAME
from sentence_transformers import SentenceTransformer
from typing import List
class Embedder:
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
@staticmethod
def get_embedding(phrase: str) -> np.ndarray:
"""
Get the embedding for a single phrase.
Args:
phrase (str): The phrase to embed.
Returns:
np.ndarray: The embedding vector.
"""
return Embedder.embedding_model.encode([phrase])[0]
@staticmethod
def get_embedding(phrases: List[str]) -> List[np.ndarray]:
"""
Get embeddings for a list of phrases.
Args:
phrases (List[str]): The phrases to embed.
Returns:
List[np.ndarray]: The list of embedding vectors.
"""
return Embedder.embedding_model.encode(phrases)