Skip to content

Commit

Permalink
Allow label selection on attribute extraction tasks (#908)
Browse files Browse the repository at this point in the history
* Allow label selection on attribute extraction tasks

* Simplify logic for when to use label selection

* label selection dict --> int

---------

Co-authored-by: Nihit <[email protected]>
  • Loading branch information
rajasbansal and nihit authored Oct 2, 2024
1 parent d1417b2 commit c3436f7
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 103 deletions.
34 changes: 18 additions & 16 deletions src/autolabel/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ class AutolabelConfig(BaseConfig):
OUTPUT_GUIDELINE_KEY = "output_guidelines"
OUTPUT_FORMAT_KEY = "output_format"
CHAIN_OF_THOUGHT_KEY = "chain_of_thought"
LABEL_SELECTION_KEY = "label_selection"
LABEL_SELECTION_COUNT_KEY = "label_selection_count"
LABEL_SELECTION_THRESHOLD = "label_selection_threshold"
ATTRIBUTES_KEY = "attributes"
TRANSFORM_KEY = "transforms"
LABEL_SELECTION_KEY = "label_selection"
LABEL_SELECTION_COUNT_KEY = "label_selection_count"

# Dataset generation config keys (config["dataset_generation"][<key>])
DATASET_GENERATION_GUIDELINES_KEY = "guidelines"
Expand Down Expand Up @@ -243,22 +242,25 @@ def chain_of_thought(self) -> bool:

def label_selection(self) -> bool:
"""Returns true if label selection is enabled. Label selection is the process of
narrowing down the list of possible labels by similarity to a given input. Useful for
classification tasks with a large number of possible classes."""
return self._prompt_config.get(self.LABEL_SELECTION_KEY, False)
narrowing down the list of possible labels by similarity to a given input."""
for attribute in self.attributes():
if attribute.get(self.LABEL_SELECTION_KEY, False):
return True
return False

def label_selection_attribute(self) -> str:
"""Returns the attribute to use for label selection"""
for attribute in self.attributes():
if attribute.get(self.LABEL_SELECTION_KEY, False):
return attribute["name"]
return None

def max_selected_labels(self) -> int:
"""Returns the number of labels to select in LabelSelector"""
k = self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)
if k < 1:
return len(self.labels_list())
return k

def label_selection_threshold(self) -> float:
"""Returns the threshold for label selection in LabelSelector
If the similarity score ratio with the top Score is above this threshold,
the label is selected."""
return self._prompt_config.get(self.LABEL_SELECTION_THRESHOLD, 0.0)
for attribute in self.attributes():
if attribute.get(self.LABEL_SELECTION_KEY):
return attribute.get(self.LABEL_SELECTION_KEY, 10)
return None

def attributes(self) -> List[Dict]:
"""Returns a list of attributes to extract from the text."""
Expand Down
2 changes: 0 additions & 2 deletions src/autolabel/configs/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def populate_few_shot_selection() -> List[str]:
},
"few_shot_num": {"type": ["number", "null"]},
"chain_of_thought": {"type": ["boolean", "null"]},
"label_selection": {"type": ["boolean", "null"]},
"label_selection_count": {"type": ["number", "null"]},
"attributes": {
"anyOf": [
{"type": "array", "items": {"type": "object"}},
Expand Down
2 changes: 2 additions & 0 deletions src/autolabel/few_shot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
LabelDiversitySimilarityExampleSelector,
)
from .vector_store import VectorStoreWrapper
from .base_label_selector import BaseLabelSelector
from .label_selector import LabelSelector

ALGORITHM_TO_IMPLEMENTATION: Dict[FewShotAlgorithm, BaseExampleSelector] = {
FewShotAlgorithm.FIXED: FixedExampleSelector,
Expand Down
8 changes: 8 additions & 0 deletions src/autolabel/few_shot/base_label_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from typing import List


class BaseLabelSelector(ABC):
@abstractmethod
def select_labels(self, input: str) -> List[str]:
pass
63 changes: 32 additions & 31 deletions src/autolabel/few_shot/label_selector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import annotations

import bisect
from collections.abc import Callable
from typing import Dict, List, Optional, Tuple, Union

from sqlalchemy.sql import text as sql_text
from typing import Dict, List, Optional, Union

from autolabel.configs import AutolabelConfig
from autolabel.few_shot.vector_store import VectorStoreWrapper, cos_sim
from autolabel.few_shot.base_label_selector import BaseLabelSelector
from autolabel.few_shot.vector_store import VectorStoreWrapper


class LabelSelector:
class LabelSelector(BaseLabelSelector):
"""Returns the most similar labels to a given input. Used for
classification tasks with a large number of possible classes."""

Expand All @@ -34,37 +32,40 @@ def __init__(
cache: bool = True,
) -> None:
self.config = config
self.labels = self.config.labels_list()
self.label_descriptions = self.config.label_descriptions()
self.label_selection_attribute = self.config.label_selection_attribute()
self.labels = []
self.label_descriptions = None

attributes = self.config.attributes()
matching_attribute = next(
(
attr
for attr in attributes
if attr["name"] == self.label_selection_attribute
),
None,
)

if matching_attribute is None:
raise ValueError(
f"No attribute found with name '{self.label_selection_attribute}'"
)

self.labels = matching_attribute.get("options", [])
if not self.labels:
raise ValueError(
f"Attribute '{self.label_selection_attribute}' does not have any options"
)

self.k = min(self.config.max_selected_labels(), len(self.labels))
self.threshold = self.config.label_selection_threshold()
self.cache = cache
self.vectorStore = VectorStoreWrapper(
embedding_function=embedding_func, cache=self.cache
)

# Get the embeddings of the labels
if self.label_descriptions is not None:
(labels, descriptions) = zip(*self.label_descriptions.items())
self.labels = list(labels)
self.labels_embeddings = torch.Tensor(
self.vectorStore._get_embeddings(descriptions)
)
else:
self.labels_embeddings = torch.Tensor(
self.vectorStore._get_embeddings(self.labels)
)
self.vectorStore.add_texts(self.labels)

def select_labels(self, input: str) -> List[str]:
"""Select which labels to use based on the similarity to input"""
input_embedding = torch.Tensor(self.vectorStore._get_embeddings([input]))
scores = cos_sim(input_embedding, self.labels_embeddings).view(-1)
scores = list(zip(scores, self.labels))
scores.sort(key=lambda x: x[0])

# remove labels with similarity score less than self.threshold*topScore
return [
label
for (score, label) in scores[-self.k :]
if score > self.threshold * scores[-1][0]
]
documents = self.vectorStore.similarity_search(input, k=self.k)
return [doc.page_content for doc in documents]
70 changes: 24 additions & 46 deletions src/autolabel/few_shot/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@

logger = logging.getLogger(__name__)

try:
import torch
except ImportError:
logger.warning(
"Torch is not installed. Please install torch to use the VectorStoreWrapper."
)

from sqlalchemy.sql import text as sql_text

EMBEDDINGS_TABLE = "autolabel_embeddings"
Expand All @@ -46,21 +39,21 @@ def cos_sim(a, b):
Returns:
cos_sim: Matrix with res(i)(j) = cos_sim(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(a, np.ndarray):
a = np.array(a)

if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if not isinstance(b, np.ndarray):
b = np.array(b)

if len(a.shape) == 1:
a = a.unsqueeze(0)
a = a.reshape(1, -1)

if len(b.shape) == 1:
b = b.unsqueeze(0)
b = b.reshape(1, -1)

a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))
a_norm = a / np.linalg.norm(a, axis=1, keepdims=True)
b_norm = b / np.linalg.norm(b, axis=1, keepdims=True)
return np.dot(a_norm, b_norm.T)


def semantic_search(
Expand All @@ -75,22 +68,14 @@ def semantic_search(
Semantic similarity search based on cosine similarity score. Implementation from this project: https://github.com/UKPLab/sentence-transformers
"""

if isinstance(query_embeddings, (np.ndarray, np.generic)):
query_embeddings = torch.from_numpy(query_embeddings)
elif isinstance(query_embeddings, list):
query_embeddings = torch.stack(query_embeddings)
if isinstance(query_embeddings, list):
query_embeddings = np.array(query_embeddings)

if len(query_embeddings.shape) == 1:
query_embeddings = query_embeddings.unsqueeze(0)

if isinstance(corpus_embeddings, (np.ndarray, np.generic)):
corpus_embeddings = torch.from_numpy(corpus_embeddings)
elif isinstance(corpus_embeddings, list):
corpus_embeddings = torch.stack(corpus_embeddings)
query_embeddings = query_embeddings.reshape(1, -1)

# Check that corpus and queries are on the same device
if corpus_embeddings.device != query_embeddings.device:
query_embeddings = query_embeddings.to(corpus_embeddings.device)
if isinstance(corpus_embeddings, list):
corpus_embeddings = np.array(corpus_embeddings)

queries_result_list = [[] for _ in range(len(query_embeddings))]

Expand All @@ -106,15 +91,10 @@ def semantic_search(
)

# Get top-k scores
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
cos_scores,
min(top_k, len(cos_scores[0])),
dim=1,
largest=True,
sorted=False,
)
cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
cos_scores_top_k_values = np.sort(cos_scores, axis=1)[:, -top_k:][:, ::-1]
cos_scores_top_k_idx = np.argsort(cos_scores, axis=1)[:, -top_k:][:, ::-1]
cos_scores_top_k_values = cos_scores_top_k_values.tolist()
cos_scores_top_k_idx = cos_scores_top_k_idx.tolist()

for query_itr in range(len(cos_scores)):
for sub_corpus_id, score in zip(
Expand Down Expand Up @@ -256,9 +236,9 @@ def add_texts(
if self._embedding_function is not None:
embeddings = self._get_embeddings(texts)

self._corpus_embeddings = torch.tensor(embeddings)
self._corpus_embeddings = np.array(embeddings)
self._texts = texts
self._metadatas = metadatas
self._metadatas = metadatas or [{} for _ in texts]
return metadatas

def similarity_search(
Expand Down Expand Up @@ -295,7 +275,7 @@ def similarity_search_with_score(
List[Tuple[Document, float]]: List of documents most similar to the query
text with distance in float.
"""
query_embeddings = torch.tensor([self._get_embeddings([query])[0]])
query_embeddings = np.array([self._get_embeddings([query])[0]])
result_ids_and_scores = semantic_search(
corpus_embeddings=self._corpus_embeddings,
query_embeddings=query_embeddings,
Expand Down Expand Up @@ -347,7 +327,7 @@ def label_diversity_similarity_search_with_score(
List[Tuple[Document, float]]: List of documents most similar to the query
text with distance in float.
"""
query_embeddings = torch.tensor([self._get_embeddings([query])[0]])
query_embeddings = np.array([self._get_embeddings([query])[0]])
data = []
data = zip(self._corpus_embeddings, self._texts, self._metadatas)
sorted_data = sorted(data, key=lambda item: item[2].get(label_key))
Expand Down Expand Up @@ -395,7 +375,7 @@ def max_marginal_relevance_search_by_vector(
**kwargs: Any,
) -> List[Document]:
query_embedding = self._get_embeddings([query])[0]
query_embeddings = torch.tensor([query_embedding])
query_embeddings = np.array([query_embedding])
result_ids_and_scores = semantic_search(
corpus_embeddings=self._corpus_embeddings,
query_embeddings=query_embeddings,
Expand All @@ -404,9 +384,7 @@ def max_marginal_relevance_search_by_vector(
result_ids = [result["corpus_id"] for result in result_ids_and_scores[0]]
scores = [result["score"] for result in result_ids_and_scores[0]]

fetched_embeddings = torch.index_select(
input=self._corpus_embeddings, dim=0, index=torch.tensor(result_ids)
).tolist()
fetched_embeddings = self._corpus_embeddings[result_ids].tolist()
mmr_selected = maximal_marginal_relevance(
np.array([query_embedding], dtype=np.float32),
fetched_embeddings,
Expand Down
Loading

0 comments on commit c3436f7

Please sign in to comment.