Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import os | |
import gradio as gr | |
from transformers import AutoModel, AutoTokenizer | |
from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer | |
# from transformers.utils.quantization_config import BitsAndBytesConfig | |
import torch | |
from project_settings import project_path | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--train_subset", default="train.jsonl", type=str) | |
parser.add_argument("--valid_subset", default="valid.jsonl", type=str) | |
parser.add_argument( | |
"--pretrained_model_name_or_path", | |
# default="YeungNLP/firefly-chatglm2-6b", | |
default=(project_path / "trained_models/firefly_chatglm2_6b_intent").as_posix(), | |
type=str | |
) | |
parser.add_argument("--output_file", default="result.xlsx", type=str) | |
parser.add_argument("--max_new_tokens", default=512, type=int) | |
parser.add_argument("--top_p", default=0.9, type=float) | |
parser.add_argument("--temperature", default=0.35, type=float) | |
parser.add_argument("--repetition_penalty", default=1.0, type=float) | |
parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) | |
args = parser.parse_args() | |
return args | |
description = """ | |
## ChatGLM-6B | |
基于 [firefly-chatglm2-6b](https://huggingface.co/YeungNLP/firefly-chatglm2-6b) 模型, 在 [telemarketing_intent](https://huggingface.co/datasets/qgyd2021/telemarketing_intent/tree/main/data/prompt) 的 prompt 数据集上训练, 目的是实现 `电话营销` 场景的 1-shot 意图识别. | |
该分类任务有一百多个类别, 但标注数据总是只有 3 万, 并且有一半是 "无关领域", 实现思路是: | |
1. 首先采用传统算法做硬分类, 然后提取概率 top 10 的标签. | |
2. 将 top 10 的标签作为候选标签, 并为每个标签提供一个句子示例. | |
3. 要求 LLM 输出目标句子的类别. | |
Gradio 布署代码参考了: https://huggingface.co/spaces/aodianyun/ChatGLM-6B | |
""" | |
examples = [ | |
"""我们在做电话营销场景的意图识别任务, 可选的意图如下: | |
否定(不是); 礼貌用语; 否定答复; 肯定(需要); 用户正忙; 否定(不需要); 无关领域; 否定(没有); 否定(不用了); 价格太高 | |
如果你认为给定的句子不属于这些意图中的任务一个, 你可以回答: 不知道. | |
Tips: | |
1. 如果候选意图中有 "无关领域", 当你不知道时, 则它有可能属于无关领域. | |
Examples: | |
--------- | |
ExampleSentence: 其实不是 | |
ExampleIntent: 否定(不是) | |
ExampleSentence: 嗯!嘿嘿!早点休息,晚安咯 | |
ExampleIntent: 礼貌用语 | |
ExampleSentence: 没问诶 | |
ExampleIntent: 否定答复 | |
ExampleSentence: 不好意思都需要谢谢 | |
ExampleIntent: 肯定(需要) | |
ExampleSentence: 对呀我在忙 | |
ExampleIntent: 用户正忙 | |
ExampleSentence: 。嗯也也不需要吧唉呀现在不需要那个啊嗯 | |
ExampleIntent: 否定(不需要) | |
ExampleSentence: 我的处理器需要很少的电源。 | |
ExampleIntent: 无关领域 | |
ExampleSentence: 。呃我好像没有在太平洋买过保险,吧拜拜 | |
ExampleIntent: 否定(没有) | |
ExampleSentence: 嗯不用谢谢 | |
ExampleIntent: 否定(不用了) | |
ExampleSentence: 费用贵。 | |
ExampleIntent: 价格太高 | |
--------- | |
Sentence: 。嗯各位不需要,啊谢谢 | |
Intent:""", | |
"""我们在做电话营销场景的意图识别任务, 可选的意图如下: | |
语音信箱; 无关领域; 查物品信息; 污言秽语; 疑问(时间); 疑问(数值); 答时间; 查收费方式; 价格太高; 答数值 | |
如果你认为给定的句子不属于这些意图中的任务一个, 你可以回答: 不知道. | |
Tips: | |
1. 如果候选意图中有 "无关领域", 当你不知道时, 则它有可能属于无关领域. | |
Examples: | |
--------- | |
ExampleSentence: 我们留言。 | |
ExampleIntent: 语音信箱 | |
ExampleSentence: 很刚刚打 | |
ExampleIntent: 无关领域 | |
ExampleSentence: 什么东西我听 | |
ExampleIntent: 查物品信息 | |
ExampleSentence: 知道!AV女优!日本人的骄傲! | |
ExampleIntent: 污言秽语 | |
ExampleSentence: 最后期限 | |
ExampleIntent: 疑问(时间) | |
ExampleSentence: 一共借了多少钱 | |
ExampleIntent: 疑问(数值) | |
ExampleSentence: 22号 | |
ExampleIntent: 答时间 | |
ExampleSentence: 运费 | |
ExampleIntent: 查收费方式 | |
ExampleSentence: 利息高 | |
ExampleIntent: 价格太高 | |
ExampleSentence: 20。 | |
ExampleIntent: 答数值 | |
--------- | |
Sentence: 。对啊什么东西啊我6月份出来的 | |
Intent:""" | |
] | |
def main(): | |
args = get_args() | |
use_cpu = os.environ.get("USE_CPU", "all") | |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True) | |
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|> | |
if tokenizer.__class__.__name__ == "QWenTokenizer": | |
tokenizer.pad_token_id = tokenizer.eod_id | |
tokenizer.bos_token_id = tokenizer.eod_id | |
tokenizer.eos_token_id = tokenizer.eod_id | |
if not use_cpu: | |
model = AutoModel.from_pretrained( | |
args.pretrained_model_name_or_path, | |
trust_remote_code=True | |
).half().cuda() | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
args.pretrained_model_name_or_path, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
offload_folder="./offload", | |
offload_state_dict=True, | |
# load_in_4bit=True, | |
) | |
model = model.eval() | |
def fn(inputs, history=None): | |
if history is None: | |
history = list() | |
with torch.no_grad(): | |
response, history = model.chat(tokenizer, inputs, history) | |
return history, history | |
with gr.Blocks() as blocks: | |
gr.Markdown(value=description) | |
state = gr.State([]) | |
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
text = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False) | |
with gr.Column(scale=1): | |
button = gr.Button("Generate") | |
gr.Examples(examples, text) | |
text.submit(fn, [text, state], [chatbot, state]) | |
button.click(fn, [text, state], [chatbot, state]) | |
blocks.queue().launch() | |
return | |
if __name__ == '__main__': | |
main() | |