hahafofo commited on
Commit
390173a
1 Parent(s): cd1b772

add chatglm

Browse files
Files changed (5) hide show
  1. app.py +80 -5
  2. requirements.txt +2 -0
  3. utils/chatglm.py +192 -0
  4. utils/generator.py +18 -4
  5. utils/translate.py +40 -4
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import gradio as gr
2
  import torch
3
-
4
  from utils.exif import get_image_info
5
  from utils.generator import generate_prompt
6
  from utils.image2text import git_image2text, w14_image2text, clip_image2text
7
  from utils.translate import en2zh as translate_en2zh
8
  from utils.translate import zh2en as translate_zh2en
 
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
@@ -51,18 +53,85 @@ def image_generate_prompter(
51
  return "\n".join(prompter_list), "\n".join(prompter_zh_list)
52
 
53
 
 
 
 
 
 
 
54
  with gr.Blocks(title="Prompt生成器") as block:
55
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  with gr.Tab('文本生成'):
57
  with gr.Row():
58
  input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
 
 
59
  translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
60
 
61
  output = gr.Textbox(lines=6, label='优化的 Prompt')
62
  output_zh = gr.Textbox(lines=6, label='优化的 Prompt(zh)')
63
  with gr.Row():
 
64
  translate_btn = gr.Button('翻译')
65
-
66
  generate_prompter_btn = gr.Button('优化Prompt')
67
 
68
  with gr.Tab('从图片中生成'):
@@ -94,13 +163,14 @@ with gr.Blocks(title="Prompt生成器") as block:
94
  'microsoft',
95
  'mj',
96
  'gpt2_650k',
 
97
  ],
98
  value='gpt2_650k',
99
  label='model_name'
100
  )
101
  prompt_min_length = gr.Slider(1, 512, 100, label='min_length', step=1)
102
  prompt_max_length = gr.Slider(1, 512, 200, label='max_length', step=1)
103
- prompt_num_return_sequences = gr.Slider(1, 30, 6, label='num_return_sequences', step=1)
104
 
105
  with gr.Accordion('BLIP参数', open=True):
106
  blip_max_length = gr.Slider(1, 512, 100, label='max_length', step=1)
@@ -145,9 +215,14 @@ with gr.Blocks(title="Prompt生成器") as block:
145
  ],
146
  outputs=[output_img_prompter, output_img_prompter_zh]
147
  )
148
- translate_btn.click(
149
- fn=translate_zh2en,
150
  inputs=input_text,
 
 
 
 
 
151
  outputs=translate_output
152
  )
153
 
 
1
  import gradio as gr
2
  import torch
3
+ import mdtex2html
4
  from utils.exif import get_image_info
5
  from utils.generator import generate_prompt
6
  from utils.image2text import git_image2text, w14_image2text, clip_image2text
7
  from utils.translate import en2zh as translate_en2zh
8
  from utils.translate import zh2en as translate_zh2en
9
+ from utils.chatglm import chat2text
10
+ from utils.chatglm import models as chatglm_models
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
53
  return "\n".join(prompter_list), "\n".join(prompter_zh_list)
54
 
55
 
56
+ def translate_input(text: str, chatglm_text: str) -> str:
57
+ if chatglm_text is not None and len(chatglm_text) > 0:
58
+ return translate_zh2en(chatglm_text)
59
+ return translate_zh2en(text)
60
+
61
+
62
  with gr.Blocks(title="Prompt生成器") as block:
63
  with gr.Column():
64
+ with gr.Tab('Chat'):
65
+ def revise(history, latest_message):
66
+ history[-1] = (history[-1][0], latest_message)
67
+ return history, ''
68
+
69
+
70
+ def revoke(history):
71
+ if len(history) >= 1:
72
+ history.pop()
73
+ return history
74
+
75
+
76
+ def interrupt(allow_generate):
77
+ allow_generate[0] = False
78
+
79
+
80
+ def reset_state():
81
+ return [], []
82
+
83
+
84
+ with gr.Row():
85
+ with gr.Column(scale=4):
86
+ chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
87
+ with gr.Column(scale=1):
88
+ with gr.Row():
89
+ max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
90
+ top_p = gr.Slider(0.01, 1, value=0.7, step=0.01, label="Top P", interactive=True)
91
+ temperature = gr.Slider(0.01, 5, value=0.95, step=0.01, label="Temperature", interactive=True)
92
+ with gr.Row():
93
+ query = gr.Textbox(show_label=False, placeholder="Prompts", lines=4).style(container=False)
94
+ generate_button = gr.Button("生成")
95
+ with gr.Row():
96
+ continue_message = gr.Textbox(
97
+ show_label=False, placeholder="Continue message", lines=2).style(container=False)
98
+ continue_btn = gr.Button("续写")
99
+ revise_message = gr.Textbox(
100
+ show_label=False, placeholder="Revise message", lines=2).style(container=False)
101
+ revise_btn = gr.Button("修订")
102
+ revoke_btn = gr.Button("撤回")
103
+ interrupt_btn = gr.Button("终止生成")
104
+ reset_btn = gr.Button("清空")
105
+
106
+ history = gr.State([])
107
+ allow_generate = gr.State([True])
108
+ blank_input = gr.State("")
109
+ reset_btn.click(reset_state, outputs=[chatbot, history], show_progress=True)
110
+ generate_button.click(
111
+ chatglm_models.chatglm.predict_continue,
112
+ inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history],
113
+ outputs=[chatbot, query]
114
+ )
115
+ revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message])
116
+ revoke_btn.click(revoke, inputs=[history], outputs=[chatbot])
117
+ continue_btn.click(
118
+ chatglm_models.chatglm.predict_continue,
119
+ inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history],
120
+ outputs=[chatbot, query, continue_message]
121
+ )
122
+ interrupt_btn.click(interrupt, inputs=[allow_generate])
123
  with gr.Tab('文本生成'):
124
  with gr.Row():
125
  input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
126
+ chatglm_output = gr.Textbox(lines=6, label='ChatGLM', placeholder='在此输入内容...')
127
+
128
  translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
129
 
130
  output = gr.Textbox(lines=6, label='优化的 Prompt')
131
  output_zh = gr.Textbox(lines=6, label='优化的 Prompt(zh)')
132
  with gr.Row():
133
+ chatglm_btn = gr.Button('召唤ChatGLM')
134
  translate_btn = gr.Button('翻译')
 
135
  generate_prompter_btn = gr.Button('优化Prompt')
136
 
137
  with gr.Tab('从图片中生成'):
 
163
  'microsoft',
164
  'mj',
165
  'gpt2_650k',
166
+ 'gpt_neo_125m',
167
  ],
168
  value='gpt2_650k',
169
  label='model_name'
170
  )
171
  prompt_min_length = gr.Slider(1, 512, 100, label='min_length', step=1)
172
  prompt_max_length = gr.Slider(1, 512, 200, label='max_length', step=1)
173
+ prompt_num_return_sequences = gr.Slider(1, 30, 8, label='num_return_sequences', step=1)
174
 
175
  with gr.Accordion('BLIP参数', open=True):
176
  blip_max_length = gr.Slider(1, 512, 100, label='max_length', step=1)
 
215
  ],
216
  outputs=[output_img_prompter, output_img_prompter_zh]
217
  )
218
+ chatglm_btn.click(
219
+ fn=chatglm_models.chatglm.generator_image_text,
220
  inputs=input_text,
221
+ outputs=chatglm_output,
222
+ )
223
+ translate_btn.click(
224
+ fn=translate_input,
225
+ inputs=[input_text, chatglm_output],
226
  outputs=translate_output
227
  )
228
 
requirements.txt CHANGED
@@ -10,3 +10,5 @@ protobuf<=3.20.1,>=3.12.2
10
  opencv-python==4.7.0.72
11
  huggingface-hub==0.13.2
12
  clip-interrogator==0.6.0
 
 
 
10
  opencv-python==4.7.0.72
11
  huggingface-hub==0.13.2
12
  clip-interrogator==0.6.0
13
+ cpm_kernels==1.0.11
14
+ mdtex2html==1.2.0
utils/chatglm.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ from transformers import AutoModel, AutoTokenizer
7
+ from transformers import LogitsProcessor, LogitsProcessorList
8
+
9
+ from .singleton import Singleton
10
+
11
+
12
+ def parse_codeblock(text):
13
+ lines = text.split("\n")
14
+ for i, line in enumerate(lines):
15
+ if "```" in line:
16
+ if line != "```":
17
+ lines[i] = f'<pre><code class="{lines[i][3:]}">'
18
+ else:
19
+ lines[i] = '</code></pre>'
20
+ else:
21
+ if i > 0:
22
+ lines[i] = "<br/>" + line.replace("<", "&lt;").replace(">", "&gt;")
23
+ return "".join(lines)
24
+
25
+
26
+ class BasePredictor(ABC):
27
+
28
+ @abstractmethod
29
+ def __init__(self, model_name):
30
+ self.model = None
31
+ self.tokenizer = None
32
+
33
+ @abstractmethod
34
+ def stream_chat_continue(self, *args, **kwargs):
35
+ raise NotImplementedError
36
+
37
+ def predict_continue(self, query, latest_message, max_length, top_p,
38
+ temperature, allow_generate, history, *args,
39
+ **kwargs):
40
+ if history is None:
41
+ history = []
42
+ allow_generate[0] = True
43
+ history.append((query, latest_message))
44
+ for response in self.stream_chat_continue(
45
+ self.model,
46
+ self.tokenizer,
47
+ query=query,
48
+ history=history,
49
+ max_length=max_length,
50
+ top_p=top_p,
51
+ temperature=temperature):
52
+ history[-1] = (history[-1][0], response)
53
+ yield history, '', ''
54
+ if not allow_generate[0]:
55
+ break
56
+
57
+
58
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
59
+
60
+ def __init__(self, start_pos=20005):
61
+ self.start_pos = start_pos
62
+
63
+ def __call__(self, input_ids: torch.LongTensor,
64
+ scores: torch.FloatTensor) -> torch.FloatTensor:
65
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
66
+ scores.zero_()
67
+ scores[..., self.start_pos] = 5e4
68
+ return scores
69
+
70
+
71
+ class ChatGLM(BasePredictor):
72
+
73
+ def __init__(self, model_name="THUDM/chatglm-6b-int4"):
74
+
75
+ print(f'Loading model {model_name}')
76
+ start = time.perf_counter()
77
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
78
+
79
+ self.tokenizer = AutoTokenizer.from_pretrained(
80
+ model_name,
81
+ trust_remote_code=True,
82
+ resume_download=True
83
+ )
84
+ model = AutoModel.from_pretrained(
85
+ model_name,
86
+ trust_remote_code=True,
87
+ resume_download=True
88
+ ).half().to(self.device)
89
+
90
+ model = model.eval()
91
+ self.model = model
92
+ self.model_name = model_name
93
+ end = time.perf_counter()
94
+ print(
95
+ f'Successfully loaded model {model_name}, time cost: {end - start:.2f}s'
96
+ )
97
+
98
+ @torch.no_grad()
99
+ def generator_image_text(self, text):
100
+ response, history = self.model.chat(self.tokenizer, "描述画面:{}".format(text), history=[])
101
+ return response
102
+
103
+ @torch.no_grad()
104
+ def stream_chat_continue(self,
105
+ model,
106
+ tokenizer,
107
+ query: str,
108
+ history: List[Tuple[str, str]] = None,
109
+ max_length: int = 2048,
110
+ do_sample=True,
111
+ top_p=0.7,
112
+ temperature=0.95,
113
+ logits_processor=None,
114
+ **kwargs):
115
+ if history is None:
116
+ history = []
117
+ if logits_processor is None:
118
+ logits_processor = LogitsProcessorList()
119
+ if len(history) > 0:
120
+ answer = history[-1][1]
121
+ else:
122
+ answer = ''
123
+ logits_processor.append(
124
+ InvalidScoreLogitsProcessor(
125
+ start_pos=20005 if 'slim' not in self.model_name else 5))
126
+ gen_kwargs = {
127
+ "max_length": max_length,
128
+ "do_sample": do_sample,
129
+ "top_p": top_p,
130
+ "temperature": temperature,
131
+ "logits_processor": logits_processor,
132
+ **kwargs
133
+ }
134
+ if not history:
135
+ prompt = query
136
+ else:
137
+ prompt = ""
138
+ for i, (old_query, response) in enumerate(history):
139
+ if i != len(history) - 1:
140
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(
141
+ i, old_query, response)
142
+ else:
143
+ prompt += "[Round {}]\n问:{}\n答:".format(i, old_query)
144
+ batch_input = tokenizer([prompt], return_tensors="pt", padding=True)
145
+ batch_input = batch_input.to(model.device)
146
+
147
+ batch_answer = tokenizer(answer, return_tensors="pt")
148
+ batch_answer = batch_answer.to(model.device)
149
+
150
+ input_length = len(batch_input['input_ids'][0])
151
+ final_input_ids = torch.cat(
152
+ [batch_input['input_ids'], batch_answer['input_ids'][:, :-2]],
153
+ dim=-1).cuda()
154
+
155
+ attention_mask = model.get_masks(
156
+ final_input_ids, device=final_input_ids.device)
157
+
158
+ batch_input['input_ids'] = final_input_ids
159
+ batch_input['attention_mask'] = attention_mask
160
+
161
+ input_ids = final_input_ids
162
+ MASK, gMASK = self.model.config.bos_token_id - 4, self.model.config.bos_token_id - 3
163
+ mask_token = MASK if MASK in input_ids else gMASK
164
+ mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
165
+ batch_input['position_ids'] = self.model.get_position_ids(
166
+ input_ids, mask_positions, device=input_ids.device)
167
+
168
+ for outputs in model.stream_generate(**batch_input, **gen_kwargs):
169
+ outputs = outputs.tolist()[0][input_length:]
170
+ response = tokenizer.decode(outputs)
171
+ response = model.process_response(response)
172
+ yield parse_codeblock(response)
173
+
174
+
175
+ @Singleton
176
+ class Models(object):
177
+
178
+ def __getattr__(self, item):
179
+ if item in self.__dict__:
180
+ return getattr(self, item)
181
+
182
+ if item == 'chatglm':
183
+ self.chatglm = ChatGLM("THUDM/chatglm-6b-int4")
184
+
185
+ return getattr(self, item)
186
+
187
+
188
+ models = Models.instance()
189
+
190
+
191
+ def chat2text(text: str) -> str:
192
+ return models.chatglm.generator_image_text(text)
utils/generator.py CHANGED
@@ -24,11 +24,16 @@ class Models(object):
24
  if item in ('gpt2_650k_pipe',):
25
  self.gpt2_650k_pipe = self.load_gpt2_650k_pipe()
26
 
 
 
27
  return getattr(self, item)
28
 
29
  @classmethod
30
- def load_gpt2_650k_pipe(cls):
 
31
 
 
 
32
  return pipeline('text-generation', model='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')
33
 
34
  @classmethod
@@ -62,7 +67,16 @@ def generate_prompt(
62
  model_name='microsoft',
63
  ):
64
  if model_name == 'gpt2_650k':
65
- return generate_prompt_gpt2_650k(
 
 
 
 
 
 
 
 
 
66
  prompt=plain_text,
67
  min_length=min_length,
68
  max_length=max_length,
@@ -114,7 +128,7 @@ def generate_prompt_microsoft(
114
  return "\n".join(result)
115
 
116
 
117
- def generate_prompt_gpt2_650k(prompt: str, min_length=60, max_length: int = 255, num_return_sequences: int = 8) -> str:
118
  def get_valid_prompt(text: str) -> str:
119
  dot_split = text.split('.')[0]
120
  n_split = text.split('\n')[0]
@@ -130,7 +144,7 @@ def generate_prompt_gpt2_650k(prompt: str, min_length=60, max_length: int = 255,
130
 
131
  output += [
132
  get_valid_prompt(result['generated_text']) for result in
133
- models.gpt2_650k_pipe(
134
  prompt,
135
  max_new_tokens=rand_length(min_length, max_length),
136
  num_return_sequences=num_return_sequences
 
24
  if item in ('gpt2_650k_pipe',):
25
  self.gpt2_650k_pipe = self.load_gpt2_650k_pipe()
26
 
27
+ if item in ('gpt_neo_125m',):
28
+ self.gpt2_650k_pipe = self.load_gpt_neo_125m()
29
  return getattr(self, item)
30
 
31
  @classmethod
32
+ def load_gpt_neo_125m(cls):
33
+ return pipeline('text-generation', model='DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M')
34
 
35
+ @classmethod
36
+ def load_gpt2_650k_pipe(cls):
37
  return pipeline('text-generation', model='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')
38
 
39
  @classmethod
 
67
  model_name='microsoft',
68
  ):
69
  if model_name == 'gpt2_650k':
70
+ return generate_prompt_pipe(
71
+ models.gpt2_650k_pipe,
72
+ prompt=plain_text,
73
+ min_length=min_length,
74
+ max_length=max_length,
75
+ num_return_sequences=num_return_sequences,
76
+ )
77
+ elif model_name == 'gpt_neo_125m':
78
+ return generate_prompt_pipe(
79
+ models.gpt_neo_125m,
80
  prompt=plain_text,
81
  min_length=min_length,
82
  max_length=max_length,
 
128
  return "\n".join(result)
129
 
130
 
131
+ def generate_prompt_pipe(pipe, prompt: str, min_length=60, max_length: int = 255, num_return_sequences: int = 8) -> str:
132
  def get_valid_prompt(text: str) -> str:
133
  dot_split = text.split('.')[0]
134
  n_split = text.split('\n')[0]
 
144
 
145
  output += [
146
  get_valid_prompt(result['generated_text']) for result in
147
+ pipe(
148
  prompt,
149
  max_new_tokens=rand_length(min_length, max_length),
150
  num_return_sequences=num_return_sequences
utils/translate.py CHANGED
@@ -1,6 +1,10 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
  from .singleton import Singleton
 
 
 
 
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
@@ -18,8 +22,18 @@ class Models(object):
18
  if item in ('en2zh_model', 'en2zh_tokenizer',):
19
  self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
20
 
 
 
 
21
  return getattr(self, item)
22
 
 
 
 
 
 
 
 
23
  @classmethod
24
  def load_en2zh_model(cls):
25
  en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
@@ -37,14 +51,35 @@ class Models(object):
37
  models = Models.instance()
38
 
39
 
40
- def zh2en(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  with torch.no_grad():
42
  encoded = models.zh2en_tokenizer([text], return_tensors="pt")
43
  sequences = models.zh2en_model.generate(**encoded)
44
  return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
45
 
46
 
47
- def en2zh(text):
48
  with torch.no_grad():
49
  encoded = models.en2zh_tokenizer([text], return_tensors="pt")
50
  sequences = models.en2zh_model.generate(**encoded)
@@ -52,8 +87,9 @@ def en2zh(text):
52
 
53
 
54
  if __name__ == "__main__":
55
- input = "青春不能回头,所以青春没有终点。 ——《火影忍者》"
56
- en = zh2en(input)
 
57
  print(input, en)
58
  zh = en2zh(en)
59
  print(en, zh)
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
  from .singleton import Singleton
4
+ from transformers import (
5
+ EncoderDecoderModel,
6
+ AutoTokenizer
7
+ )
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
22
  if item in ('en2zh_model', 'en2zh_tokenizer',):
23
  self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
24
 
25
+ if item in ('wenyanwen2modern_tokenizer', 'wenyanwen2modern_model',):
26
+ self.wenyanwen2modern_tokenizer, self.wenyanwen2modern_model = self.load_wenyanwen2modern_model()
27
+
28
  return getattr(self, item)
29
 
30
+ @classmethod
31
+ def load_wenyanwen2modern_model(cls):
32
+ PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern"
33
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
34
+ model = EncoderDecoderModel.from_pretrained(PRETRAINED)
35
+ return tokenizer, model
36
+
37
  @classmethod
38
  def load_en2zh_model(cls):
39
  en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
 
51
  models = Models.instance()
52
 
53
 
54
+ def wenyanwen2modern(text: str) -> str:
55
+ tk_kwargs = dict(
56
+ truncation=True,
57
+ max_length=128,
58
+ padding="max_length",
59
+ return_tensors='pt')
60
+
61
+ inputs = models.wenyanwen2modern_tokenizer([text, ], **tk_kwargs)
62
+ with torch.no_grad():
63
+ return models.wenyanwen2modern_tokenizer.batch_decode(
64
+ models.wenyanwen2modern_model.generate(
65
+ inputs.input_ids,
66
+ attention_mask=inputs.attention_mask,
67
+ num_beams=3,
68
+ max_length=256,
69
+ bos_token_id=101,
70
+ eos_token_id=models.wenyanwen2modern_tokenizer.sep_token_id,
71
+ pad_token_id=models.wenyanwen2modern_tokenizer.pad_token_id,
72
+ ), skip_special_tokens=True)[0].replace(" ", "")
73
+
74
+
75
+ def zh2en(text: str) -> str:
76
  with torch.no_grad():
77
  encoded = models.zh2en_tokenizer([text], return_tensors="pt")
78
  sequences = models.zh2en_model.generate(**encoded)
79
  return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
80
 
81
 
82
+ def en2zh(text: str) -> str:
83
  with torch.no_grad():
84
  encoded = models.en2zh_tokenizer([text], return_tensors="pt")
85
  sequences = models.en2zh_model.generate(**encoded)
 
87
 
88
 
89
  if __name__ == "__main__":
90
+ input = "飞流直下三千尺,疑是银河落九天"
91
+ input_m = wenyanwen2modern(input)
92
+ en = zh2en(input_m)
93
  print(input, en)
94
  zh = en2zh(en)
95
  print(en, zh)