text-aug-demo / app.py
maximuspowers's picture
Update app.py
e36bf2f verified
raw
history blame
3.92 kB
import gradio as gr
import gensim
print(gensim.__version__)
import transformers
import sacremoses # for back translation tokenizer
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
import nlpaug.flow as nafc
from nlpaug.util import Action
from nlpaug.util.file.download import DownloadUtil
DownloadUtil.download_word2vec(dest_dir = '.')
# Possible values are ‘wiki-news-300d-1M’, ‘wiki-news-300d-1M-subword’, ‘crawl-300d-2M’ and ‘crawl-300d-2M-subword’
DownloadUtil.download_fasttext(dest_dir = '.', model_name = 'crawl-300d-2M')
# for synonym replacement
DownloadUtil.download_glove(dest_dir = '.', model_name = 'glove.6B')
# augmentations
def augment_text(text, aug_type, model_type=None, model_path=None, aug_p=0.25, aug_max=3):
if aug_type == 'Word Embedding Substitution':
aug = naw.WordEmbsAug(
model_type=model_type,
model_path=model_path,
action="substitute",
aug_p=aug_p
)
elif aug_type == 'Contextual Insertion':
aug = naw.ContextualWordEmbsAug(
model_path='bert-base-uncased',
action="insert",
aug_p=aug_p
)
elif aug_type == 'Synonym Replacement':
aug = naw.SynonymAug(
aug_src="wordnet",
aug_max=aug_max
)
elif aug_type == 'Back Translation':
aug = naw.BackTranslationAug(
from_model_name='facebook/wmt19-en-de',
to_model_name='facebook/wmt19-de-en'
)
else:
return text
augmented_text = aug.augment(text)
return augmented_text
with gr.Blocks() as iface:
text_input = gr.Textbox(label="Input Text")
aug_type_input = gr.Radio(
choices=['Word Embedding Substitution', 'Contextual Insertion', 'Synonym Replacement', 'Back Translation'],
label="Augmentation Type",
value='Word Embedding Substitution'
)
model_type_input = gr.Dropdown(
choices=['word2vec', 'fasttext', 'glove'],
label="Model Type (for Word Embedding Substitution)",
value='word2vec',
visible=True
)
model_path_input = gr.Textbox(
label="Model Path (for Word Embedding Substitution)",
value="GoogleNews-vectors-negative300.bin",
visible=True
)
aug_p_input = gr.Slider(
minimum=0, maximum=1, step=0.05, value=0.25,
label="Probability of Augmentation (for Embedding Substitution or Contextual Insertion)"
)
aug_max_input = gr.Slider(
minimum=1, maximum=10, step=1, value=3,
label="Max Number of Words to Change (for Synonym Replacement)",
visible=False
)
augmented_output = gr.Textbox(label="Augmented Text")
# update input block visibility based on aug type
def update_inputs(aug_type):
if aug_type == 'Word Embedding Substitution':
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
elif aug_type == 'Contextual Insertion':
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
elif aug_type == 'Synonym Replacement':
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
elif aug_type == 'Back Translation':
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
# update inputs when aug type changes
aug_type_input.change(
update_inputs,
inputs=[aug_type_input],
outputs=[model_type_input, model_path_input, aug_max_input]
)
apply_button = gr.Button("Apply Augmentation")
apply_button.click(
augment_text,
inputs=[text_input, aug_type_input, model_type_input, model_path_input, aug_p_input, aug_max_input],
outputs=[augmented_output]
)
iface.launch()