Skip to content

Commit

Permalink
不安装torch也可以使用部分功能
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongjilibo committed Jul 10, 2024
1 parent 2ca5d27 commit a62b74d
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions bert4vector/snippets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,31 @@
import heapq
import queue
from typing import Union

import numpy as np
import torch
import torch.nn.functional
from torch4keras.snippets import is_torch_available


# 允许不安装torch来使用
if is_torch_available():
import torch
import torch.nn.functional
torch_Tensor = torch.Tensor
else:
class torch_Tensor:
pass


def cos_sim(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
def cos_sim(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
is_tensor = True
if not isinstance(a, torch.Tensor):
if not isinstance(a, torch_Tensor):
a = torch.tensor(a, dtype=torch.float)
is_tensor = False

if not isinstance(b, torch.Tensor):
if not isinstance(b, torch_Tensor):
b = torch.tensor(b, dtype=torch.float)
is_tensor = False

Expand All @@ -39,17 +47,17 @@ def cos_sim(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarra
return scores if is_tensor else scores.numpy()


def dot_score(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
def dot_score(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]):
"""
Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = dot_prod(a[i], b[j])
"""
is_tensor = True
if not isinstance(a, torch.Tensor):
if not isinstance(a, torch_Tensor):
a = torch.tensor(a)
is_tensor = False

if not isinstance(b, torch.Tensor):
if not isinstance(b, torch_Tensor):
b = torch.tensor(b)
is_tensor = False

Expand All @@ -63,44 +71,44 @@ def dot_score(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndar
return scores if is_tensor else scores.numpy()


def pairwise_dot_score(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
def pairwise_dot_score(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]):
"""
Computes the pairwise dot-product dot_prod(a[i], b[i])
:return: Vector with res[i] = dot_prod(a[i], b[i])
"""
if not isinstance(a, torch.Tensor):
if not isinstance(a, torch_Tensor):
a = torch.tensor(a)

if not isinstance(b, torch.Tensor):
if not isinstance(b, torch_Tensor):
b = torch.tensor(b)

return (a * b).sum(dim=-1)


def pairwise_cos_sim(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]):
def pairwise_cos_sim(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]):
"""
Computes the pairwise cossim cos_sim(a[i], b[i])
:return: Vector with res[i] = cos_sim(a[i], b[i])
"""
if not isinstance(a, torch.Tensor):
if not isinstance(a, torch_Tensor):
a = torch.tensor(a)

if not isinstance(b, torch.Tensor):
if not isinstance(b, torch_Tensor):
b = torch.tensor(b)

return pairwise_dot_score(normalize_embeddings(a), normalize_embeddings(b))


def normalize_embeddings(embeddings: torch.Tensor):
def normalize_embeddings(embeddings: torch_Tensor):
"""
Normalizes the embeddings matrix, so that each sentence embedding has unit length
"""
return torch.nn.functional.normalize(embeddings, p=2, dim=1)


def semantic_search(
query_embeddings: Union[torch.Tensor, np.ndarray],
corpus_embeddings: Union[torch.Tensor, np.ndarray],
query_embeddings: Union[torch_Tensor, np.ndarray],
corpus_embeddings: Union[torch_Tensor, np.ndarray],
query_chunk_size: int = 100,
corpus_chunk_size: int = 500000,
top_k: int = 10,
Expand Down Expand Up @@ -176,7 +184,7 @@ def semantic_search(


def paraphrase_mining_embeddings(
embeddings: torch.Tensor,
embeddings: torch_Tensor,
query_chunk_size: int = 5000,
corpus_chunk_size: int = 100000,
max_pairs: int = 500000,
Expand Down Expand Up @@ -249,7 +257,7 @@ def community_detection(embeddings, threshold=0.75, min_community_size=10, batch
Returns only communities that are larger than min_community_size. The communities are returned
in decreasing order. The first element in each list is the central point in the community.
"""
if not isinstance(embeddings, torch.Tensor):
if not isinstance(embeddings, torch_Tensor):
embeddings = torch.tensor(embeddings)

threshold = torch.tensor(threshold, device=embeddings.device)
Expand Down

0 comments on commit a62b74d

Please sign in to comment.