Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import os | |
from threading import Thread | |
import gradio as gr | |
from transformers import AutoModel, AutoTokenizer | |
from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation.streamers import TextIteratorStreamer | |
import torch | |
from project_settings import project_path | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--train_subset", default="train.jsonl", type=str) | |
parser.add_argument("--valid_subset", default="valid.jsonl", type=str) | |
parser.add_argument( | |
"--pretrained_model_name_or_path", | |
default=(project_path / "trained_models/qwen_7b_chinese_modern_poetry").as_posix(), | |
type=str | |
) | |
parser.add_argument("--output_file", default="result.xlsx", type=str) | |
parser.add_argument("--max_new_tokens", default=512, type=int) | |
parser.add_argument("--top_p", default=0.9, type=float) | |
parser.add_argument("--temperature", default=0.35, type=float) | |
parser.add_argument("--repetition_penalty", default=1.0, type=float) | |
parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) | |
args = parser.parse_args() | |
return args | |
description = """ | |
## Qwen-7B | |
基于 [Qwen-7B](https://huggingface.co/qgyd2021/Qwen-7B) 模型, 在 [chinese_modern_poetry](https://huggingface.co/datasets/Iess/chinese_modern_poetry) 数据集上训练了 2 个 epoch. | |
可用于生成现代诗. 如下: | |
使用下列意象写一首现代诗:智慧,刀刃. | |
""" | |
examples = [ | |
"使用下列意象写一首现代诗:石头,森林", | |
"使用下列意象写一首现代诗:花,纱布", | |
"使用下列意象写一首现代诗:山壁,彩虹,诗句,山坡,泪", | |
"使用下列意象写一首现代诗:味道,黄金,名字,银子,女人", | |
"使用下列意象写一首现代诗:乳房,触感,车速,星星,路灯" | |
] | |
def main(): | |
args = get_args() | |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True) | |
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|> | |
if tokenizer.__class__.__name__ == "QWenTokenizer": | |
tokenizer.pad_token_id = tokenizer.eod_id | |
tokenizer.bos_token_id = tokenizer.eod_id | |
tokenizer.eos_token_id = tokenizer.eod_id | |
model = AutoModelForCausalLM.from_pretrained( | |
args.pretrained_model_name_or_path, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
offload_folder="./offload", | |
offload_state_dict=True, | |
# load_in_4bit=True, | |
) | |
model = model.bfloat16().eval() | |
def fn_non_stream(text: str): | |
input_ids = tokenizer( | |
text, | |
return_tensors="pt", | |
add_special_tokens=False, | |
).input_ids.to(args.device) | |
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device) | |
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device) | |
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1) | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=args.max_new_tokens, | |
do_sample=True, | |
top_p=args.top_p, | |
temperature=args.temperature, | |
repetition_penalty=args.repetition_penalty, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
outputs = outputs.tolist()[0][len(input_ids[0]):] | |
response = tokenizer.decode(outputs) | |
response = response.strip().replace(tokenizer.eos_token, "").strip() | |
return [(text, response)] | |
def fn_stream(text: str): | |
text = str(text).strip() | |
input_ids = tokenizer( | |
text, | |
return_tensors="pt", | |
add_special_tokens=False, | |
).input_ids.to(args.device) | |
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device) | |
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device) | |
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1) | |
streamer = TextIteratorStreamer(tokenizer=tokenizer) | |
generation_kwargs = dict( | |
inputs=input_ids, | |
max_new_tokens=args.max_new_tokens, | |
do_sample=True, | |
top_p=args.top_p, | |
temperature=args.temperature, | |
repetition_penalty=args.repetition_penalty, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
streamer=streamer, | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
output = "" | |
for output_ in streamer: | |
output_ = output_.replace(text, "") | |
output_ = output_.replace(tokenizer.eos_token, "") | |
output += output_ | |
result = [(text, output)] | |
chatbot.value = result | |
yield result | |
with gr.Blocks() as blocks: | |
gr.Markdown(value=description) | |
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False) | |
with gr.Column(scale=1): | |
submit_button = gr.Button("💬Submit") | |
with gr.Column(scale=1): | |
clear_button = gr.Button("🗑️Clear", variant="secondary") | |
gr.Examples(examples, text_box) | |
text_box.submit(fn_stream, [text_box], [chatbot]) | |
submit_button.click(fn_stream, [text_box], [chatbot]) | |
clear_button.click( | |
fn=lambda: ("", ""), | |
outputs=[text_box, chatbot], | |
queue=False, | |
api_name=False, | |
) | |
blocks.queue().launch() | |
return | |
if __name__ == '__main__': | |
main() | |