Spaces:
Runtime error
Runtime error
fix
Browse files- utils/__init__.py +0 -0
- utils/dbimutils.py +54 -0
- utils/exif.py +54 -0
- utils/html.py +8 -0
- utils/image2text.py +195 -0
- utils/singleton.py +37 -0
- utils/translate.py +59 -0
utils/__init__.py
ADDED
File without changes
|
utils/dbimutils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DanBooru IMage Utility functions
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
|
9 |
+
if img.endswith(".gif"):
|
10 |
+
img = Image.open(img)
|
11 |
+
img = img.convert("RGB")
|
12 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
13 |
+
else:
|
14 |
+
img = cv2.imread(img, flag)
|
15 |
+
return img
|
16 |
+
|
17 |
+
|
18 |
+
def smart_24bit(img):
|
19 |
+
if img.dtype is np.dtype(np.uint16):
|
20 |
+
img = (img / 257).astype(np.uint8)
|
21 |
+
|
22 |
+
if len(img.shape) == 2:
|
23 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
24 |
+
elif img.shape[2] == 4:
|
25 |
+
trans_mask = img[:, :, 3] == 0
|
26 |
+
img[trans_mask] = [255, 255, 255, 255]
|
27 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
28 |
+
return img
|
29 |
+
|
30 |
+
|
31 |
+
def make_square(img, target_size):
|
32 |
+
old_size = img.shape[:2]
|
33 |
+
desired_size = max(old_size)
|
34 |
+
desired_size = max(desired_size, target_size)
|
35 |
+
|
36 |
+
delta_w = desired_size - old_size[1]
|
37 |
+
delta_h = desired_size - old_size[0]
|
38 |
+
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
39 |
+
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
40 |
+
|
41 |
+
color = [255, 255, 255]
|
42 |
+
new_im = cv2.copyMakeBorder(
|
43 |
+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
44 |
+
)
|
45 |
+
return new_im
|
46 |
+
|
47 |
+
|
48 |
+
def smart_resize(img, size):
|
49 |
+
# Assumes the image has already gone through make_square
|
50 |
+
if img.shape[0] > size:
|
51 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
52 |
+
elif img.shape[0] < size:
|
53 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
54 |
+
return img
|
utils/exif.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import piexif
|
2 |
+
import piexif.helper
|
3 |
+
from .html import plaintext_to_html
|
4 |
+
|
5 |
+
|
6 |
+
def get_image_info(rawimage):
|
7 |
+
items = rawimage.info
|
8 |
+
geninfo = ""
|
9 |
+
|
10 |
+
if "exif" in rawimage.info:
|
11 |
+
exif = piexif.load(rawimage.info["exif"])
|
12 |
+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
|
13 |
+
try:
|
14 |
+
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
15 |
+
except ValueError:
|
16 |
+
exif_comment = exif_comment.decode("utf8", errors="ignore")
|
17 |
+
|
18 |
+
items["exif comment"] = exif_comment
|
19 |
+
geninfo = exif_comment
|
20 |
+
|
21 |
+
for field in [
|
22 |
+
"jfif",
|
23 |
+
"jfif_version",
|
24 |
+
"jfif_unit",
|
25 |
+
"jfif_density",
|
26 |
+
"dpi",
|
27 |
+
"exif",
|
28 |
+
"loop",
|
29 |
+
"background",
|
30 |
+
"timestamp",
|
31 |
+
"duration",
|
32 |
+
]:
|
33 |
+
items.pop(field, None)
|
34 |
+
|
35 |
+
geninfo = items.get("parameters", geninfo)
|
36 |
+
|
37 |
+
info = f"""
|
38 |
+
<p><h4>PNG Info</h4></p>
|
39 |
+
"""
|
40 |
+
for key, text in items.items():
|
41 |
+
info += (
|
42 |
+
f"""
|
43 |
+
<div>
|
44 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
45 |
+
<p>{plaintext_to_html(str(text))}</p>
|
46 |
+
</div>
|
47 |
+
""".strip()
|
48 |
+
+ "\n"
|
49 |
+
)
|
50 |
+
|
51 |
+
if len(info) == 0:
|
52 |
+
message = "Nothing found in the image."
|
53 |
+
info = f"<div><p>{message}<p></div>"
|
54 |
+
return info
|
utils/html.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
|
3 |
+
|
4 |
+
def plaintext_to_html(text):
|
5 |
+
text = (
|
6 |
+
"<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>"
|
7 |
+
)
|
8 |
+
return text
|
utils/image2text.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import PIL.Image
|
4 |
+
import huggingface_hub
|
5 |
+
import numpy as np
|
6 |
+
import onnxruntime as rt
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from transformers import AutoModelForCausalLM
|
10 |
+
from transformers import AutoProcessor
|
11 |
+
|
12 |
+
from . import dbimutils
|
13 |
+
from .singleton import Singleton
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from clip_interrogator import Config, Interrogator
|
17 |
+
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
|
21 |
+
@Singleton
|
22 |
+
class Models(object):
|
23 |
+
# WD14 models
|
24 |
+
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
25 |
+
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
26 |
+
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
27 |
+
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
28 |
+
|
29 |
+
MODEL_FILENAME = "model.onnx"
|
30 |
+
LABEL_FILENAME = "selected_tags.csv"
|
31 |
+
|
32 |
+
# CLIP models
|
33 |
+
VIT_H_14_MODEL_REPO = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" # Stable Diffusion 2.X
|
34 |
+
VIT_L_14_MODEL_REPO = "openai/clip-vit-large-patch14" # Stable Diffusion 1.X
|
35 |
+
|
36 |
+
def __init__(self):
|
37 |
+
pass
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def load_clip_model(cls, model_repo):
|
41 |
+
config = Config()
|
42 |
+
config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
43 |
+
config.blip_offload = False if torch.cuda.is_available() else True
|
44 |
+
config.chunk_size = 2048
|
45 |
+
config.flavor_intermediate_count = 512
|
46 |
+
config.blip_num_beams = 64
|
47 |
+
config.clip_model_name = model_repo
|
48 |
+
|
49 |
+
ci = Interrogator(config)
|
50 |
+
return ci
|
51 |
+
|
52 |
+
def __getattr__(self, item):
|
53 |
+
if item in self.__dict__:
|
54 |
+
return getattr(self, item)
|
55 |
+
print(f"Loading {item}...")
|
56 |
+
if item in ('clip_vit_h_14_model',):
|
57 |
+
self.clip_vit_h_14_model = self.load_clip_model(self.VIT_H_14_MODEL_REPO)
|
58 |
+
|
59 |
+
if item in ('clip_vit_l_14_model',):
|
60 |
+
self.clip_vit_l_14_model = self.load_clip_model(self.VIT_L_14_MODEL_REPO)
|
61 |
+
|
62 |
+
if item in ('swinv2_model',):
|
63 |
+
self.swinv2_model = self.load_model(self.SWIN_MODEL_REPO, self.MODEL_FILENAME)
|
64 |
+
if item in ('convnext_model',):
|
65 |
+
self.convnext_model = self.load_model(self.CONV_MODEL_REPO, self.MODEL_FILENAME)
|
66 |
+
if item in ('vit_model',):
|
67 |
+
self.vit_model = self.load_model(self.VIT_MODEL_REPO, self.MODEL_FILENAME)
|
68 |
+
if item in ('convnextv2_model',):
|
69 |
+
self.convnextv2_model = self.load_model(self.CONV2_MODEL_REPO, self.MODEL_FILENAME)
|
70 |
+
|
71 |
+
if item in ('git_model', 'git_processor'):
|
72 |
+
self.git_model, self.git_processor = self.load_git_model()
|
73 |
+
|
74 |
+
if item in ('tag_names', 'rating_indexes', 'general_indexes', 'character_indexes'):
|
75 |
+
self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = self.load_w14_labels()
|
76 |
+
|
77 |
+
return getattr(self, item)
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def load_git_model(cls):
|
81 |
+
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
|
82 |
+
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
|
83 |
+
|
84 |
+
return model, processor
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
|
88 |
+
path = huggingface_hub.hf_hub_download(
|
89 |
+
model_repo, model_filename,
|
90 |
+
)
|
91 |
+
model = rt.InferenceSession(path)
|
92 |
+
return model
|
93 |
+
|
94 |
+
@classmethod
|
95 |
+
def load_w14_labels(cls) -> list[str]:
|
96 |
+
path = huggingface_hub.hf_hub_download(
|
97 |
+
cls.CONV2_MODEL_REPO, cls.LABEL_FILENAME
|
98 |
+
)
|
99 |
+
df = pd.read_csv(path)
|
100 |
+
|
101 |
+
tag_names = df["name"].tolist()
|
102 |
+
rating_indexes = list(np.where(df["category"] == 9)[0])
|
103 |
+
general_indexes = list(np.where(df["category"] == 0)[0])
|
104 |
+
character_indexes = list(np.where(df["category"] == 4)[0])
|
105 |
+
return [tag_names, rating_indexes, general_indexes, character_indexes]
|
106 |
+
|
107 |
+
|
108 |
+
models = Models.instance()
|
109 |
+
|
110 |
+
|
111 |
+
def clip_image2text(image, mode_type='best', model_name='vit_h_14'):
|
112 |
+
image = image.convert('RGB')
|
113 |
+
model = getattr(models, f'clip_{model_name}_model')
|
114 |
+
if mode_type == 'classic':
|
115 |
+
prompt = model.interrogate_classic(image)
|
116 |
+
elif mode_type == 'fast':
|
117 |
+
prompt = model.interrogate_fast(image)
|
118 |
+
elif mode_type == 'negative':
|
119 |
+
prompt = model.interrogate_negative(image)
|
120 |
+
else:
|
121 |
+
prompt = model.interrogate(image) # default to best
|
122 |
+
return prompt
|
123 |
+
|
124 |
+
|
125 |
+
def git_image2text(input_image, max_length=50):
|
126 |
+
image = input_image.convert('RGB')
|
127 |
+
pixel_values = models.git_processor(images=image, return_tensors="pt").to(device).pixel_values
|
128 |
+
|
129 |
+
generated_ids = models.git_model.to(device).generate(pixel_values=pixel_values, max_length=max_length)
|
130 |
+
generated_caption = models.git_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
131 |
+
return generated_caption
|
132 |
+
|
133 |
+
|
134 |
+
def w14_image2text(
|
135 |
+
image: PIL.Image.Image,
|
136 |
+
model_name: str,
|
137 |
+
general_threshold: float,
|
138 |
+
character_threshold: float,
|
139 |
+
|
140 |
+
):
|
141 |
+
tag_names: list[str] = models.tag_names
|
142 |
+
rating_indexes: list[np.int64] = models.rating_indexes
|
143 |
+
general_indexes: list[np.int64] = models.general_indexes
|
144 |
+
character_indexes: list[np.int64] = models.character_indexes
|
145 |
+
model_name = "{}_model".format(model_name.lower())
|
146 |
+
model = getattr(models, model_name)
|
147 |
+
|
148 |
+
_, height, width, _ = model.get_inputs()[0].shape
|
149 |
+
|
150 |
+
# Alpha to white
|
151 |
+
image = image.convert("RGBA")
|
152 |
+
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
|
153 |
+
new_image.paste(image, mask=image)
|
154 |
+
image = new_image.convert("RGB")
|
155 |
+
image = np.asarray(image)
|
156 |
+
|
157 |
+
# PIL RGB to OpenCV BGR
|
158 |
+
image = image[:, :, ::-1]
|
159 |
+
|
160 |
+
image = dbimutils.make_square(image, height)
|
161 |
+
image = dbimutils.smart_resize(image, height)
|
162 |
+
image = image.astype(np.float32)
|
163 |
+
image = np.expand_dims(image, 0)
|
164 |
+
|
165 |
+
input_name = model.get_inputs()[0].name
|
166 |
+
label_name = model.get_outputs()[0].name
|
167 |
+
probs = model.run([label_name], {input_name: image})[0]
|
168 |
+
|
169 |
+
labels = list(zip(tag_names, probs[0].astype(float)))
|
170 |
+
|
171 |
+
# First 4 labels are actually ratings: pick one with argmax
|
172 |
+
ratings_names = [labels[i] for i in rating_indexes]
|
173 |
+
rating = dict(ratings_names)
|
174 |
+
|
175 |
+
# Then we have general tags: pick any where prediction confidence > threshold
|
176 |
+
general_names = [labels[i] for i in general_indexes]
|
177 |
+
general_res = [x for x in general_names if x[1] > general_threshold]
|
178 |
+
general_res = dict(general_res)
|
179 |
+
|
180 |
+
# Everything else is characters: pick any where prediction confidence > threshold
|
181 |
+
character_names = [labels[i] for i in character_indexes]
|
182 |
+
character_res = [x for x in character_names if x[1] > character_threshold]
|
183 |
+
character_res = dict(character_res)
|
184 |
+
|
185 |
+
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
|
186 |
+
a = (
|
187 |
+
", ".join(list(b.keys()))
|
188 |
+
.replace("_", " ")
|
189 |
+
.replace("(", "\(")
|
190 |
+
.replace(")", "\)")
|
191 |
+
)
|
192 |
+
c = ", ".join(list(b.keys()))
|
193 |
+
d = " ".join(list(b.keys()))
|
194 |
+
|
195 |
+
return a, c, d, rating, character_res, general_res
|
utils/singleton.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Singleton:
|
2 |
+
"""
|
3 |
+
A non-thread-safe helper class to ease implementing singletons.
|
4 |
+
This should be used as a decorator -- not a metaclass -- to the
|
5 |
+
class that should be a singleton.
|
6 |
+
|
7 |
+
The decorated class can define one `__init__` function that
|
8 |
+
takes only the `self` argument. Also, the decorated class cannot be
|
9 |
+
inherited from. Other than that, there are no restrictions that apply
|
10 |
+
to the decorated class.
|
11 |
+
|
12 |
+
To get the singleton instance, use the `instance` method. Trying
|
13 |
+
to use `__call__` will result in a `TypeError` being raised.
|
14 |
+
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, decorated):
|
18 |
+
self._decorated = decorated
|
19 |
+
|
20 |
+
def instance(self):
|
21 |
+
"""
|
22 |
+
Returns the singleton instance. Upon its first call, it creates a
|
23 |
+
new instance of the decorated class and calls its `__init__` method.
|
24 |
+
On all subsequent calls, the already created instance is returned.
|
25 |
+
|
26 |
+
"""
|
27 |
+
try:
|
28 |
+
return self._instance
|
29 |
+
except AttributeError:
|
30 |
+
self._instance = self._decorated()
|
31 |
+
return self._instance
|
32 |
+
|
33 |
+
def __call__(self):
|
34 |
+
raise TypeError('Singletons must be accessed through `instance()`.')
|
35 |
+
|
36 |
+
def __instancecheck__(self, inst):
|
37 |
+
return isinstance(inst, self._decorated)
|
utils/translate.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
+
import torch
|
3 |
+
from .singleton import Singleton
|
4 |
+
|
5 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
6 |
+
|
7 |
+
|
8 |
+
@Singleton
|
9 |
+
class Models(object):
|
10 |
+
|
11 |
+
def __getattr__(self, item):
|
12 |
+
if item in self.__dict__:
|
13 |
+
return getattr(self, item)
|
14 |
+
|
15 |
+
if item in ('zh2en_model', 'zh2en_tokenizer',):
|
16 |
+
self.zh2en_model, self.zh2en_tokenizer = self.load_zh2en_model()
|
17 |
+
|
18 |
+
if item in ('en2zh_model', 'en2zh_tokenizer',):
|
19 |
+
self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
|
20 |
+
|
21 |
+
return getattr(self, item)
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def load_en2zh_model(cls):
|
25 |
+
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
|
26 |
+
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
|
27 |
+
return en2zh_model, en2zh_tokenizer
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def load_zh2en_model(cls):
|
31 |
+
zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
|
32 |
+
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
|
33 |
+
|
34 |
+
return zh2en_model, zh2en_tokenizer,
|
35 |
+
|
36 |
+
|
37 |
+
models = Models.instance()
|
38 |
+
|
39 |
+
|
40 |
+
def zh2en(text):
|
41 |
+
with torch.no_grad():
|
42 |
+
encoded = models.zh2en_tokenizer([text], return_tensors="pt")
|
43 |
+
sequences = models.zh2en_model.generate(**encoded)
|
44 |
+
return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
45 |
+
|
46 |
+
|
47 |
+
def en2zh(text):
|
48 |
+
with torch.no_grad():
|
49 |
+
encoded = models.en2zh_tokenizer([text], return_tensors="pt")
|
50 |
+
sequences = models.en2zh_model.generate(**encoded)
|
51 |
+
return models.en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
input = "青春不能回头,所以青春没有终点。 ——《火影忍者》"
|
56 |
+
en = zh2en(input)
|
57 |
+
print(input, en)
|
58 |
+
zh = en2zh(en)
|
59 |
+
print(en, zh)
|