Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加Embedding和rerank服务 #15

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
src/parse/db/*
*.onnx
__pycache__
logs
bce_model
10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
concurrent-log-handler==0.9.25
sanic==23.6.0
sanic_ext==23.6.0
onnxruntime==1.17.1
numpy==1.24.3
transformers==4.36.2
langchain-core==0.1.50
langchain==0.1.9
langchain-openai==0.0.8
langchain_elasticsearch==0.2.2
Empty file.
204 changes: 204 additions & 0 deletions src/client/embedding/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@

import sys
import os
import time
# 获取当前脚本的绝对路径
current_script_path = os.path.abspath(__file__)

# 将项目根目录添加到sys.path
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_script_path))))

sys.path.append(root_dir)
from typing import List
from src.utils.log_handler import debug_logger, embed_logger
from src.utils.general_utils import get_time_async, get_time
from langchain_core.embeddings import Embeddings
from src.configs.configs import LOCAL_EMBED_SERVICE_URL, LOCAL_RERANK_BATCH
import traceback
import aiohttp
import asyncio
import requests

# 清除多余换行以及以![figure]和![equation]起始的行
def _process_query(query):
return '\n'.join([line for line in query.split('\n') if
not line.strip().startswith('![figure]') and
not line.strip().startswith('![equation]')])


class SBIEmbeddings(Embeddings):
# 初始化请求embedding服务的url
def __init__(self):
self.url = f"http://{LOCAL_EMBED_SERVICE_URL}/embedding"
self.session = requests.Session()
super().__init__()
# 异步向embedding服务请求获取文本的向量
async def _get_embedding_async(self, session, texts):
# 去除多余换行和特殊标记
data = {'texts': [_process_query(text) for text in texts]}
async with session.post(self.url, json=data) as response:
return await response.json()

@get_time_async
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
# 设置批量大小
batch_size = LOCAL_RERANK_BATCH
# 向上取整
embed_logger.info(f'embedding texts number: {len(texts) / batch_size}')
all_embeddings = []
# 分批请求获取文本向量
async with aiohttp.ClientSession() as session:
tasks = [self._get_embedding_async(session, texts[i:i + batch_size])
for i in range(0, len(texts), batch_size)]
# 收集所有任务结果,
# asyncio.gather 的一个重要特性是:虽然任务是并发执行的,但返回结果时会保持跟任务列表相同的顺序。
# 即使后面的批次先处理完,最终 results 中的顺序仍然与 tasks 列表的顺序一致。
results = await asyncio.gather(*tasks)
# 合并所有任务结果
for result in results:
all_embeddings.extend(result)
debug_logger.info(f'success embedding number: {len(all_embeddings)}')
# 返回结果
return all_embeddings
# 专门用于处理单个查询文本。将单个text转换为列表,因为是单个所以只取第一条embedding向量
async def aembed_query(self, text: str) -> List[float]:
return (await self.aembed_documents([text]))[0]
# 同步方法
def _get_embedding_sync(self, texts):
# 为什么同步去除,异步没去除标记啊,我先都给加上
data = {'texts': [_process_query(text) for text in texts]}
try:
response = self.session.post(self.url, json=data)
response.raise_for_status()
result = response.json()
return result
except Exception as e:
debug_logger.error(f'sync embedding error: {traceback.format_exc()}')
return None

# @get_time
# 同步方法,列表请求
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._get_embedding_sync(texts)

@get_time
#同步方法,单个请求
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# return self._get_embedding([text])['embeddings'][0]
return self._get_embedding_sync([text])[0]

async def test_async_methods():
"""测试异步方法"""
embedder = SBIEmbeddings()

# 测试单个文本的embedding
debug_logger.info("\n测试异步单个文本embedding:")
single_text = "人工智能正在改变我们的生活方式。"
single_embedding = await embedder.aembed_query(single_text)
debug_logger.info(f"文本: {single_text}")
debug_logger.info(f"向量维度: {len(single_embedding)}")

# 测试批量文本的embedding
debug_logger.info("\n测试异步批量文本embedding:")
texts = [
"深度学习是人工智能的一个重要分支。",
"自然语言处理技术正在不断进步。",
"机器学习算法可以从数据中学习规律。"
]

embeddings = await embedder.aembed_documents(texts)
for text, embedding in zip(texts, embeddings):
debug_logger.info(f"文本: {text}")
debug_logger.info(f"向量维度: {len(embedding)}")


def test_sync_methods():
"""测试同步方法"""
embedder = SBIEmbeddings()

# 测试单个文本的embedding
debug_logger.info("\n测试同步单个文本embedding:")
single_text = "这是一个测试文本。"
single_embedding = embedder.embed_query(single_text)
debug_logger.info(f"文本: {single_text}")
debug_logger.info(f"向量维度: {len(single_embedding)}")

# 测试批量文本的embedding
debug_logger.info("\n测试同步批量文本embedding:")
texts = [
"第一个测试文本",
"第二个测试文本",
"第三个测试文本"
]
embeddings = embedder.embed_documents(texts)
for text, embedding in zip(texts, embeddings):
debug_logger.info(f"文本: {text}")
debug_logger.info(f"向量维度: {len(embedding)}")


def test_error_handling():
"""测试错误处理"""
embedder = SBIEmbeddings()

debug_logger.info("\n测试错误处理:")
# 测试空文本
try:
embedding = embedder.embed_query("")
debug_logger.info("空文本处理成功")
except Exception as e:
debug_logger.error(f"空文本处理失败: {str(e)}")

# 测试None值
try:
embedding = embedder.embed_documents([None])
debug_logger.info("None值处理成功")
except Exception as e:
debug_logger.error(f"None值处理失败: {str(e)}")


async def performance_test():
"""性能测试"""
embedder = SBIEmbeddings()

debug_logger.info("\n执行性能测试:")
# 准备测试数据
test_sizes = [10, 50, 100]

for size in test_sizes:
texts = [f"这是第{i}个性能测试文本。" for i in range(size)]

# 测试同步方法性能
start_time = time.time()
embeddings = embedder.embed_documents(texts)
sync_time = time.time() - start_time
debug_logger.info(f"同步处理 {size} 个文本耗时: {sync_time:.2f}秒")

# 测试异步方法性能
start_time = time.time()
embeddings = await embedder.aembed_documents(texts)
async_time = time.time() - start_time
debug_logger.info(f"异步处理 {size} 个文本耗时: {async_time:.2f}秒")


async def main():
"""主测试函数"""
debug_logger.info(f"开始embedding客户端测试...")

# 测试异步方法
await test_async_methods()

# # 测试同步方法
# test_sync_methods()

# # 测试错误处理
# test_error_handling()

# # 执行性能测试
# await performance_test()

debug_logger.info("embedding客户端测试完成")


if __name__ == "__main__":
asyncio.run(main())
Empty file added src/client/rerank/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions src/client/rerank/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import sys
import os

# 获取当前脚本的绝对路径
current_script_path = os.path.abspath(__file__)

# 将项目根目录添加到sys.path
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_script_path))))

sys.path.append(root_dir)

from typing import List
from src.utils.log_handler import debug_logger
from src.utils.general_utils import get_time_async, get_time
from src.configs.configs import LOCAL_RERANK_BATCH,LOCAL_RERANK_SERVICE_URL
from langchain.schema import Document
import traceback
import aiohttp
import asyncio
import requests


class SBIRerank:
def __init__(self):
"""初始化重排序客户端"""
self.url = f"http://{LOCAL_RERANK_SERVICE_URL}/rerank"
# 不支持异步的session
self.session = requests.Session()

async def _get_rerank_async(self, query: str, passages: List[str]) -> List[float]:
"""异步请求重排序服务"""
data = {'query': query, 'passages': passages}
try:
async with aiohttp.ClientSession() as session:
async with session.post(self.url, json=data) as response:
return await response.json()
except Exception as e:
debug_logger.error(f'async rerank error: {traceback.format_exc()}')
return [0.0] * len(passages)

@get_time_async
async def arerank_documents(self, query: str, source_documents: List[Document]) -> List[Document]:
"""Embed search docs using async calls, maintaining the original order."""
batch_size = LOCAL_RERANK_BATCH # 增大客户端批处理大小
all_scores = [0 for _ in range(len(source_documents))]
passages = [doc.page_content for doc in source_documents]

tasks = []
for i in range(0, len(passages), batch_size):
task = asyncio.create_task(self._get_rerank_async(query, passages[i:i + batch_size]))
tasks.append((i, task))

for start_index, task in tasks:
res = await task
if res is None:
return source_documents
all_scores[start_index:start_index + batch_size] = res
print(res)

for idx, score in enumerate(all_scores):
source_documents[idx].metadata['score'] = round(float(score), 2)
source_documents = sorted(source_documents, key=lambda x: x.metadata['score'], reverse=True)

return source_documents



#使用示例
async def main():
reranker = SBIRerank()
query = "什么是人工智能"
documents = [Document(page_content="阿爸巴sss啊啊啊啊s巴爸爸"),
Document(page_content="AI技术在各领域广泛应用"),
Document(page_content="机器学习是AI的核心技术。"),
Document(page_content="人工智能是计算机科学的一个分支。")] # 示例文档
reranked_docs = await reranker.arerank_documents(query, documents)
return reranked_docs


# 运行异步主函数
if __name__ == "__main__":
reranked_docs = asyncio.run(main())
print(reranked_docs)
27 changes: 27 additions & 0 deletions src/client/rerank/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# your query and corresponding passages
query = '什么是人工智能'
passages = ['阿爸巴sss啊啊啊啊s巴爸爸', 'AI技术在各领域广泛应用',
'机器学习是AI的核心技术。',
'人工智能是计算机科学的一个分支。']

# construct sentence pairs
sentence_pairs = [[query, passage] for passage in passages]
# init model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-reranker-base_v1')
model = AutoModelForSequenceClassification.from_pretrained('maidalun1020/bce-reranker-base_v1')

device = 'cuda' # if no GPU, set "cpu"
model.to(device)

# get inputs
inputs = tokenizer(sentence_pairs, padding=True, truncation=True, max_length=512, return_tensors="pt")
inputs_on_device = {k: v.to(device) for k, v in inputs.items()}

# calculate scores
logits = model(**inputs_on_device, return_dict=True).logits
print(logits.shape)
scores = logits.view(-1,).float()
scores = torch.sigmoid(scores)
print(scores)
Empty file added src/configs/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions src/configs/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

current_script_path = os.path.abspath(__file__)
root_path = os.path.dirname(os.path.dirname(os.path.dirname(current_script_path)))

LOCAL_EMBED_SERVICE_URL = "localhost:9001"
EMBED_MODEL_PATH="maidalun1020/bce-embedding-base_v1"
LOCAL_EMBED_PATH=os.path.join(root_path, 'src/server/embedding_server', 'bce_model')
LOCAL_EMBED_MODEL_PATH=os.path.join(LOCAL_EMBED_PATH, "model.onnx")
LOCAL_EMBED_BATCH=1
LOCAL_EMBED_THREADS=1
LOCAL_RERANK_SERVICE_URL = "localhost:8001"
RERANK_MODEL_PATH="maidalun1020/bce-reranker-base_v1"
LOCAL_RERANK_BATCH = 1
LOCAL_RERANK_THREADS = 1
LOCAL_RERANK_MAX_LENGTH=512
LOCAL_RERANK_PATH=os.path.join(root_path, 'src/server/rerank_server', 'bce_model')
LOCAL_RERANK_MODEL_PATH=os.path.join(LOCAL_RERANK_PATH, "model.onnx")
3 changes: 1 addition & 2 deletions src/parse/files2db.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def parse_txt(file_path):
return ""

def file2db(uploaded_file, filename):

file_hash = hashlib.md5(uploaded_file.read()).hexdigest()
uploaded_file.seek(0) # 重置文件指针
print(f"Processing {filename} to db...")
Expand All @@ -167,7 +166,7 @@ def file2db(uploaded_file, filename):
return
embedding_path = os.path.join(db_path, 'index.pkl')
os.makedirs(db_path, exist_ok=True)
embeddings = OllamaEmbeddings(base_url="http://localhost:11434", model="qwen:7b")
embeddings = OllamaEmbeddings(base_url="http://localhost:11434", model="nomic-embed-text")
# 检查向量数据库是否存在
if os.path.exists(db_path) and os.path.exists(embedding_path):
print(f"Vector database for {file} already exists.")
Expand Down
2 changes: 1 addition & 1 deletion src/parse/files2elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self,
es_host: str = "localhost",
es_port: int = 9200,
index_name: str = "documents",
embedding_model: str = "qwen:7b"):
embedding_model: str = "nomic-embed-text"):
"""
Initialize the FileToElasticSearch class.

Expand Down
Loading