Spaces:
Runtime error
Runtime error
Feature: chat to chihiro.
Browse files- app.py +58 -24
- modeling_chatglm.py +68 -30
app.py
CHANGED
@@ -58,29 +58,63 @@ def evaluate(context, temperature, top_p, top_k):
|
|
58 |
)
|
59 |
out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
|
60 |
return out_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
import gradio as gr
|
63 |
-
gr.
|
64 |
-
|
65 |
-
|
66 |
-
gr.
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
gr.
|
76 |
-
info="
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
58 |
)
|
59 |
out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
|
60 |
return out_text
|
61 |
+
|
62 |
+
def evaluate_stream(msg, history, temperature, top_p):
|
63 |
+
generation_config = GenerationConfig(
|
64 |
+
temperature=temperature,
|
65 |
+
top_p=top_p,
|
66 |
+
#repetition_penalty=1.1,
|
67 |
+
num_beams=1,
|
68 |
+
do_sample=True,
|
69 |
+
)
|
70 |
+
|
71 |
+
history.append([msg, None])
|
72 |
+
|
73 |
+
context = ""
|
74 |
+
if len(history) > 5:
|
75 |
+
history.pop(0)
|
76 |
+
|
77 |
+
for j in range(len(history)):
|
78 |
+
history[j][0] = history[j][0].replace("<br>", "")
|
79 |
+
|
80 |
+
# concatenate context
|
81 |
+
for h in history[:-1]:
|
82 |
+
context += h[0] + "\n" + h[1] + "\n"
|
83 |
+
|
84 |
+
context += history[-1][0]
|
85 |
+
context = context.replace(r'<br>', '')
|
86 |
+
|
87 |
+
h = []
|
88 |
+
print("History:", history)
|
89 |
+
print("Context:", context)
|
90 |
+
for response, h in model.stream_chat(tokenizer, context, h, max_length=160, top_p=top_p, temperature=temperature):
|
91 |
+
history[-1][1] = response
|
92 |
+
yield history, ""
|
93 |
+
|
94 |
+
#return response
|
95 |
|
96 |
import gradio as gr
|
97 |
+
with gr.Blocks() as demo:
|
98 |
+
state = gr.State()
|
99 |
+
with gr.Row():
|
100 |
+
with gr.Column(scale=2):
|
101 |
+
temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.9, label="Temperature",
|
102 |
+
info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。")
|
103 |
+
top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.97, label="Top-p",
|
104 |
+
info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
|
105 |
+
#code = gr.Textbox(label="temp_output", info="解码器输出")
|
106 |
+
#top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
|
107 |
+
# info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
|
108 |
+
|
109 |
+
with gr.Column(scale=3):
|
110 |
+
chatbot = gr.Chatbot(label="聊天框", info="")
|
111 |
+
msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
|
112 |
+
info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。聊天会追随上下文,如果要换个话题建议按下按钮清除聊天。")
|
113 |
+
clear = gr.Button("清除聊天")
|
114 |
+
|
115 |
+
msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
|
116 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
117 |
+
|
118 |
+
|
119 |
+
demo.queue()
|
120 |
+
demo.launch(debug=False)
|
modeling_chatglm.py
CHANGED
@@ -4,6 +4,8 @@ import math
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
|
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
@@ -31,10 +33,12 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
|
|
31 |
from configuration_chatglm import ChatGLMConfig
|
32 |
|
33 |
# flags required to enable jit fusion kernels
|
34 |
-
|
35 |
-
|
36 |
-
torch._C.
|
37 |
-
torch._C.
|
|
|
|
|
38 |
|
39 |
logger = logging.get_logger(__name__)
|
40 |
|
@@ -51,7 +55,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
51 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
52 |
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
53 |
scores.zero_()
|
54 |
-
scores[..., 20005] =
|
55 |
return scores
|
56 |
|
57 |
|
@@ -265,7 +269,7 @@ def attention_fn(
|
|
265 |
if not (attention_mask == 0).all():
|
266 |
# if auto-regressive, skip
|
267 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
268 |
-
dtype = attention_scores.
|
269 |
attention_scores = attention_scores.float()
|
270 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
271 |
|
@@ -610,8 +614,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
610 |
a simple interface for downloading and loading pretrained models.
|
611 |
"""
|
612 |
|
613 |
-
is_parallelizable =
|
614 |
-
supports_gradient_checkpointing =
|
615 |
config_class = ChatGLMConfig
|
616 |
base_model_prefix = "transformer"
|
617 |
_no_split_modules = ["GLM6BBlock"]
|
@@ -619,13 +623,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
619 |
def __init__(self, *inputs, **kwargs):
|
620 |
super().__init__(*inputs, **kwargs)
|
621 |
|
622 |
-
def _init_weights(self, module):
|
|
|
623 |
return
|
624 |
|
625 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
626 |
-
if isinstance(module, (GLMBlock)):
|
627 |
-
module.gradient_checkpointing = value
|
628 |
-
|
629 |
|
630 |
CHATGLM_6B_START_DOCSTRING = r"""
|
631 |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
@@ -722,7 +723,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
722 |
self.inner_hidden_size = config.inner_hidden_size
|
723 |
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
|
724 |
self.position_encoding_2d = config.position_encoding_2d
|
725 |
-
self.model_parallel = True
|
726 |
|
727 |
self.word_embeddings = skip_init(
|
728 |
torch.nn.Embedding,
|
@@ -757,9 +757,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
757 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
758 |
self.word_embeddings = new_embeddings
|
759 |
|
760 |
-
|
761 |
-
|
762 |
-
context_length = seq.index(150004) + 1
|
763 |
|
764 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
765 |
attention_mask.tril_()
|
@@ -770,9 +769,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
770 |
return attention_mask
|
771 |
|
772 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
773 |
-
context_length =
|
774 |
if self.position_encoding_2d:
|
775 |
-
seq_length = seq.index(
|
776 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
777 |
if not gmask:
|
778 |
position_ids[seq_length:] = mask_position
|
@@ -827,14 +826,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
827 |
|
828 |
if past_key_values is None:
|
829 |
past_key_values = tuple([None] * len(self.layers))
|
830 |
-
|
831 |
-
MASK, gMASK = 150000, 150001
|
832 |
-
mask_token = MASK if MASK in input_ids else gMASK
|
833 |
-
use_gmask = False if MASK in input_ids else gMASK
|
834 |
seq = input_ids[0].tolist()
|
835 |
|
836 |
-
mask_position = seq.index(mask_token)
|
837 |
-
|
838 |
if attention_mask is None:
|
839 |
attention_mask = self.get_masks(
|
840 |
seq=seq,
|
@@ -842,6 +835,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
842 |
)
|
843 |
|
844 |
if position_ids is None:
|
|
|
|
|
|
|
|
|
|
|
845 |
position_ids = self.get_position_ids(
|
846 |
seq=seq,
|
847 |
mask_position=mask_position,
|
@@ -940,12 +938,12 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
940 |
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
|
941 |
attention_mask = torch.ones((1, context_length, context_length), device=device)
|
942 |
attention_mask.tril_()
|
943 |
-
attention_mask[..., :
|
944 |
attention_mask.unsqueeze_(1)
|
945 |
attention_mask = (attention_mask < 0.5).bool()
|
946 |
|
947 |
if self.position_encoding_2d:
|
948 |
-
seq_length = seq.index(
|
949 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
950 |
if not gmask:
|
951 |
position_ids[seq_length:] = mask_position
|
@@ -983,7 +981,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
983 |
|
984 |
# only last token for input_ids if past is not None
|
985 |
if past is not None or past_key_values is not None:
|
986 |
-
context_length = seq.index(
|
987 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
988 |
if self.position_encoding_2d:
|
989 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
@@ -1091,6 +1089,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1091 |
for layer_past in past
|
1092 |
)
|
1093 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1094 |
@torch.no_grad()
|
1095 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1096 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
@@ -1113,11 +1126,35 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1113 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1114 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1115 |
response = tokenizer.decode(outputs)
|
1116 |
-
response =
|
1117 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1118 |
history = history + [(query, response)]
|
1119 |
return response, history
|
1120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1121 |
|
1122 |
@torch.no_grad()
|
1123 |
def stream_generate(
|
@@ -1220,6 +1257,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1220 |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1221 |
break
|
1222 |
yield input_ids
|
|
|
1223 |
def quantize(self, bits: int):
|
1224 |
from .quantization import quantize
|
1225 |
self.transformer = quantize(self.transformer, bits)
|
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
|
10 |
import torch
|
11 |
import torch.utils.checkpoint
|
|
|
33 |
from configuration_chatglm import ChatGLMConfig
|
34 |
|
35 |
# flags required to enable jit fusion kernels
|
36 |
+
|
37 |
+
if sys.platform != 'darwin':
|
38 |
+
torch._C._jit_set_profiling_mode(False)
|
39 |
+
torch._C._jit_set_profiling_executor(False)
|
40 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
41 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
42 |
|
43 |
logger = logging.get_logger(__name__)
|
44 |
|
|
|
55 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
56 |
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
57 |
scores.zero_()
|
58 |
+
scores[..., 20005] = 5e4
|
59 |
return scores
|
60 |
|
61 |
|
|
|
269 |
if not (attention_mask == 0).all():
|
270 |
# if auto-regressive, skip
|
271 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
272 |
+
dtype = attention_scores.dtype
|
273 |
attention_scores = attention_scores.float()
|
274 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
275 |
|
|
|
614 |
a simple interface for downloading and loading pretrained models.
|
615 |
"""
|
616 |
|
617 |
+
is_parallelizable = False
|
618 |
+
supports_gradient_checkpointing = False
|
619 |
config_class = ChatGLMConfig
|
620 |
base_model_prefix = "transformer"
|
621 |
_no_split_modules = ["GLM6BBlock"]
|
|
|
623 |
def __init__(self, *inputs, **kwargs):
|
624 |
super().__init__(*inputs, **kwargs)
|
625 |
|
626 |
+
def _init_weights(self, module: nn.Module):
|
627 |
+
"""Initialize the weights."""
|
628 |
return
|
629 |
|
|
|
|
|
|
|
|
|
630 |
|
631 |
CHATGLM_6B_START_DOCSTRING = r"""
|
632 |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
|
|
723 |
self.inner_hidden_size = config.inner_hidden_size
|
724 |
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
|
725 |
self.position_encoding_2d = config.position_encoding_2d
|
|
|
726 |
|
727 |
self.word_embeddings = skip_init(
|
728 |
torch.nn.Embedding,
|
|
|
757 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
758 |
self.word_embeddings = new_embeddings
|
759 |
|
760 |
+
def get_masks(self, seq, device):
|
761 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
|
|
762 |
|
763 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
764 |
attention_mask.tril_()
|
|
|
769 |
return attention_mask
|
770 |
|
771 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
772 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
773 |
if self.position_encoding_2d:
|
774 |
+
seq_length = seq.index(self.config.bos_token_id)
|
775 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
776 |
if not gmask:
|
777 |
position_ids[seq_length:] = mask_position
|
|
|
826 |
|
827 |
if past_key_values is None:
|
828 |
past_key_values = tuple([None] * len(self.layers))
|
|
|
|
|
|
|
|
|
829 |
seq = input_ids[0].tolist()
|
830 |
|
|
|
|
|
831 |
if attention_mask is None:
|
832 |
attention_mask = self.get_masks(
|
833 |
seq=seq,
|
|
|
835 |
)
|
836 |
|
837 |
if position_ids is None:
|
838 |
+
MASK, gMASK = 150000, 150001
|
839 |
+
mask_token = MASK if MASK in input_ids else gMASK
|
840 |
+
use_gmask = False if MASK in input_ids else gMASK
|
841 |
+
|
842 |
+
mask_position = seq.index(mask_token)
|
843 |
position_ids = self.get_position_ids(
|
844 |
seq=seq,
|
845 |
mask_position=mask_position,
|
|
|
938 |
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
|
939 |
attention_mask = torch.ones((1, context_length, context_length), device=device)
|
940 |
attention_mask.tril_()
|
941 |
+
attention_mask[..., :context_length - 1] = 1
|
942 |
attention_mask.unsqueeze_(1)
|
943 |
attention_mask = (attention_mask < 0.5).bool()
|
944 |
|
945 |
if self.position_encoding_2d:
|
946 |
+
seq_length = seq.index(self.config.bos_token_id)
|
947 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
948 |
if not gmask:
|
949 |
position_ids[seq_length:] = mask_position
|
|
|
981 |
|
982 |
# only last token for input_ids if past is not None
|
983 |
if past is not None or past_key_values is not None:
|
984 |
+
context_length = seq.index(self.config.bos_token_id)
|
985 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
986 |
if self.position_encoding_2d:
|
987 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
|
|
1089 |
for layer_past in past
|
1090 |
)
|
1091 |
|
1092 |
+
def process_response(self, response):
|
1093 |
+
response = response.strip()
|
1094 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1095 |
+
punkts = [
|
1096 |
+
[",", ","],
|
1097 |
+
["!", "!"],
|
1098 |
+
[":", ":"],
|
1099 |
+
[";", ";"],
|
1100 |
+
["\?", "?"],
|
1101 |
+
]
|
1102 |
+
for item in punkts:
|
1103 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
1104 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
1105 |
+
return response
|
1106 |
+
|
1107 |
@torch.no_grad()
|
1108 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1109 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
1126 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1127 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1128 |
response = tokenizer.decode(outputs)
|
1129 |
+
response = self.process_response(response)
|
|
|
1130 |
history = history + [(query, response)]
|
1131 |
return response, history
|
1132 |
|
1133 |
+
@torch.no_grad()
|
1134 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|
1135 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
1136 |
+
if history is None:
|
1137 |
+
history = []
|
1138 |
+
if logits_processor is None:
|
1139 |
+
logits_processor = LogitsProcessorList()
|
1140 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1141 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1142 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1143 |
+
if not history:
|
1144 |
+
prompt = query
|
1145 |
+
else:
|
1146 |
+
prompt = ""
|
1147 |
+
for i, (old_query, response) in enumerate(history):
|
1148 |
+
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
1149 |
+
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
1150 |
+
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1151 |
+
input_ids = input_ids.to(self.device)
|
1152 |
+
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1153 |
+
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1154 |
+
response = tokenizer.decode(outputs)
|
1155 |
+
response = self.process_response(response)
|
1156 |
+
new_history = history + [(query, response)]
|
1157 |
+
yield response, new_history
|
1158 |
|
1159 |
@torch.no_grad()
|
1160 |
def stream_generate(
|
|
|
1257 |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1258 |
break
|
1259 |
yield input_ids
|
1260 |
+
|
1261 |
def quantize(self, bits: int):
|
1262 |
from .quantization import quantize
|
1263 |
self.transformer = quantize(self.transformer, bits)
|