forked from openai/chatgpt-retrieval-plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjja_query.py
47 lines (40 loc) · 1.33 KB
/
jja_query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import openai
import qdrant_client
import requests
def describe_collection(collection_name):
response = requests.get(f"http://localhost:6333/collections/{collection_name}")
if response.status_code == 200:
return response.json()
else:
print(f"Error {response.status_code}: {response.text}")
return None
collection_name = "document_chunks"
collection_config = describe_collection(collection_name)
client = qdrant_client.QdrantClient(
host="localhost",
prefer_grpc=True,
)
if collection_config is not None:
print(collection_config)
def query_qdrant(
query,
collection_name,
#vector_name="vectors",
top_k=20):
vector_config = collection_config["result"]["config"]["params"][vector_name]
# Creates embedding vector from user query
embedded_query = openai.Embedding.create(
input=query,
model="text-embedding-ada-002",
)["data"][0]["embedding"]
query_results = client.search(
collection_name=collection_name,
query_vector=(
vector_name, embedded_query
),
limit=top_k,
)
return query_results
query_results = query_qdrant("messages about customers", "document_chunks")
for i, article in enumerate(query_results):
print(f"{i + 1}. {article.payload['title']} (Score: {round(article.score, 3)})")