Spaces:
Runtime error
Runtime error
add chatglm
Browse files- app.py +80 -5
- requirements.txt +2 -0
- utils/chatglm.py +192 -0
- utils/generator.py +18 -4
- utils/translate.py +40 -4
app.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
|
4 |
from utils.exif import get_image_info
|
5 |
from utils.generator import generate_prompt
|
6 |
from utils.image2text import git_image2text, w14_image2text, clip_image2text
|
7 |
from utils.translate import en2zh as translate_en2zh
|
8 |
from utils.translate import zh2en as translate_zh2en
|
|
|
|
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
|
@@ -51,18 +53,85 @@ def image_generate_prompter(
|
|
51 |
return "\n".join(prompter_list), "\n".join(prompter_zh_list)
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
with gr.Blocks(title="Prompt生成器") as block:
|
55 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
with gr.Tab('文本生成'):
|
57 |
with gr.Row():
|
58 |
input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
|
|
|
|
|
59 |
translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
|
60 |
|
61 |
output = gr.Textbox(lines=6, label='优化的 Prompt')
|
62 |
output_zh = gr.Textbox(lines=6, label='优化的 Prompt(zh)')
|
63 |
with gr.Row():
|
|
|
64 |
translate_btn = gr.Button('翻译')
|
65 |
-
|
66 |
generate_prompter_btn = gr.Button('优化Prompt')
|
67 |
|
68 |
with gr.Tab('从图片中生成'):
|
@@ -94,13 +163,14 @@ with gr.Blocks(title="Prompt生成器") as block:
|
|
94 |
'microsoft',
|
95 |
'mj',
|
96 |
'gpt2_650k',
|
|
|
97 |
],
|
98 |
value='gpt2_650k',
|
99 |
label='model_name'
|
100 |
)
|
101 |
prompt_min_length = gr.Slider(1, 512, 100, label='min_length', step=1)
|
102 |
prompt_max_length = gr.Slider(1, 512, 200, label='max_length', step=1)
|
103 |
-
prompt_num_return_sequences = gr.Slider(1, 30,
|
104 |
|
105 |
with gr.Accordion('BLIP参数', open=True):
|
106 |
blip_max_length = gr.Slider(1, 512, 100, label='max_length', step=1)
|
@@ -145,9 +215,14 @@ with gr.Blocks(title="Prompt生成器") as block:
|
|
145 |
],
|
146 |
outputs=[output_img_prompter, output_img_prompter_zh]
|
147 |
)
|
148 |
-
|
149 |
-
fn=
|
150 |
inputs=input_text,
|
|
|
|
|
|
|
|
|
|
|
151 |
outputs=translate_output
|
152 |
)
|
153 |
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import mdtex2html
|
4 |
from utils.exif import get_image_info
|
5 |
from utils.generator import generate_prompt
|
6 |
from utils.image2text import git_image2text, w14_image2text, clip_image2text
|
7 |
from utils.translate import en2zh as translate_en2zh
|
8 |
from utils.translate import zh2en as translate_zh2en
|
9 |
+
from utils.chatglm import chat2text
|
10 |
+
from utils.chatglm import models as chatglm_models
|
11 |
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
|
|
|
53 |
return "\n".join(prompter_list), "\n".join(prompter_zh_list)
|
54 |
|
55 |
|
56 |
+
def translate_input(text: str, chatglm_text: str) -> str:
|
57 |
+
if chatglm_text is not None and len(chatglm_text) > 0:
|
58 |
+
return translate_zh2en(chatglm_text)
|
59 |
+
return translate_zh2en(text)
|
60 |
+
|
61 |
+
|
62 |
with gr.Blocks(title="Prompt生成器") as block:
|
63 |
with gr.Column():
|
64 |
+
with gr.Tab('Chat'):
|
65 |
+
def revise(history, latest_message):
|
66 |
+
history[-1] = (history[-1][0], latest_message)
|
67 |
+
return history, ''
|
68 |
+
|
69 |
+
|
70 |
+
def revoke(history):
|
71 |
+
if len(history) >= 1:
|
72 |
+
history.pop()
|
73 |
+
return history
|
74 |
+
|
75 |
+
|
76 |
+
def interrupt(allow_generate):
|
77 |
+
allow_generate[0] = False
|
78 |
+
|
79 |
+
|
80 |
+
def reset_state():
|
81 |
+
return [], []
|
82 |
+
|
83 |
+
|
84 |
+
with gr.Row():
|
85 |
+
with gr.Column(scale=4):
|
86 |
+
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
|
87 |
+
with gr.Column(scale=1):
|
88 |
+
with gr.Row():
|
89 |
+
max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
90 |
+
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
91 |
+
temperature = gr.Slider(0.01, 5, value=0.95, step=0.01, label="Temperature", interactive=True)
|
92 |
+
with gr.Row():
|
93 |
+
query = gr.Textbox(show_label=False, placeholder="Prompts", lines=4).style(container=False)
|
94 |
+
generate_button = gr.Button("生成")
|
95 |
+
with gr.Row():
|
96 |
+
continue_message = gr.Textbox(
|
97 |
+
show_label=False, placeholder="Continue message", lines=2).style(container=False)
|
98 |
+
continue_btn = gr.Button("续写")
|
99 |
+
revise_message = gr.Textbox(
|
100 |
+
show_label=False, placeholder="Revise message", lines=2).style(container=False)
|
101 |
+
revise_btn = gr.Button("修订")
|
102 |
+
revoke_btn = gr.Button("撤回")
|
103 |
+
interrupt_btn = gr.Button("终止生成")
|
104 |
+
reset_btn = gr.Button("清空")
|
105 |
+
|
106 |
+
history = gr.State([])
|
107 |
+
allow_generate = gr.State([True])
|
108 |
+
blank_input = gr.State("")
|
109 |
+
reset_btn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
110 |
+
generate_button.click(
|
111 |
+
chatglm_models.chatglm.predict_continue,
|
112 |
+
inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history],
|
113 |
+
outputs=[chatbot, query]
|
114 |
+
)
|
115 |
+
revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message])
|
116 |
+
revoke_btn.click(revoke, inputs=[history], outputs=[chatbot])
|
117 |
+
continue_btn.click(
|
118 |
+
chatglm_models.chatglm.predict_continue,
|
119 |
+
inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history],
|
120 |
+
outputs=[chatbot, query, continue_message]
|
121 |
+
)
|
122 |
+
interrupt_btn.click(interrupt, inputs=[allow_generate])
|
123 |
with gr.Tab('文本生成'):
|
124 |
with gr.Row():
|
125 |
input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
|
126 |
+
chatglm_output = gr.Textbox(lines=6, label='ChatGLM', placeholder='在此输入内容...')
|
127 |
+
|
128 |
translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
|
129 |
|
130 |
output = gr.Textbox(lines=6, label='优化的 Prompt')
|
131 |
output_zh = gr.Textbox(lines=6, label='优化的 Prompt(zh)')
|
132 |
with gr.Row():
|
133 |
+
chatglm_btn = gr.Button('召唤ChatGLM')
|
134 |
translate_btn = gr.Button('翻译')
|
|
|
135 |
generate_prompter_btn = gr.Button('优化Prompt')
|
136 |
|
137 |
with gr.Tab('从图片中生成'):
|
|
|
163 |
'microsoft',
|
164 |
'mj',
|
165 |
'gpt2_650k',
|
166 |
+
'gpt_neo_125m',
|
167 |
],
|
168 |
value='gpt2_650k',
|
169 |
label='model_name'
|
170 |
)
|
171 |
prompt_min_length = gr.Slider(1, 512, 100, label='min_length', step=1)
|
172 |
prompt_max_length = gr.Slider(1, 512, 200, label='max_length', step=1)
|
173 |
+
prompt_num_return_sequences = gr.Slider(1, 30, 8, label='num_return_sequences', step=1)
|
174 |
|
175 |
with gr.Accordion('BLIP参数', open=True):
|
176 |
blip_max_length = gr.Slider(1, 512, 100, label='max_length', step=1)
|
|
|
215 |
],
|
216 |
outputs=[output_img_prompter, output_img_prompter_zh]
|
217 |
)
|
218 |
+
chatglm_btn.click(
|
219 |
+
fn=chatglm_models.chatglm.generator_image_text,
|
220 |
inputs=input_text,
|
221 |
+
outputs=chatglm_output,
|
222 |
+
)
|
223 |
+
translate_btn.click(
|
224 |
+
fn=translate_input,
|
225 |
+
inputs=[input_text, chatglm_output],
|
226 |
outputs=translate_output
|
227 |
)
|
228 |
|
requirements.txt
CHANGED
@@ -10,3 +10,5 @@ protobuf<=3.20.1,>=3.12.2
|
|
10 |
opencv-python==4.7.0.72
|
11 |
huggingface-hub==0.13.2
|
12 |
clip-interrogator==0.6.0
|
|
|
|
|
|
10 |
opencv-python==4.7.0.72
|
11 |
huggingface-hub==0.13.2
|
12 |
clip-interrogator==0.6.0
|
13 |
+
cpm_kernels==1.0.11
|
14 |
+
mdtex2html==1.2.0
|
utils/chatglm.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import AutoModel, AutoTokenizer
|
7 |
+
from transformers import LogitsProcessor, LogitsProcessorList
|
8 |
+
|
9 |
+
from .singleton import Singleton
|
10 |
+
|
11 |
+
|
12 |
+
def parse_codeblock(text):
|
13 |
+
lines = text.split("\n")
|
14 |
+
for i, line in enumerate(lines):
|
15 |
+
if "```" in line:
|
16 |
+
if line != "```":
|
17 |
+
lines[i] = f'<pre><code class="{lines[i][3:]}">'
|
18 |
+
else:
|
19 |
+
lines[i] = '</code></pre>'
|
20 |
+
else:
|
21 |
+
if i > 0:
|
22 |
+
lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">")
|
23 |
+
return "".join(lines)
|
24 |
+
|
25 |
+
|
26 |
+
class BasePredictor(ABC):
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def __init__(self, model_name):
|
30 |
+
self.model = None
|
31 |
+
self.tokenizer = None
|
32 |
+
|
33 |
+
@abstractmethod
|
34 |
+
def stream_chat_continue(self, *args, **kwargs):
|
35 |
+
raise NotImplementedError
|
36 |
+
|
37 |
+
def predict_continue(self, query, latest_message, max_length, top_p,
|
38 |
+
temperature, allow_generate, history, *args,
|
39 |
+
**kwargs):
|
40 |
+
if history is None:
|
41 |
+
history = []
|
42 |
+
allow_generate[0] = True
|
43 |
+
history.append((query, latest_message))
|
44 |
+
for response in self.stream_chat_continue(
|
45 |
+
self.model,
|
46 |
+
self.tokenizer,
|
47 |
+
query=query,
|
48 |
+
history=history,
|
49 |
+
max_length=max_length,
|
50 |
+
top_p=top_p,
|
51 |
+
temperature=temperature):
|
52 |
+
history[-1] = (history[-1][0], response)
|
53 |
+
yield history, '', ''
|
54 |
+
if not allow_generate[0]:
|
55 |
+
break
|
56 |
+
|
57 |
+
|
58 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
59 |
+
|
60 |
+
def __init__(self, start_pos=20005):
|
61 |
+
self.start_pos = start_pos
|
62 |
+
|
63 |
+
def __call__(self, input_ids: torch.LongTensor,
|
64 |
+
scores: torch.FloatTensor) -> torch.FloatTensor:
|
65 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
66 |
+
scores.zero_()
|
67 |
+
scores[..., self.start_pos] = 5e4
|
68 |
+
return scores
|
69 |
+
|
70 |
+
|
71 |
+
class ChatGLM(BasePredictor):
|
72 |
+
|
73 |
+
def __init__(self, model_name="THUDM/chatglm-6b-int4"):
|
74 |
+
|
75 |
+
print(f'Loading model {model_name}')
|
76 |
+
start = time.perf_counter()
|
77 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
78 |
+
|
79 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
80 |
+
model_name,
|
81 |
+
trust_remote_code=True,
|
82 |
+
resume_download=True
|
83 |
+
)
|
84 |
+
model = AutoModel.from_pretrained(
|
85 |
+
model_name,
|
86 |
+
trust_remote_code=True,
|
87 |
+
resume_download=True
|
88 |
+
).half().to(self.device)
|
89 |
+
|
90 |
+
model = model.eval()
|
91 |
+
self.model = model
|
92 |
+
self.model_name = model_name
|
93 |
+
end = time.perf_counter()
|
94 |
+
print(
|
95 |
+
f'Successfully loaded model {model_name}, time cost: {end - start:.2f}s'
|
96 |
+
)
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def generator_image_text(self, text):
|
100 |
+
response, history = self.model.chat(self.tokenizer, "描述画面:{}".format(text), history=[])
|
101 |
+
return response
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
def stream_chat_continue(self,
|
105 |
+
model,
|
106 |
+
tokenizer,
|
107 |
+
query: str,
|
108 |
+
history: List[Tuple[str, str]] = None,
|
109 |
+
max_length: int = 2048,
|
110 |
+
do_sample=True,
|
111 |
+
top_p=0.7,
|
112 |
+
temperature=0.95,
|
113 |
+
logits_processor=None,
|
114 |
+
**kwargs):
|
115 |
+
if history is None:
|
116 |
+
history = []
|
117 |
+
if logits_processor is None:
|
118 |
+
logits_processor = LogitsProcessorList()
|
119 |
+
if len(history) > 0:
|
120 |
+
answer = history[-1][1]
|
121 |
+
else:
|
122 |
+
answer = ''
|
123 |
+
logits_processor.append(
|
124 |
+
InvalidScoreLogitsProcessor(
|
125 |
+
start_pos=20005 if 'slim' not in self.model_name else 5))
|
126 |
+
gen_kwargs = {
|
127 |
+
"max_length": max_length,
|
128 |
+
"do_sample": do_sample,
|
129 |
+
"top_p": top_p,
|
130 |
+
"temperature": temperature,
|
131 |
+
"logits_processor": logits_processor,
|
132 |
+
**kwargs
|
133 |
+
}
|
134 |
+
if not history:
|
135 |
+
prompt = query
|
136 |
+
else:
|
137 |
+
prompt = ""
|
138 |
+
for i, (old_query, response) in enumerate(history):
|
139 |
+
if i != len(history) - 1:
|
140 |
+
prompt += "[Round {}]\n问:{}\n答:{}\n".format(
|
141 |
+
i, old_query, response)
|
142 |
+
else:
|
143 |
+
prompt += "[Round {}]\n问:{}\n答:".format(i, old_query)
|
144 |
+
batch_input = tokenizer([prompt], return_tensors="pt", padding=True)
|
145 |
+
batch_input = batch_input.to(model.device)
|
146 |
+
|
147 |
+
batch_answer = tokenizer(answer, return_tensors="pt")
|
148 |
+
batch_answer = batch_answer.to(model.device)
|
149 |
+
|
150 |
+
input_length = len(batch_input['input_ids'][0])
|
151 |
+
final_input_ids = torch.cat(
|
152 |
+
[batch_input['input_ids'], batch_answer['input_ids'][:, :-2]],
|
153 |
+
dim=-1).cuda()
|
154 |
+
|
155 |
+
attention_mask = model.get_masks(
|
156 |
+
final_input_ids, device=final_input_ids.device)
|
157 |
+
|
158 |
+
batch_input['input_ids'] = final_input_ids
|
159 |
+
batch_input['attention_mask'] = attention_mask
|
160 |
+
|
161 |
+
input_ids = final_input_ids
|
162 |
+
MASK, gMASK = self.model.config.bos_token_id - 4, self.model.config.bos_token_id - 3
|
163 |
+
mask_token = MASK if MASK in input_ids else gMASK
|
164 |
+
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
165 |
+
batch_input['position_ids'] = self.model.get_position_ids(
|
166 |
+
input_ids, mask_positions, device=input_ids.device)
|
167 |
+
|
168 |
+
for outputs in model.stream_generate(**batch_input, **gen_kwargs):
|
169 |
+
outputs = outputs.tolist()[0][input_length:]
|
170 |
+
response = tokenizer.decode(outputs)
|
171 |
+
response = model.process_response(response)
|
172 |
+
yield parse_codeblock(response)
|
173 |
+
|
174 |
+
|
175 |
+
@Singleton
|
176 |
+
class Models(object):
|
177 |
+
|
178 |
+
def __getattr__(self, item):
|
179 |
+
if item in self.__dict__:
|
180 |
+
return getattr(self, item)
|
181 |
+
|
182 |
+
if item == 'chatglm':
|
183 |
+
self.chatglm = ChatGLM("THUDM/chatglm-6b-int4")
|
184 |
+
|
185 |
+
return getattr(self, item)
|
186 |
+
|
187 |
+
|
188 |
+
models = Models.instance()
|
189 |
+
|
190 |
+
|
191 |
+
def chat2text(text: str) -> str:
|
192 |
+
return models.chatglm.generator_image_text(text)
|
utils/generator.py
CHANGED
@@ -24,11 +24,16 @@ class Models(object):
|
|
24 |
if item in ('gpt2_650k_pipe',):
|
25 |
self.gpt2_650k_pipe = self.load_gpt2_650k_pipe()
|
26 |
|
|
|
|
|
27 |
return getattr(self, item)
|
28 |
|
29 |
@classmethod
|
30 |
-
def
|
|
|
31 |
|
|
|
|
|
32 |
return pipeline('text-generation', model='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')
|
33 |
|
34 |
@classmethod
|
@@ -62,7 +67,16 @@ def generate_prompt(
|
|
62 |
model_name='microsoft',
|
63 |
):
|
64 |
if model_name == 'gpt2_650k':
|
65 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
prompt=plain_text,
|
67 |
min_length=min_length,
|
68 |
max_length=max_length,
|
@@ -114,7 +128,7 @@ def generate_prompt_microsoft(
|
|
114 |
return "\n".join(result)
|
115 |
|
116 |
|
117 |
-
def
|
118 |
def get_valid_prompt(text: str) -> str:
|
119 |
dot_split = text.split('.')[0]
|
120 |
n_split = text.split('\n')[0]
|
@@ -130,7 +144,7 @@ def generate_prompt_gpt2_650k(prompt: str, min_length=60, max_length: int = 255,
|
|
130 |
|
131 |
output += [
|
132 |
get_valid_prompt(result['generated_text']) for result in
|
133 |
-
|
134 |
prompt,
|
135 |
max_new_tokens=rand_length(min_length, max_length),
|
136 |
num_return_sequences=num_return_sequences
|
|
|
24 |
if item in ('gpt2_650k_pipe',):
|
25 |
self.gpt2_650k_pipe = self.load_gpt2_650k_pipe()
|
26 |
|
27 |
+
if item in ('gpt_neo_125m',):
|
28 |
+
self.gpt2_650k_pipe = self.load_gpt_neo_125m()
|
29 |
return getattr(self, item)
|
30 |
|
31 |
@classmethod
|
32 |
+
def load_gpt_neo_125m(cls):
|
33 |
+
return pipeline('text-generation', model='DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M')
|
34 |
|
35 |
+
@classmethod
|
36 |
+
def load_gpt2_650k_pipe(cls):
|
37 |
return pipeline('text-generation', model='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')
|
38 |
|
39 |
@classmethod
|
|
|
67 |
model_name='microsoft',
|
68 |
):
|
69 |
if model_name == 'gpt2_650k':
|
70 |
+
return generate_prompt_pipe(
|
71 |
+
models.gpt2_650k_pipe,
|
72 |
+
prompt=plain_text,
|
73 |
+
min_length=min_length,
|
74 |
+
max_length=max_length,
|
75 |
+
num_return_sequences=num_return_sequences,
|
76 |
+
)
|
77 |
+
elif model_name == 'gpt_neo_125m':
|
78 |
+
return generate_prompt_pipe(
|
79 |
+
models.gpt_neo_125m,
|
80 |
prompt=plain_text,
|
81 |
min_length=min_length,
|
82 |
max_length=max_length,
|
|
|
128 |
return "\n".join(result)
|
129 |
|
130 |
|
131 |
+
def generate_prompt_pipe(pipe, prompt: str, min_length=60, max_length: int = 255, num_return_sequences: int = 8) -> str:
|
132 |
def get_valid_prompt(text: str) -> str:
|
133 |
dot_split = text.split('.')[0]
|
134 |
n_split = text.split('\n')[0]
|
|
|
144 |
|
145 |
output += [
|
146 |
get_valid_prompt(result['generated_text']) for result in
|
147 |
+
pipe(
|
148 |
prompt,
|
149 |
max_new_tokens=rand_length(min_length, max_length),
|
150 |
num_return_sequences=num_return_sequences
|
utils/translate.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import torch
|
3 |
from .singleton import Singleton
|
|
|
|
|
|
|
|
|
4 |
|
5 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
6 |
|
@@ -18,8 +22,18 @@ class Models(object):
|
|
18 |
if item in ('en2zh_model', 'en2zh_tokenizer',):
|
19 |
self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
|
20 |
|
|
|
|
|
|
|
21 |
return getattr(self, item)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
@classmethod
|
24 |
def load_en2zh_model(cls):
|
25 |
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
|
@@ -37,14 +51,35 @@ class Models(object):
|
|
37 |
models = Models.instance()
|
38 |
|
39 |
|
40 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
with torch.no_grad():
|
42 |
encoded = models.zh2en_tokenizer([text], return_tensors="pt")
|
43 |
sequences = models.zh2en_model.generate(**encoded)
|
44 |
return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
45 |
|
46 |
|
47 |
-
def en2zh(text):
|
48 |
with torch.no_grad():
|
49 |
encoded = models.en2zh_tokenizer([text], return_tensors="pt")
|
50 |
sequences = models.en2zh_model.generate(**encoded)
|
@@ -52,8 +87,9 @@ def en2zh(text):
|
|
52 |
|
53 |
|
54 |
if __name__ == "__main__":
|
55 |
-
input = "
|
56 |
-
|
|
|
57 |
print(input, en)
|
58 |
zh = en2zh(en)
|
59 |
print(en, zh)
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import torch
|
3 |
from .singleton import Singleton
|
4 |
+
from transformers import (
|
5 |
+
EncoderDecoderModel,
|
6 |
+
AutoTokenizer
|
7 |
+
)
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
|
|
|
22 |
if item in ('en2zh_model', 'en2zh_tokenizer',):
|
23 |
self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
|
24 |
|
25 |
+
if item in ('wenyanwen2modern_tokenizer', 'wenyanwen2modern_model',):
|
26 |
+
self.wenyanwen2modern_tokenizer, self.wenyanwen2modern_model = self.load_wenyanwen2modern_model()
|
27 |
+
|
28 |
return getattr(self, item)
|
29 |
|
30 |
+
@classmethod
|
31 |
+
def load_wenyanwen2modern_model(cls):
|
32 |
+
PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern"
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
|
34 |
+
model = EncoderDecoderModel.from_pretrained(PRETRAINED)
|
35 |
+
return tokenizer, model
|
36 |
+
|
37 |
@classmethod
|
38 |
def load_en2zh_model(cls):
|
39 |
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
|
|
|
51 |
models = Models.instance()
|
52 |
|
53 |
|
54 |
+
def wenyanwen2modern(text: str) -> str:
|
55 |
+
tk_kwargs = dict(
|
56 |
+
truncation=True,
|
57 |
+
max_length=128,
|
58 |
+
padding="max_length",
|
59 |
+
return_tensors='pt')
|
60 |
+
|
61 |
+
inputs = models.wenyanwen2modern_tokenizer([text, ], **tk_kwargs)
|
62 |
+
with torch.no_grad():
|
63 |
+
return models.wenyanwen2modern_tokenizer.batch_decode(
|
64 |
+
models.wenyanwen2modern_model.generate(
|
65 |
+
inputs.input_ids,
|
66 |
+
attention_mask=inputs.attention_mask,
|
67 |
+
num_beams=3,
|
68 |
+
max_length=256,
|
69 |
+
bos_token_id=101,
|
70 |
+
eos_token_id=models.wenyanwen2modern_tokenizer.sep_token_id,
|
71 |
+
pad_token_id=models.wenyanwen2modern_tokenizer.pad_token_id,
|
72 |
+
), skip_special_tokens=True)[0].replace(" ", "")
|
73 |
+
|
74 |
+
|
75 |
+
def zh2en(text: str) -> str:
|
76 |
with torch.no_grad():
|
77 |
encoded = models.zh2en_tokenizer([text], return_tensors="pt")
|
78 |
sequences = models.zh2en_model.generate(**encoded)
|
79 |
return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
80 |
|
81 |
|
82 |
+
def en2zh(text: str) -> str:
|
83 |
with torch.no_grad():
|
84 |
encoded = models.en2zh_tokenizer([text], return_tensors="pt")
|
85 |
sequences = models.en2zh_model.generate(**encoded)
|
|
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
+
input = "飞流直下三千尺,疑是银河落九天"
|
91 |
+
input_m = wenyanwen2modern(input)
|
92 |
+
en = zh2en(input_m)
|
93 |
print(input, en)
|
94 |
zh = en2zh(en)
|
95 |
print(en, zh)
|