diff --git a/bert4vector/snippets/util.py b/bert4vector/snippets/util.py index 23d6e6c..003c853 100644 --- a/bert4vector/snippets/util.py +++ b/bert4vector/snippets/util.py @@ -7,23 +7,31 @@ import heapq import queue from typing import Union - import numpy as np -import torch -import torch.nn.functional +from torch4keras.snippets import is_torch_available + + +# 允许不安装torch来使用 +if is_torch_available(): + import torch + import torch.nn.functional + torch_Tensor = torch.Tensor +else: + class torch_Tensor: + pass -def cos_sim(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): +def cos_sim(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]): """ Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) """ is_tensor = True - if not isinstance(a, torch.Tensor): + if not isinstance(a, torch_Tensor): a = torch.tensor(a, dtype=torch.float) is_tensor = False - if not isinstance(b, torch.Tensor): + if not isinstance(b, torch_Tensor): b = torch.tensor(b, dtype=torch.float) is_tensor = False @@ -39,17 +47,17 @@ def cos_sim(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarra return scores if is_tensor else scores.numpy() -def dot_score(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): +def dot_score(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]): """ Computes the dot-product dot_prod(a[i], b[j]) for all i and j. :return: Matrix with res[i][j] = dot_prod(a[i], b[j]) """ is_tensor = True - if not isinstance(a, torch.Tensor): + if not isinstance(a, torch_Tensor): a = torch.tensor(a) is_tensor = False - if not isinstance(b, torch.Tensor): + if not isinstance(b, torch_Tensor): b = torch.tensor(b) is_tensor = False @@ -63,35 +71,35 @@ def dot_score(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndar return scores if is_tensor else scores.numpy() -def pairwise_dot_score(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): +def pairwise_dot_score(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]): """ Computes the pairwise dot-product dot_prod(a[i], b[i]) :return: Vector with res[i] = dot_prod(a[i], b[i]) """ - if not isinstance(a, torch.Tensor): + if not isinstance(a, torch_Tensor): a = torch.tensor(a) - if not isinstance(b, torch.Tensor): + if not isinstance(b, torch_Tensor): b = torch.tensor(b) return (a * b).sum(dim=-1) -def pairwise_cos_sim(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray]): +def pairwise_cos_sim(a: Union[torch_Tensor, np.ndarray], b: Union[torch_Tensor, np.ndarray]): """ Computes the pairwise cossim cos_sim(a[i], b[i]) :return: Vector with res[i] = cos_sim(a[i], b[i]) """ - if not isinstance(a, torch.Tensor): + if not isinstance(a, torch_Tensor): a = torch.tensor(a) - if not isinstance(b, torch.Tensor): + if not isinstance(b, torch_Tensor): b = torch.tensor(b) return pairwise_dot_score(normalize_embeddings(a), normalize_embeddings(b)) -def normalize_embeddings(embeddings: torch.Tensor): +def normalize_embeddings(embeddings: torch_Tensor): """ Normalizes the embeddings matrix, so that each sentence embedding has unit length """ @@ -99,8 +107,8 @@ def normalize_embeddings(embeddings: torch.Tensor): def semantic_search( - query_embeddings: Union[torch.Tensor, np.ndarray], - corpus_embeddings: Union[torch.Tensor, np.ndarray], + query_embeddings: Union[torch_Tensor, np.ndarray], + corpus_embeddings: Union[torch_Tensor, np.ndarray], query_chunk_size: int = 100, corpus_chunk_size: int = 500000, top_k: int = 10, @@ -176,7 +184,7 @@ def semantic_search( def paraphrase_mining_embeddings( - embeddings: torch.Tensor, + embeddings: torch_Tensor, query_chunk_size: int = 5000, corpus_chunk_size: int = 100000, max_pairs: int = 500000, @@ -249,7 +257,7 @@ def community_detection(embeddings, threshold=0.75, min_community_size=10, batch Returns only communities that are larger than min_community_size. The communities are returned in decreasing order. The first element in each list is the central point in the community. """ - if not isinstance(embeddings, torch.Tensor): + if not isinstance(embeddings, torch_Tensor): embeddings = torch.tensor(embeddings) threshold = torch.tensor(threshold, device=embeddings.device)