Spaces:
Sleeping
Sleeping
File size: 5,061 Bytes
7261d63 7ac7b6a 7261d63 39e5383 7261d63 53be4fc 7261d63 53be4fc 7261d63 13ea389 7261d63 53be4fc 7261d63 cf8bf4d 2dafff9 cf8bf4d 7261d63 2dafff9 7261d63 2dafff9 7261d63 53be4fc 7261d63 2dafff9 7261d63 2dafff9 7261d63 53be4fc 7261d63 f6048e2 7261d63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import streamlit as st
import torch
import time
from threading import Thread
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer
)
# App title
st.set_page_config(page_title="😶🌫️ FuseChat Model")
root_path = "FuseAI"
model_name = "FuseChat-Qwen-2.5-7B-Instruct"
@st.cache_resource
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(
f"{root_path}/{model_name}",
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
model = AutoModelForCausalLM.from_pretrained(
f"{root_path}/{model_name}",
device_map="auto",
load_in_4bit=True,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model.eval()
return model, tokenizer
with st.sidebar:
st.title('😶🌫️ FuseChat-3.0')
st.write('This chatbot is created using FuseChat, a model developed by FuseAI')
temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=1.0, value=0.7, step=0.01)
top_p = st.sidebar.slider('top_p', min_value=0.1, max_value=1.0, value=0.8, step=0.05)
top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=20, step=1)
repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1.0, max_value=2.0, value=1.05, step=0.05)
max_length = st.sidebar.slider('max_length', min_value=32, max_value=4096, value=2048, step=8)
with st.spinner('loading model..'):
model, tokenizer = load_model(model_name)
# Store LLM generated responses
if "messages" not in st.session_state.keys():
# st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
st.session_state.messages = []
# Display or clear chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
def set_query(query):
st.session_state.messages.append({"role": "user", "content": query})
# Create a list of candidate questions
candidate_questions = ["Is boiling water (100 degrees) an obtuse angle (larger than 90 degrees)?", "Write a quicksort code in Python.", "笼子里有好几只鸡和兔子。笼子里有72个头,200只腿。里面有多少只鸡和兔子"]
# Display the chat interface with a list of clickable question buttons
for question in candidate_questions:
st.sidebar.button(label=question, on_click=set_query, args=[question])
def clear_chat_history():
st.session_state.messages = []
# st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
@torch.no_grad()
def generate_fusechat_response():
conversations=[]
conversations.append({"role": "system", "content": "You are FuseChat-3.0, created by Sun Yat-sen University. You are a helpful assistant."})
for dict_message in st.session_state.messages:
if dict_message["role"] == "user":
conversations.append({"role":"user", "content":dict_message["content"]})
else:
conversations.append({"role":"assistant", "content":dict_message["content"]})
string_dialogue = tokenizer.apply_chat_template(conversations,tokenize=False,add_generation_prompt=True)
input_ids = tokenizer(string_dialogue,
return_tensors="pt").input_ids.to('cuda')
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_length,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
return "".join(outputs)
# User-provided prompt
if prompt := st.chat_input("Do androids dream of electric sheep?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
# Generate a new response if last message is not from assistant
if len(st.session_state.messages) > 0 and st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = generate_fusechat_response()
placeholder = st.empty()
full_response = ''
for item in response:
full_response += item
time.sleep(0.05)
placeholder.markdown(full_response + "▌")
placeholder.markdown(full_response)
message = {"role": "assistant", "content": full_response}
st.session_state.messages.append(message) |