Spaces:
Runtime error
Runtime error
File size: 6,366 Bytes
390173a f2f92a4 390173a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import time
from abc import ABC, abstractmethod
from typing import List, Tuple
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import LogitsProcessor, LogitsProcessorList
from .singleton import Singleton
def parse_codeblock(text):
lines = text.split("\n")
for i, line in enumerate(lines):
if "```" in line:
if line != "```":
lines[i] = f'<pre><code class="{lines[i][3:]}">'
else:
lines[i] = '</code></pre>'
else:
if i > 0:
lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">")
return "".join(lines)
class BasePredictor(ABC):
@abstractmethod
def __init__(self, model_name):
self.model = None
self.tokenizer = None
@abstractmethod
def stream_chat_continue(self, *args, **kwargs):
raise NotImplementedError
def predict_continue(self, query, latest_message, max_length, top_p,
temperature, allow_generate, history, *args,
**kwargs):
if history is None:
history = []
allow_generate[0] = True
history.append((query, latest_message))
for response in self.stream_chat_continue(
self.model,
self.tokenizer,
query=query,
history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature):
history[-1] = (history[-1][0], response)
yield history, '', ''
if not allow_generate[0]:
break
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __init__(self, start_pos=20005):
self.start_pos = start_pos
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., self.start_pos] = 5e4
return scores
class ChatGLM(BasePredictor):
def __init__(self, model_name="THUDM/chatglm-6b-int4"):
print(f'Loading model {model_name}')
start = time.perf_counter()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
resume_download=True
)
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
resume_download=True
).half().to(self.device)
model = model.eval()
self.model = model
self.model_name = model_name
end = time.perf_counter()
print(
f'Successfully loaded model {model_name}, time cost: {end - start:.2f}s'
)
@torch.no_grad()
def generator_image_text(self, text):
response, history = self.model.chat(self.tokenizer, "描述画面:{}".format(text), history=[])
return response
@torch.no_grad()
def stream_chat_continue(self,
model,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 2048,
do_sample=True,
top_p=0.7,
temperature=0.95,
logits_processor=None,
**kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
if len(history) > 0:
answer = history[-1][1]
else:
answer = ''
logits_processor.append(
InvalidScoreLogitsProcessor(
start_pos=20005 if 'slim' not in self.model_name else 5))
gen_kwargs = {
"max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs
}
if not history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(history):
if i != len(history) - 1:
prompt += "[Round {}]\n问:{}\n答:{}\n".format(
i, old_query, response)
else:
prompt += "[Round {}]\n问:{}\n答:".format(i, old_query)
batch_input = tokenizer([prompt], return_tensors="pt", padding=True)
batch_input = batch_input.to(model.device)
batch_answer = tokenizer(answer, return_tensors="pt")
batch_answer = batch_answer.to(model.device)
input_length = len(batch_input['input_ids'][0])
final_input_ids = torch.cat(
[batch_input['input_ids'], batch_answer['input_ids'][:, :-2]],
dim=-1).cuda()
attention_mask = model.get_masks(
final_input_ids, device=final_input_ids.device)
batch_input['input_ids'] = final_input_ids
batch_input['attention_mask'] = attention_mask
input_ids = final_input_ids
MASK, gMASK = self.model.config.bos_token_id - 4, self.model.config.bos_token_id - 3
mask_token = MASK if MASK in input_ids else gMASK
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
batch_input['position_ids'] = self.model.get_position_ids(
input_ids, mask_positions, device=input_ids.device)
for outputs in model.stream_generate(**batch_input, **gen_kwargs):
outputs = outputs.tolist()[0][input_length:]
response = tokenizer.decode(outputs)
response = model.process_response(response)
yield parse_codeblock(response)
@Singleton
class Models(object):
def __getattr__(self, item):
if item in self.__dict__:
return getattr(self, item)
if item == 'chatglm':
self.chatglm = ChatGLM("THUDM/chatglm-6b-int4")
return getattr(self, item)
models = Models.instance()
def chat2text(text: str) -> str:
return models.chatglm.generator_image_text(text)
|