Qingnan Duan
commited on
Commit
•
f786a98
1
Parent(s):
220f772
Normalize response with locale hint
Browse files- modeling_chatglm.py +61 -6
modeling_chatglm.py
CHANGED
@@ -46,6 +46,17 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
46 |
# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
|
47 |
]
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
51 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
@@ -1087,9 +1098,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1087 |
for layer_past in past
|
1088 |
)
|
1089 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1090 |
@torch.no_grad()
|
1091 |
-
def
|
1092 |
-
|
|
|
1093 |
if history is None:
|
1094 |
history = []
|
1095 |
if logits_processor is None:
|
@@ -1097,20 +1118,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1097 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1098 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1099 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
|
|
|
|
1100 |
if not history:
|
1101 |
prompt = query
|
1102 |
else:
|
1103 |
prompt = ""
|
1104 |
for i, (old_query, response) in enumerate(history):
|
1105 |
-
prompt += "[Round {}]\n
|
1106 |
-
prompt += "[Round {}]\n
|
1107 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1108 |
input_ids = input_ids.to(self.device)
|
1109 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1110 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
|
1111 |
response = tokenizer.decode(outputs)
|
1112 |
-
response =
|
1113 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1114 |
history = history + [(query, response)]
|
1115 |
return response, history
|
1116 |
|
@@ -1165,6 +1187,39 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1165 |
|
1166 |
return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
|
1167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1168 |
def quantize(self, bits: int):
|
1169 |
from .quantization import quantize
|
1170 |
self.transformer = quantize(self.transformer, bits)
|
|
|
46 |
# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
|
47 |
]
|
48 |
|
49 |
+
QUERY_KEYWORDS = {
|
50 |
+
'chinese-simplified': {
|
51 |
+
'question': '问:',
|
52 |
+
'answer': '答:',
|
53 |
+
},
|
54 |
+
'english': {
|
55 |
+
'question': 'Q:',
|
56 |
+
'answer': 'A:',
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
|
61 |
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
62 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
1098 |
for layer_past in past
|
1099 |
)
|
1100 |
|
1101 |
+
def chat(self, *args, **kwargs):
|
1102 |
+
return self.chat_chinese_simplified(*args, **kwargs)
|
1103 |
+
|
1104 |
+
def chat_chinese_simplified(self, *args, **kwargs):
|
1105 |
+
return self.chat_internal(*args, **kwargs, locale='chinese-simplified')
|
1106 |
+
|
1107 |
+
def chat_english(self, *args, **kwargs):
|
1108 |
+
return self.chat_internal(*args, **kwargs, locale='english')
|
1109 |
+
|
1110 |
@torch.no_grad()
|
1111 |
+
def chat_internal(self, tokenizer, query: str, locale: str,
|
1112 |
+
history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1113 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
1114 |
if history is None:
|
1115 |
history = []
|
1116 |
if logits_processor is None:
|
|
|
1118 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1119 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1120 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1121 |
+
format_query_keyword_question = QUERY_KEYWORDS[locale]['question']
|
1122 |
+
format_query_keyword_answer = QUERY_KEYWORDS[locale]['answer']
|
1123 |
if not history:
|
1124 |
prompt = query
|
1125 |
else:
|
1126 |
prompt = ""
|
1127 |
for i, (old_query, response) in enumerate(history):
|
1128 |
+
prompt += f"[Round {i}]\n{format_query_keyword_question}{old_query}\n{format_query_keyword_answer}{response}\n"
|
1129 |
+
prompt += f"[Round {len(history)}]\n{format_query_keyword_question}{query}\n{format_query_keyword_answer}"
|
1130 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1131 |
input_ids = input_ids.to(self.device)
|
1132 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1133 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
|
1134 |
response = tokenizer.decode(outputs)
|
1135 |
+
response = self.post_process(response, locale=locale)
|
|
|
1136 |
history = history + [(query, response)]
|
1137 |
return response, history
|
1138 |
|
|
|
1187 |
|
1188 |
return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
|
1189 |
|
1190 |
+
def post_process(self, response: str, locale: str) -> str:
|
1191 |
+
response = response.strip()
|
1192 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1193 |
+
|
1194 |
+
if locale == 'chinese-simplified':
|
1195 |
+
import re
|
1196 |
+
# CJK Unified Ideographs + CJK Unified Ideographs Extension A
|
1197 |
+
cjk_regex = r'([\u4e00-\u9fff]|[\u3400-\u4dbf])'
|
1198 |
+
regex_mapping = {
|
1199 |
+
cjk_regex + ',': r'\1,',
|
1200 |
+
cjk_regex + r'\.': r'\1。',
|
1201 |
+
cjk_regex + r'\?': r'\1?',
|
1202 |
+
cjk_regex + '!': r'\1!',
|
1203 |
+
cjk_regex + ':': r'\1:',
|
1204 |
+
cjk_regex + ';': r'\1;',
|
1205 |
+
}
|
1206 |
+
for pattern in regex_mapping:
|
1207 |
+
response = re.sub(pattern, regex_mapping[pattern], response)
|
1208 |
+
# Nested parantheses not supported.
|
1209 |
+
response = re.sub(r'\(([^\(\)]*(?:[\u4e00-\u9fff]|[\u3400-\u4dbf])[^\(\)]*)\)', r'(\1)', response)
|
1210 |
+
elif locale == 'english':
|
1211 |
+
mapping = {
|
1212 |
+
',': ',',
|
1213 |
+
'。': '.',
|
1214 |
+
'?': '?',
|
1215 |
+
'!': '!',
|
1216 |
+
':': ':',
|
1217 |
+
';': ';',
|
1218 |
+
}
|
1219 |
+
for char in mapping:
|
1220 |
+
response = response.replace(char, mapping[char])
|
1221 |
+
return response
|
1222 |
+
|
1223 |
def quantize(self, bits: int):
|
1224 |
from .quantization import quantize
|
1225 |
self.transformer = quantize(self.transformer, bits)
|