-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstreamlit_app.py
96 lines (75 loc) · 3.7 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_community.embeddings import OpenAIEmbeddings
from ensemble import ensemble_retriever_from_docs
from full_chain import create_full_chain, ask_question
from local_loader import load_txt_files
st.set_page_config(page_title="LangChain & Streamlit RAG")
st.title("LangChain & Streamlit RAG")
def show_ui(qa, prompt_to_user="How may I help you?"):
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": prompt_to_user}]
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# User-provided prompt
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = ask_question(qa, prompt)
st.markdown(response.content)
message = {"role": "assistant", "content": response.content}
st.session_state.messages.append(message)
@st.cache_resource
def get_retriever(openai_api_key=None):
docs = load_txt_files()
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small")
return ensemble_retriever_from_docs(docs, embeddings=embeddings)
def get_chain(openai_api_key=None, huggingfacehub_api_token=None):
ensemble_retriever = get_retriever(openai_api_key=openai_api_key)
chain = create_full_chain(ensemble_retriever,
openai_api_key=openai_api_key,
chat_memory=StreamlitChatMessageHistory(key="langchain_messages"))
return chain
def get_secret_or_input(secret_key, secret_name, info_link=None):
if secret_key in st.secrets:
st.write("Found %s secret" % secret_key)
secret_value = st.secrets[secret_key]
else:
st.write(f"Please provide your {secret_name}")
secret_value = st.text_input(secret_name, key=f"input_{secret_key}", type="password")
if secret_value:
st.session_state[secret_key] = secret_value
if info_link:
st.markdown(f"[Get an {secret_name}]({info_link})")
return secret_value
def run():
ready = True
openai_api_key = st.session_state.get("OPENAI_API_KEY")
huggingfacehub_api_token = st.session_state.get("HUGGINGFACEHUB_API_TOKEN")
with st.sidebar:
if not openai_api_key:
openai_api_key = get_secret_or_input('OPENAI_API_KEY', "OpenAI API key",
info_link="https://platform.openai.com/account/api-keys")
if not huggingfacehub_api_token:
huggingfacehub_api_token = get_secret_or_input('HUGGINGFACEHUB_API_TOKEN', "HuggingFace Hub API Token",
info_link="https://huggingface.co/docs/huggingface_hub/main/en/quick-start#authentication")
if not openai_api_key:
st.warning("Missing OPENAI_API_KEY")
ready = False
if not huggingfacehub_api_token:
st.warning("Missing HUGGINGFACEHUB_API_TOKEN")
ready = False
if ready:
chain = get_chain(openai_api_key=openai_api_key, huggingfacehub_api_token=huggingfacehub_api_token)
st.subheader("Ask me questions about this week's meal plan")
show_ui(chain, "What would you like to know?")
else:
st.stop()
run()