Spaces:
Running
Running
import gradio | |
import inseq | |
from inseq.data.aggregator import AggregatorPipeline, SubwordAggregator, SequenceAttributionAggregator, PairAggregator | |
import torch | |
from os.path import exists | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
else: | |
DEVICE = "cpu" | |
# Start downloading the Hu-En model | |
# model_hu_en = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients") | |
def swap_pronoun(sentence): | |
if "He" in sentence: | |
return sentence.replace("He", "She") | |
elif "She" in sentence: | |
return sentence.replace("She", "He") | |
else: | |
return sentence | |
def run_counterfactual(occupation): | |
occupation = occupation.split(" (")[0] | |
result_fp = f"results/counterfactual_{occupation}.html" | |
if exists(result_fp): | |
with open(result_fp, 'r') as f: | |
return f.read() | |
# "egy" means something like "a", but is used less frequently than in English. | |
#source = f"Ő egy {occupation}." | |
source = f"Ő {occupation}." | |
model = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients") | |
model.device = DEVICE | |
target = model.generate(source)[0] | |
#target_modified = swap_pronoun(target) | |
out = model.attribute( | |
[ | |
source, | |
source, | |
], | |
[ | |
#target, | |
#target_modified, | |
target.replace("She", "He"), | |
target.replace("He", "She"), | |
], | |
n_steps=150, | |
return_convergence_delta=False, | |
attribute_target=False, | |
step_scores=["probability"], | |
internal_batch_size=100, | |
include_eos_baseline=False, | |
device=DEVICE, | |
) | |
#out = model.attribute(source, attribute_target=False, n_steps=150, device=DEVICE, return_convergence_delta=False, step_scores=["probability"]) | |
squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator]) | |
masculine = out.sequence_attributions[0].aggregate(aggregator=squeezesum) | |
feminine = out.sequence_attributions[1].aggregate(aggregator=squeezesum) | |
html = masculine.show(aggregator=PairAggregator, paired_attr=feminine, return_html=True, display=True) | |
# Save html | |
with open(result_fp, 'w') as f: | |
f.write(html) | |
return html | |
#return out.show(return_html=True, display=True) | |
def run_simple(occupation, lang, aggregate): | |
aggregate = True if aggregate == "yes" else False | |
occupation = occupation.split(" (")[0] | |
result_fp = f"results/simple_{occupation}_{lang}{'_aggregate' if aggregate else ''}.html" | |
if exists(result_fp): | |
with open(result_fp, 'r') as f: | |
return f.read() | |
model_name = f"Helsinki-NLP/opus-mt-hu-{lang}" | |
# "egy" means something like "a", but is used less frequently than in English. | |
#source = f"Ő egy {occupation}." | |
source = f"Ő {occupation}." | |
model = inseq.load_model(model_name, "integrated_gradients") | |
out = model.attribute([source], attribute_target=True, n_steps=150, device=DEVICE, return_convergence_delta=False) | |
if aggregate: | |
squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator]) | |
html = out.show(return_html=True, display=True, aggregator=squeezesum) | |
else: | |
html = out.show(return_html=True, display=True) | |
# Save html | |
with open(result_fp, 'w') as f: | |
f.write(html) | |
return html | |
with open("description.md") as fh: | |
desc = fh.read() | |
with open("simple_translation.md") as fh: | |
simple_translation = fh.read() | |
with open("contrastive_pair.md") as fh: | |
contrastive_pair = fh.read() | |
with open("notice.md") as fh: | |
notice = fh.read() | |
OCCUPATIONS = [ | |
"nő (woman)", | |
"férfi (man)", | |
"nővér (nurse)", | |
"tudós (scientist)", | |
"mérnök (engineer)", | |
"pék (baker)", | |
"tanár (teacher)", | |
"esküvőszervező (wedding organizer)", | |
"vezérigazgató (CEO)", | |
] | |
LANGS = [ | |
"en", | |
"fr", | |
"de", | |
] | |
with gradio.Blocks(title="Gender Bias in MT: Hungarian to English") as iface: | |
gradio.Markdown(desc) | |
print(simple_translation) | |
with gradio.Accordion("Simple translation", open=True): | |
gradio.Markdown(simple_translation) | |
with gradio.Accordion("Contrastive pair", open=False): | |
gradio.Markdown(contrastive_pair) | |
gradio.Markdown("**Does the model seem to rely on gender stereotypes in its translations?**") | |
with gradio.Tab("Simple translation"): | |
with gradio.Row(equal_height=True): | |
with gradio.Column(scale=4): | |
occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0]) | |
with gradio.Column(scale=4): | |
target_lang = gradio.Dropdown(label="Target Language", choices=LANGS, value=LANGS[0]) | |
aggregate_subwords = gradio.Radio( | |
["yes", "no"], label="Aggregate subwords?", value="yes" | |
) | |
but = gradio.Button("Translate & Attribute") | |
out = gradio.HTML() | |
args = [occupation_sel, target_lang, aggregate_subwords] | |
but.click(run_simple, inputs=args, outputs=out) | |
with gradio.Tab("Contrastive pair"): | |
with gradio.Row(equal_height=True): | |
with gradio.Column(scale=4): | |
occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0]) | |
but = gradio.Button("Translate & Attribute") | |
out = gradio.HTML() | |
args = [occupation_sel] | |
but.click(run_counterfactual, inputs=args, outputs=out) | |
with gradio.Accordion("Notes & References", open=False): | |
gradio.Markdown(notice) | |
iface.launch() |