Spaces:
Running
Running
import os | |
import streamlit as st | |
from langchain.llms import HuggingFaceHub | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
class UserInterface(): | |
def __init__(self, ): | |
st.warning("Warning: Some models may not work and some models may require GPU to run") | |
st.text("An Open Source Chat Application") | |
st.header("Open LLMs") | |
# self.API_KEY = st.sidebar.text_input( | |
# 'API Key', | |
# type='password', | |
# help="Type in your HuggingFace API key to use this app" | |
# ) | |
models_name = ( | |
"HuggingFaceH4/zephyr-7b-beta", | |
"Sharathhebbar24/chat_gpt2_dpo", | |
"Sharathhebbar24/chat_gpt2", | |
"Sharathhebbar24/math_gpt2_sft", | |
"Sharathhebbar24/math_gpt2", | |
"Sharathhebbar24/convo_bot_gpt_v1", | |
"Sharathhebbar24/Instruct_GPT", | |
"Sharathhebbar24/Mistral-7B-v0.1-sharded", | |
"Sharathhebbar24/llama_chat_small_7b", | |
"Deci/DeciCoder-6B", | |
"Deci/DeciLM-7B-instruct", | |
"Deci/DeciCoder-1b", | |
"Deci/DeciLM-7B-instruct-GGUF", | |
"Open-Orca/Mistral-7B-OpenOrca", | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"Sharathhebbar24/llama_7b_chat", | |
"CultriX/MistralTrix-v1", | |
"ahxt/LiteLlama-460M-1T", | |
"gorilla-llm/gorilla-7b-hf-delta-v0", | |
"codeparrot/codeparrot" | |
) | |
self.models = st.sidebar.selectbox( | |
label="Choose your models", | |
options=models_name, | |
help="Choose your model", | |
) | |
self.temperature = st.sidebar.slider( | |
label='Temperature', | |
min_value=0.1, | |
max_value=1.0, | |
step=0.1, | |
value=0.5, | |
help="Set the temperature to get accurate or random result" | |
) | |
self.max_token_length = st.sidebar.slider( | |
label="Token Length", | |
min_value=32, | |
max_value=2048, | |
step=16, | |
value=64, | |
help="Set max tokens to generate maximum amount of text output" | |
) | |
self.model_kwargs = { | |
"temperature": self.temperature, | |
"max_new_tokens": self.max_token_length | |
} | |
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.getenv("HF_KEY") | |
def form_data(self): | |
try: | |
# if not self.API_KEY.startswith('hf_'): | |
# st.warning('Please enter your API key!', icon='⚠') | |
# text_input_visibility = True | |
# else: | |
# text_input_visibility = False | |
text_input_visibility = False | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
st.write(f"You are using {self.models} model") | |
for message in st.session_state.messages: | |
with st.chat_message(message.get('role')): | |
st.write(message.get("content")) | |
context = st.sidebar.text_input( | |
label="Context", | |
help="Context lets you know on what the answer should be generated" | |
) | |
question = st.chat_input( | |
key="question", | |
disabled=text_input_visibility | |
) | |
template = f"<|system|>\nYou are a intelligent chatbot and expertise in {context}.</s>\n<|user|>\n{question}.\n<|assistant|>" | |
# template = """ | |
# Answer the question based on the context, if you don't know then output "Out of Context" | |
# Context: {context} | |
# Question: {question} | |
# Answer: | |
# """ | |
prompt = PromptTemplate( | |
template=template, | |
input_variables=[ | |
'question', | |
'context' | |
] | |
) | |
llm = HuggingFaceHub( | |
repo_id = self.models, | |
model_kwargs = self.model_kwargs | |
) | |
if question: | |
llm_chain = LLMChain( | |
prompt=prompt, | |
llm=llm, | |
) | |
result = llm_chain.run({ | |
"question": question, | |
"context": context | |
}) | |
if "Out of Context" in result: | |
result = "Out of Context" | |
st.session_state.messages.append( | |
{ | |
"role":"user", | |
"content": f"Context: {context}\n\nQuestion: {question}" | |
} | |
) | |
with st.chat_message("user"): | |
st.write(f"Context: {context}\n\nQuestion: {question}") | |
if question.lower() == "clear": | |
del st.session_state.messages | |
return | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"content": result | |
} | |
) | |
with st.chat_message('assistant'): | |
st.markdown(result) | |
except Exception as e: | |
st.error(e, icon="🚨") | |
model = UserInterface() | |
model.form_data() |