ljsabc commited on
Commit
300a211
1 Parent(s): f494a07

Feature: chat to chihiro.

Browse files
Files changed (2) hide show
  1. app.py +58 -24
  2. modeling_chatglm.py +68 -30
app.py CHANGED
@@ -58,29 +58,63 @@ def evaluate(context, temperature, top_p, top_k):
58
  )
59
  out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
60
  return out_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  import gradio as gr
63
- gr.Interface(
64
- fn=evaluate,
65
- inputs=[
66
- gr.components.Textbox(
67
- lines=2, label="问题", placeholder="最近过得怎么样?",
68
- info="可以在这里输入你的问题。也可以什么都不填写生成随机数据。"
69
- ),
70
- #gr.components.Textbox(lines=2, label="Input", placeholder="none"),
71
- gr.components.Slider(minimum=0, maximum=1.1, value=1.0, label="Temperature",
72
- info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。"),
73
- gr.components.Slider(minimum=0.5, maximum=1.0, value=0.98, label="Top p",
74
- info="top-p参数,只输出前p>top-p的文字,建议不要修改。"),
75
- gr.components.Slider(minimum=1, maximum=200, step=1, value=40, label="Top k",
76
- info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。"),
77
- ],
78
- outputs=[
79
- gr.inputs.Textbox(
80
- lines=5,
81
- label="Output",
82
- )
83
- ],
84
- title="李萌萌(Alter Ego)",
85
- description="这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。因为是在CPU上进行运行,速度会比较慢。",
86
- ).launch()
 
58
  )
59
  out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
60
  return out_text
61
+
62
+ def evaluate_stream(msg, history, temperature, top_p):
63
+ generation_config = GenerationConfig(
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ #repetition_penalty=1.1,
67
+ num_beams=1,
68
+ do_sample=True,
69
+ )
70
+
71
+ history.append([msg, None])
72
+
73
+ context = ""
74
+ if len(history) > 5:
75
+ history.pop(0)
76
+
77
+ for j in range(len(history)):
78
+ history[j][0] = history[j][0].replace("<br>", "")
79
+
80
+ # concatenate context
81
+ for h in history[:-1]:
82
+ context += h[0] + "\n" + h[1] + "\n"
83
+
84
+ context += history[-1][0]
85
+ context = context.replace(r'<br>', '')
86
+
87
+ h = []
88
+ print("History:", history)
89
+ print("Context:", context)
90
+ for response, h in model.stream_chat(tokenizer, context, h, max_length=160, top_p=top_p, temperature=temperature):
91
+ history[-1][1] = response
92
+ yield history, ""
93
+
94
+ #return response
95
 
96
  import gradio as gr
97
+ with gr.Blocks() as demo:
98
+ state = gr.State()
99
+ with gr.Row():
100
+ with gr.Column(scale=2):
101
+ temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.9, label="Temperature",
102
+ info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。")
103
+ top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.97, label="Top-p",
104
+ info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
105
+ #code = gr.Textbox(label="temp_output", info="解码器输出")
106
+ #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
107
+ # info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
108
+
109
+ with gr.Column(scale=3):
110
+ chatbot = gr.Chatbot(label="聊天框", info="")
111
+ msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
112
+ info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。聊天会追随上下文,如果要换个话题建议按下按钮清除聊天。")
113
+ clear = gr.Button("清除聊天")
114
+
115
+ msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
116
+ clear.click(lambda: None, None, chatbot, queue=False)
117
+
118
+
119
+ demo.queue()
120
+ demo.launch(debug=False)
modeling_chatglm.py CHANGED
@@ -4,6 +4,8 @@ import math
4
  import copy
5
  import os
6
  import warnings
 
 
7
 
8
  import torch
9
  import torch.utils.checkpoint
@@ -31,10 +33,12 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
31
  from configuration_chatglm import ChatGLMConfig
32
 
33
  # flags required to enable jit fusion kernels
34
- torch._C._jit_set_profiling_mode(False)
35
- torch._C._jit_set_profiling_executor(False)
36
- torch._C._jit_override_can_fuse_on_cpu(True)
37
- torch._C._jit_override_can_fuse_on_gpu(True)
 
 
38
 
39
  logger = logging.get_logger(__name__)
40
 
@@ -51,7 +55,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
51
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
52
  if torch.isnan(scores).any() or torch.isinf(scores).any():
53
  scores.zero_()
54
- scores[..., 20005] = 1e5
55
  return scores
56
 
57
 
@@ -265,7 +269,7 @@ def attention_fn(
265
  if not (attention_mask == 0).all():
266
  # if auto-regressive, skip
267
  attention_scores.masked_fill_(attention_mask, -10000.0)
268
- dtype = attention_scores.type()
269
  attention_scores = attention_scores.float()
270
  attention_scores = attention_scores * query_key_layer_scaling_coeff
271
 
@@ -610,8 +614,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
610
  a simple interface for downloading and loading pretrained models.
611
  """
612
 
613
- is_parallelizable = True
614
- supports_gradient_checkpointing = True
615
  config_class = ChatGLMConfig
616
  base_model_prefix = "transformer"
617
  _no_split_modules = ["GLM6BBlock"]
@@ -619,13 +623,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
619
  def __init__(self, *inputs, **kwargs):
620
  super().__init__(*inputs, **kwargs)
621
 
622
- def _init_weights(self, module):
 
623
  return
624
 
625
- def _set_gradient_checkpointing(self, module, value=False):
626
- if isinstance(module, (GLMBlock)):
627
- module.gradient_checkpointing = value
628
-
629
 
630
  CHATGLM_6B_START_DOCSTRING = r"""
631
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
@@ -722,7 +723,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
722
  self.inner_hidden_size = config.inner_hidden_size
723
  self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
724
  self.position_encoding_2d = config.position_encoding_2d
725
- self.model_parallel = True
726
 
727
  self.word_embeddings = skip_init(
728
  torch.nn.Embedding,
@@ -757,9 +757,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
757
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
758
  self.word_embeddings = new_embeddings
759
 
760
- @staticmethod
761
- def get_masks(seq, device):
762
- context_length = seq.index(150004) + 1
763
 
764
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
765
  attention_mask.tril_()
@@ -770,9 +769,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
770
  return attention_mask
771
 
772
  def get_position_ids(self, seq, mask_position, device, gmask=False):
773
- context_length = len(seq)
774
  if self.position_encoding_2d:
775
- seq_length = seq.index(150004)
776
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
777
  if not gmask:
778
  position_ids[seq_length:] = mask_position
@@ -827,14 +826,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
827
 
828
  if past_key_values is None:
829
  past_key_values = tuple([None] * len(self.layers))
830
-
831
- MASK, gMASK = 150000, 150001
832
- mask_token = MASK if MASK in input_ids else gMASK
833
- use_gmask = False if MASK in input_ids else gMASK
834
  seq = input_ids[0].tolist()
835
 
836
- mask_position = seq.index(mask_token)
837
-
838
  if attention_mask is None:
839
  attention_mask = self.get_masks(
840
  seq=seq,
@@ -842,6 +835,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
842
  )
843
 
844
  if position_ids is None:
 
 
 
 
 
845
  position_ids = self.get_position_ids(
846
  seq=seq,
847
  mask_position=mask_position,
@@ -940,12 +938,12 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
940
  def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
941
  attention_mask = torch.ones((1, context_length, context_length), device=device)
942
  attention_mask.tril_()
943
- attention_mask[..., :mask_position - 1] = 1
944
  attention_mask.unsqueeze_(1)
945
  attention_mask = (attention_mask < 0.5).bool()
946
 
947
  if self.position_encoding_2d:
948
- seq_length = seq.index(150004)
949
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
950
  if not gmask:
951
  position_ids[seq_length:] = mask_position
@@ -983,7 +981,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
983
 
984
  # only last token for input_ids if past is not None
985
  if past is not None or past_key_values is not None:
986
- context_length = seq.index(150004)
987
  last_token = input_ids[:, -1].unsqueeze(-1)
988
  if self.position_encoding_2d:
989
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
@@ -1091,6 +1089,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1091
  for layer_past in past
1092
  )
1093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1094
  @torch.no_grad()
1095
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1096
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
@@ -1113,11 +1126,35 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1113
  outputs = self.generate(**input_ids, **gen_kwargs)
1114
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1115
  response = tokenizer.decode(outputs)
1116
- response = response.strip()
1117
- response = response.replace("[[训练时间]]", "2023年")
1118
  history = history + [(query, response)]
1119
  return response, history
1120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1121
 
1122
  @torch.no_grad()
1123
  def stream_generate(
@@ -1220,6 +1257,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1220
  if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1221
  break
1222
  yield input_ids
 
1223
  def quantize(self, bits: int):
1224
  from .quantization import quantize
1225
  self.transformer = quantize(self.transformer, bits)
 
4
  import copy
5
  import os
6
  import warnings
7
+ import re
8
+ import sys
9
 
10
  import torch
11
  import torch.utils.checkpoint
 
33
  from configuration_chatglm import ChatGLMConfig
34
 
35
  # flags required to enable jit fusion kernels
36
+
37
+ if sys.platform != 'darwin':
38
+ torch._C._jit_set_profiling_mode(False)
39
+ torch._C._jit_set_profiling_executor(False)
40
+ torch._C._jit_override_can_fuse_on_cpu(True)
41
+ torch._C._jit_override_can_fuse_on_gpu(True)
42
 
43
  logger = logging.get_logger(__name__)
44
 
 
55
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
  if torch.isnan(scores).any() or torch.isinf(scores).any():
57
  scores.zero_()
58
+ scores[..., 20005] = 5e4
59
  return scores
60
 
61
 
 
269
  if not (attention_mask == 0).all():
270
  # if auto-regressive, skip
271
  attention_scores.masked_fill_(attention_mask, -10000.0)
272
+ dtype = attention_scores.dtype
273
  attention_scores = attention_scores.float()
274
  attention_scores = attention_scores * query_key_layer_scaling_coeff
275
 
 
614
  a simple interface for downloading and loading pretrained models.
615
  """
616
 
617
+ is_parallelizable = False
618
+ supports_gradient_checkpointing = False
619
  config_class = ChatGLMConfig
620
  base_model_prefix = "transformer"
621
  _no_split_modules = ["GLM6BBlock"]
 
623
  def __init__(self, *inputs, **kwargs):
624
  super().__init__(*inputs, **kwargs)
625
 
626
+ def _init_weights(self, module: nn.Module):
627
+ """Initialize the weights."""
628
  return
629
 
 
 
 
 
630
 
631
  CHATGLM_6B_START_DOCSTRING = r"""
632
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
 
723
  self.inner_hidden_size = config.inner_hidden_size
724
  self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
725
  self.position_encoding_2d = config.position_encoding_2d
 
726
 
727
  self.word_embeddings = skip_init(
728
  torch.nn.Embedding,
 
757
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
758
  self.word_embeddings = new_embeddings
759
 
760
+ def get_masks(self, seq, device):
761
+ context_length = seq.index(self.config.bos_token_id) + 1
 
762
 
763
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
764
  attention_mask.tril_()
 
769
  return attention_mask
770
 
771
  def get_position_ids(self, seq, mask_position, device, gmask=False):
772
+ context_length = seq.index(self.config.bos_token_id) + 1
773
  if self.position_encoding_2d:
774
+ seq_length = seq.index(self.config.bos_token_id)
775
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
776
  if not gmask:
777
  position_ids[seq_length:] = mask_position
 
826
 
827
  if past_key_values is None:
828
  past_key_values = tuple([None] * len(self.layers))
 
 
 
 
829
  seq = input_ids[0].tolist()
830
 
 
 
831
  if attention_mask is None:
832
  attention_mask = self.get_masks(
833
  seq=seq,
 
835
  )
836
 
837
  if position_ids is None:
838
+ MASK, gMASK = 150000, 150001
839
+ mask_token = MASK if MASK in input_ids else gMASK
840
+ use_gmask = False if MASK in input_ids else gMASK
841
+
842
+ mask_position = seq.index(mask_token)
843
  position_ids = self.get_position_ids(
844
  seq=seq,
845
  mask_position=mask_position,
 
938
  def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
939
  attention_mask = torch.ones((1, context_length, context_length), device=device)
940
  attention_mask.tril_()
941
+ attention_mask[..., :context_length - 1] = 1
942
  attention_mask.unsqueeze_(1)
943
  attention_mask = (attention_mask < 0.5).bool()
944
 
945
  if self.position_encoding_2d:
946
+ seq_length = seq.index(self.config.bos_token_id)
947
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
948
  if not gmask:
949
  position_ids[seq_length:] = mask_position
 
981
 
982
  # only last token for input_ids if past is not None
983
  if past is not None or past_key_values is not None:
984
+ context_length = seq.index(self.config.bos_token_id)
985
  last_token = input_ids[:, -1].unsqueeze(-1)
986
  if self.position_encoding_2d:
987
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
 
1089
  for layer_past in past
1090
  )
1091
 
1092
+ def process_response(self, response):
1093
+ response = response.strip()
1094
+ response = response.replace("[[训练时间]]", "2023年")
1095
+ punkts = [
1096
+ [",", ","],
1097
+ ["!", "!"],
1098
+ [":", ":"],
1099
+ [";", ";"],
1100
+ ["\?", "?"],
1101
+ ]
1102
+ for item in punkts:
1103
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
1104
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
1105
+ return response
1106
+
1107
  @torch.no_grad()
1108
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1109
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
 
1126
  outputs = self.generate(**input_ids, **gen_kwargs)
1127
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1128
  response = tokenizer.decode(outputs)
1129
+ response = self.process_response(response)
 
1130
  history = history + [(query, response)]
1131
  return response, history
1132
 
1133
+ @torch.no_grad()
1134
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
1135
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1136
+ if history is None:
1137
+ history = []
1138
+ if logits_processor is None:
1139
+ logits_processor = LogitsProcessorList()
1140
+ logits_processor.append(InvalidScoreLogitsProcessor())
1141
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1142
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1143
+ if not history:
1144
+ prompt = query
1145
+ else:
1146
+ prompt = ""
1147
+ for i, (old_query, response) in enumerate(history):
1148
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1149
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1150
+ input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1151
+ input_ids = input_ids.to(self.device)
1152
+ for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1153
+ outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1154
+ response = tokenizer.decode(outputs)
1155
+ response = self.process_response(response)
1156
+ new_history = history + [(query, response)]
1157
+ yield response, new_history
1158
 
1159
  @torch.no_grad()
1160
  def stream_generate(
 
1257
  if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1258
  break
1259
  yield input_ids
1260
+
1261
  def quantize(self, bits: int):
1262
  from .quantization import quantize
1263
  self.transformer = quantize(self.transformer, bits)