Spaces:
Runtime error
Runtime error
File size: 7,621 Bytes
d663592 9c3fd77 d663592 9c3fd77 51fb7cc 967e32f 9c3fd77 9a6ada6 9c3fd77 51fb7cc d663592 f0ff84c d663592 51fb7cc 7321395 51fb7cc 1942313 51fb7cc 1942313 51fb7cc 9a6ada6 51fb7cc d663592 7321395 9a6ada6 d663592 51fb7cc d663592 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
import torch
import numpy as np
import fasttext
import os
import urllib
import huggingface_hub
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
MODEL_URL_MYV_MUL = 'slone/mbart-large-51-myv-mul-v1'
MODEL_URL_MUL_MYV = 'slone/mbart-large-51-mul-myv-v1'
MODEL_URL_LANGID = 'https://huggingface.co/slone/fastText-LID-323/resolve/main/lid.323.ftz'
MODEL_PATH_LANGID = 'lid.323.ftz'
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(
hf_token=HF_TOKEN,
dataset_name="myv-translation-2022-demo-flags-v2",
organization="slone",
private=True,
)
lang_to_code = {
'Эрзянь | Erzya': 'myv_XX',
'Русский | Рузонь | Russian': 'ru_RU',
'Suomi | Суоминь | Finnish': 'fi_FI',
'Deutsch | Немецень | German': 'de_DE',
'Español | Испанонь | Spanish': 'es_XX',
'English | Англань ': 'en_XX',
'हिन्दी | Хинди | Hindi': 'hi_IN',
'漢語 | Китаень | Chinese': 'zh_CN',
'Türkçe | Турконь | Turkish': 'tr_TR',
'Українська | Украинань | Ukrainian': 'uk_UA',
'Français | Французонь | French': 'fr_XX',
'العربية | Арабонь | Arabic': 'ar_AR',
}
def fix_tokenizer(tokenizer, extra_lang='myv_XX'):
"""Add a new language id to a MBART 50 tokenizer (because it is not serialized) and shift the mask token id."""
old_len = len(tokenizer) - int(extra_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[extra_lang] = old_len-1
tokenizer.id_to_lang_code[old_len-1] = extra_lang
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
if extra_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(extra_lang)
tokenizer.added_tokens_encoder = {}
def translate(
text, model, tokenizer,
src='ru_RU',
trg='myv_XX',
max_length='auto',
num_beams=3,
repetition_penalty=5.0,
train_mode=False, n_out=None,
**kwargs
):
tokenizer.src_lang = src
encoded = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
if max_length == 'auto':
max_length = int(32 + 1.5 * encoded.input_ids.shape[1])
if train_mode:
model.train()
else:
model.eval()
generated_tokens = model.generate(
**encoded.to(model.device),
forced_bos_token_id=tokenizer.lang_code_to_id[trg],
max_length=max_length,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
# early_stopping=True,
num_return_sequences=n_out or 1,
**kwargs
)
out = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
if isinstance(text, str) and n_out is None:
return out[0]
return out
def translate_rerank(
text, model, tokenizer,
src='ru_RU', trg='myv_XX', max_length='auto', num_beams=3, repetition_penalty=5.0, train_mode=False,
n=5, diversity_penalty=3.0, lang='myv', max_score=0.3, order_penalty=0.01,
verbose=False,
**kwargs
):
texts = translate(
text, model, tokenizer, src, trg,
max_length=max_length, train_mode=train_mode, repetition_penalty=repetition_penalty,
num_beams=n,
num_beam_groups=n,
diversity_penalty=diversity_penalty,
n_out=n,
**kwargs
)
scores = [get_mean_lang_score(t, lang=lang, max_score=max_score) for t in texts]
pen_scores = scores - order_penalty * np.arange(n)
if verbose:
print(texts)
print(scores)
print(pen_scores)
return texts[np.argmax(pen_scores)]
def get_mean_lang_score(text, lang='myv', k=300, max_score=0.3):
words = text.split() + [text]
res = []
for langs, scores in zip(*langid_model.predict(words, k=k)):
d = dict(zip([l[9:] for l in langs], scores))
score = min(d.get(lang, 0), max_score) / max_score
res.append(score)
# print(res)
return np.mean(res)
def translate_wrapper(text, src, trg, correct=None):
src = lang_to_code.get(src)
trg = lang_to_code.get(trg)
if src == trg:
return 'Please choose two different languages'
if src == 'myv_XX':
model = model_myv_mul
elif trg == 'myv_XX':
model = model_mul_myv
else:
return 'Please translate to or from Erzya'
print(text, src, trg)
fn = translate_rerank if trg == 'myv_XX' else translate
result = fn(
text=text,
model=model,
tokenizer=tokenizer,
src=src,
trg=trg,
)
return result
article = """
Те эрзянь кельсэ автоматической васенце ютавтыця. Тонавкстнэ улить – [сёрмадовкссо](https://arxiv.org/abs/2209.09368).
Это первый автоматический переводчик для эрзянского языка. Подробности – в [статье](https://arxiv.org/abs/2209.09368).
Пожалуйста, оставляйте своё мнение о качестве переводов с помощью кнопок с эмодзи!
This is the first automatic translator for the Erzya language. The details are in the [paper](https://arxiv.org/abs/2209.09368).
Please leave your feedback about the quality of translations using the buttons with emojis.
The code and models for translation can be found in the repository: https://github.com/slone-nlp/myv-nmt.
"""
fix_instruction = 'Если перевод модели неправильный, впишите сюда правильный текст, снова нажмите "Исполнить", и затем "bad 🙁". ' \
'Тогда к нам в базу попадёт пометка, что перевод был неверным, и его исправление.'
interface = gr.Interface(
translate_wrapper,
[
gr.Textbox(label="Text / текст", lines=2, placeholder='text to translate / текст ютавтозь'),
gr.Dropdown(list(lang_to_code.keys()), type="value", label='source language / васенце кель', value=list(lang_to_code.keys())[0]),
gr.Dropdown(list(lang_to_code.keys()), type="value", label='target language / эрявикс кель', value=list(lang_to_code.keys())[1]),
gr.Textbox(label="Correct translation", lines=2, placeholder=fix_instruction),
],
"text",
allow_flagging="manual",
flagging_options=["good 🙂", "50/50 😐", "bad 🙁"],
flagging_callback=hf_writer,
title='Эрзянь ютавтыця | Эрзянский переводчик | Erzya translator',
article=article,
)
if __name__ == '__main__':
model_mul_myv = MBartForConditionalGeneration.from_pretrained(MODEL_URL_MUL_MYV)
model_myv_mul = MBartForConditionalGeneration.from_pretrained(MODEL_URL_MYV_MUL)
if torch.cuda.is_available():
model_mul_myv.cuda()
model_myv_mul.cuda()
tokenizer = MBart50Tokenizer.from_pretrained(MODEL_URL_MYV_MUL)
fix_tokenizer(tokenizer)
if not os.path.exists(MODEL_PATH_LANGID):
print('downloading LID model...')
urllib.request.urlretrieve(MODEL_URL_LANGID, MODEL_PATH_LANGID)
langid_model = fasttext.load_model(MODEL_PATH_LANGID)
interface.launch()
|