import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import torch from threading import Thread import os; os.chdir(os.path.dirname(__file__)) # model_name = "./92M_low_kv_dropout_v3_hf" model_name = "fzmnm/TinyStoriesAdv_v2_92M" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() model.generation_config.pad_token_id = tokenizer.eos_token_id max_tokens = 512 def build_input_str(message: str, history: 'list[list[str]]'): history_str = "" for entity in history: if entity['role'] == 'user': history_str += f"问:{entity['content']}\n\n" elif entity['role'] == 'assistant': history_str += f"答:{entity['content']}\n\n" return history_str + f"问:{message}\n\n" def stop_criteria(input_str): # return input_str.endswith("\n") and len(input_str.strip()) > 0 input_str=input_str.replace(":",":") return input_str.endswith("问:") or input_str.endswith("meta_tag:") def remove_ending(input_str): if input_str.replace(":",":").endswith("问:"): return input_str[:-2] if input_str.endswith("meta_tag:"): return input_str[:-9] return input_str class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: input_str = tokenizer.decode(input_ids[0], skip_special_tokens=True) return stop_criteria(input_str) def chat(message, history): input_str = build_input_str(message, history) input_ids = tokenizer.encode(input_str, return_tensors="pt") input_ids = input_ids[:, -max_tokens:] streamer = TextIteratorStreamer( tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True) stopping_criteria = StoppingCriteriaList([StopOnTokens()]) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, stopping_criteria=stopping_criteria, max_new_tokens=512, top_p=0.9, do_sample=True, temperature=0.7 ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output_str = "" for new_str in streamer: output_str += new_str yield remove_ending(output_str) app = gr.ChatInterface( fn=chat, type='messages', examples=['什么是鹦鹉?', '什么是大象?', '谁是李白?', '什么是黑洞?'], title='聊天机器人', ) app.launch()