chatbot / app.py
aisuko's picture
Init commit
0f60bae
raw
history blame contribute delete
No virus
1.9 kB
import gradio as gr
import torch
from torch import LongTensor, FloatTensor
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.bfloat16)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: LongTensor, scores: FloatTensor, **kwargs) -> bool:
stop_ids=[29,0]
for stop_id in stop_ids:
if input_ids[0][-1]==stop_id:
return True
return False
def predict(message, history):
try:
history_transformer_format = history+[[message, ""]]
stop=StopOnTokens()
messages="".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) for item in history_transformer_format])
model_inputs =tokenizer([messages], return_tensors="pt")
streamer=TextIteratorStreamer(
tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs=dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t=Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partical_message=""
for new_token in streamer:
if new_token !='<':
partical_message+=new_token
yield partical_message
except Exception as e:
yield "Sorry, I don't understand that."
gr.ChatInterface(predict).queue().launch()