Skip to content

Commit

Permalink
Add search to assistant for much faster response
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Jan 19, 2025
1 parent b66c23f commit 440f805
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 7,643 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ packages = {find = {where = ["src"]}}
include-package-data = true

[tool.setuptools.package-data]
unitxt = ["catalog/**/*.json", "ui/banner.png", "assistant/context.txt"]
unitxt = ["catalog/**/*.json", "ui/banner.png", "assistant/embeddings.npz", "assistant/metadata.parquet"]

[tool.setuptools.dynamic]
version = {attr = "unitxt.version.version"}
Expand Down
164 changes: 93 additions & 71 deletions src/unitxt/assistant/app.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,108 @@
import datetime
import glob
import json
import os
import uuid # to generate a unique session ID
from functools import lru_cache
import uuid

import litellm
import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoTokenizer


def combine_files_into_string(glob_pattern):
combined_content = []
@st.cache_resource
def load_data():
current_file_dir = os.path.dirname(os.path.abspath(__file__))

for file_path in glob.glob(glob_pattern):
with open(file_path, encoding="utf-8") as file:
combined_content.append("# " + file_path)
combined_content.append(file.read())
metadata_df = pd.read_parquet(os.path.join(current_file_dir, "metadata.parquet"))
embeddings = np.load(os.path.join(current_file_dir, "embeddings.npz"))["embeddings"]
return metadata_df, embeddings

return "\n".join(combined_content)

def search(query, metadata_df, embeddings, max_tokens=5000, min_text_length=50):
# Generate embedding for the query using litellm
response = litellm.embedding(
model="watsonx/intfloat/multilingual-e5-large",
input=[query],
)

query_embedding = torch.tensor(response.data[0]["embedding"], dtype=torch.float32)
embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32)

# Compute cosine similarity
similarities = torch.nn.functional.cosine_similarity(
query_embedding.unsqueeze(0), embeddings_tensor
)

# Sort indices by similarity
sorted_indices = torch.argsort(similarities, descending=True).numpy()

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

# Collect results until max_tokens is reached
total_tokens = 0
results = {}

for idx in sorted_indices:
row = metadata_df.iloc[idx]
path = row["path"]
if path in results:
results[path]["count"] += 1
continue

text = row["document"]

if len(text) < min_text_length:
continue

token_count = len(tokenizer.tokenize(text))

if total_tokens + token_count > max_tokens:
break

def read_catalog_files():
combined_content = []
total_tokens += token_count

for file_path in glob.glob("src/unitxt/catalog/**/*.json"):
with open(file_path, encoding="utf-8") as file:
data = json.load(file)
id = (
file_path.replace("src/unitxt/catalog/", "")
.replace(".json", "")
.replace("/", ".")
)
item = f'id: {id}, type: {data["__type__"]}'
if "__description__" in data:
item += f', desc: {data["__description__"]}'
combined_content.append(item)
results[row["path"]] = {
"count": 1,
"text": text,
"path": row["path"],
"similarity": similarities[idx].item(),
}

return "\n".join(combined_content)
return results


def make_context():
context = "\n# Tutorials: \n"
context += combine_files_into_string("docs/*.rst")
context += combine_files_into_string("docs/docs/*.rst")
context += "\n# Examples: \n"
context += combine_files_into_string("examples/*.py")
context += "\n# Catalog: \n"
context += read_catalog_files()
return context
def generate_response(messages, metadata_df, embeddings, model, max_tokens=500):
user_query = messages[-1]["content"] # Use the latest user message as the query
search_results = search(user_query, metadata_df, embeddings, max_tokens=5000)

# Combine top results as context
context = "\n\n".join(
[f"Path: {v['path']}\nText: {v['text']}" for v in search_results.values()]
)

system_prompt = (
"Your job is to assist users with Unitxt Library and Catalog. "
"Refuse to do anything else.\n\n"
"# Answer only based on the following Information:\n\n" + context
)

@lru_cache
def get_context():
context_file_path = os.path.join(os.path.dirname(__file__), "context.txt")
if not os.path.exists(context_file_path):
context = make_context()
with open(context_file_path, "w", encoding="utf-8") as f:
f.write(context)
else:
with open(context_file_path, encoding="utf-8") as f:
context = f.read()
return context
messages = [
{"role": "system", "content": system_prompt},
*messages,
]

response = litellm.completion(
model=model,
messages=messages,
max_tokens=max_tokens,
stream=True,
)

for chunk in response:
yield chunk.choices[0].delta.content or ""


def save_messages_to_disk(
Expand Down Expand Up @@ -97,8 +139,6 @@ def save_feedback_to_disk(feedback, session_id, output_dir="feedback"):
json.dump(data, f, ensure_ascii=False, indent=4)


context = get_context()

st.set_page_config(
page_title="Unitxt Assistant", page_icon="🦄", initial_sidebar_state="collapsed"
)
Expand All @@ -110,28 +150,7 @@ def save_feedback_to_disk(feedback, session_id, output_dir="feedback"):
if "pending_user_content" not in st.session_state:
st.session_state.pending_user_content = None


def generate_response(messages, model, max_tokens=500):
messages = [
{
"role": "system",
"content": (
"Your job is to assist users with Unitxt Library and Catalog. "
"Refuse to do anything else.\n\n"
"# Answer only based on the following Information:\n\n" + context
),
},
*messages,
]
response = litellm.completion(
model=model,
messages=messages,
max_tokens=max_tokens,
stream=True,
)
for chunk in response:
yield chunk.choices[0].delta.content or ""

metadata_df, embeddings = load_data()

with st.sidebar:
st.title("Assistant")
Expand Down Expand Up @@ -219,7 +238,11 @@ def generate_response(messages, model, max_tokens=500):
)

stream = generate_response(
st.session_state.messages, model=model, max_tokens=max_tokens
st.session_state.messages,
metadata_df,
embeddings,
model=model,
max_tokens=max_tokens,
)

response = placeholder.write_stream(stream)
Expand All @@ -234,7 +257,6 @@ def generate_response(messages, model, max_tokens=500):
)

st.session_state.pending_user_content = None

else:
with chat_container:
for msg in st.session_state.messages:
Expand Down
Loading

0 comments on commit 440f805

Please sign in to comment.