Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import os | |
import numpy as np | |
import torch | |
import transformers | |
import tokenizers | |
from model import BertAD | |
DICTIONARY = json.load(open('model/dict.json')) | |
TOKENIZER = tokenizers.BertWordPieceTokenizer(f"model/vocab.txt", lowercase=True) | |
MAX_LEN = 256 | |
MODEL = BertAD() | |
vec = MODEL.state_dict()['bert.embeddings.position_ids'] | |
chkp = torch.load(os.path.join('model', 'model_0.bin'), map_location='cpu') | |
chkp['bert.embeddings.position_ids'] = vec | |
MODEL.load_state_dict(chkp) | |
del chkp, vec | |
def sample_text(text, acronym, max_len): | |
text = text.split() | |
idx = text.index(acronym) | |
left_idx = max(0, idx - max_len//2) | |
right_idx = min(len(text), idx + max_len//2) | |
sampled_text = text[left_idx:right_idx] | |
return ' '.join(sampled_text) | |
def process_data(text, acronym, expansion, tokenizer, max_len): | |
text = str(text) | |
expansion = str(expansion) | |
acronym = str(acronym) | |
n_tokens = len(text.split()) | |
if n_tokens>120: | |
text = sample_text(text, acronym, 120) | |
answers = acronym + ' ' + ' '.join(DICTIONARY[acronym]) | |
start = answers.find(expansion) | |
end = start + len(expansion) | |
char_mask = [0]*len(answers) | |
for i in range(start, end): | |
char_mask[i] = 1 | |
tok_answer = tokenizer.encode(answers) | |
answer_ids = tok_answer.ids | |
answer_offsets = tok_answer.offsets | |
answer_ids = answer_ids[1:-1] | |
answer_offsets = answer_offsets[1:-1] | |
target_idx = [] | |
for i, (off1, off2) in enumerate(answer_offsets): | |
if sum(char_mask[off1:off2])>0: | |
target_idx.append(i) | |
start = target_idx[0] | |
end = target_idx[-1] | |
text_ids = tokenizer.encode(text).ids[1:-1] | |
token_ids = [101] + answer_ids + [102] + text_ids + [102] | |
offsets = [(0,0)] + answer_offsets + [(0,0)]*(len(text_ids) + 2) | |
mask = [1] * len(token_ids) | |
token_type = [0]*(len(answer_ids) + 1) + [1]*(2+len(text_ids)) | |
text = answers + text | |
start = start + 1 | |
end = end + 1 | |
padding = max_len - len(token_ids) | |
if padding>=0: | |
token_ids = token_ids + ([0] * padding) | |
token_type = token_type + [1] * padding | |
mask = mask + ([0] * padding) | |
offsets = offsets + ([(0, 0)] * padding) | |
else: | |
token_ids = token_ids[0:max_len] | |
token_type = token_type[0:max_len] | |
mask = mask[0:max_len] | |
offsets = offsets[0:max_len] | |
assert len(token_ids)==max_len | |
assert len(mask)==max_len | |
assert len(offsets)==max_len | |
assert len(token_type)==max_len | |
return { | |
'ids': token_ids, | |
'mask': mask, | |
'token_type': token_type, | |
'offset': offsets, | |
'start': start, | |
'end': end, | |
'text': text, | |
'expansion': expansion, | |
'acronym': acronym, | |
} | |
def jaccard(str1, str2): | |
a = set(str1.lower().split()) | |
b = set(str2.lower().split()) | |
c = a.intersection(b) | |
return float(len(c)) / (len(a) + len(b) - len(c)) | |
def evaluate_jaccard(text, selected_text, acronym, offsets, idx_start, idx_end): | |
filtered_output = "" | |
for ix in range(idx_start, idx_end + 1): | |
filtered_output += text[offsets[ix][0]: offsets[ix][1]] | |
if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]: | |
filtered_output += " " | |
candidates = DICTIONARY[acronym] | |
candidate_jaccards = [jaccard(w.strip(), filtered_output.strip()) for w in candidates] | |
idx = np.argmax(candidate_jaccards) | |
return candidate_jaccards[idx], candidates[idx] | |
def disambiguate(text, acronym): | |
inputs = process_data(text, acronym, acronym, TOKENIZER, MAX_LEN) | |
ids = torch.tensor(inputs['ids']) | |
mask = torch.tensor(inputs['mask']) | |
token_type = torch.tensor(inputs['token_type']) | |
offsets = inputs['offset'] | |
expansion = inputs['expansion'] | |
acronym = inputs['acronym'] | |
ids = torch.unsqueeze(ids, 0) | |
mask = torch.unsqueeze(mask, 0) | |
token_type = torch.unsqueeze(token_type, 0) | |
start_logits, end_logits = MODEL(ids, mask, token_type) | |
start_prob = torch.softmax(start_logits, axis=-1).detach().numpy() | |
end_prob = torch.softmax(end_logits, axis=-1).detach().numpy() | |
start_idx = np.argmax(start_prob[0,:]) | |
end_idx = np.argmax(end_prob[0,:]) | |
_, exp = evaluate_jaccard(text, expansion, acronym, offsets, start_idx, end_idx) | |
return exp | |
text = gr.inputs.Textbox(lines=5, label="Context",\ | |
default="Particularly , we explore four CNN architectures , AlexNet , GoogLeNet , VGG-16 , and ResNet to derive features for all images in our dataset , which are labeled as private or public .") | |
acronym = gr.inputs.Dropdown(choices=sorted(list(DICTIONARY.keys())), label="Acronym", default="CNN") | |
expansion = gr.outputs.Textbox(label="Expansion") | |
iface = gr.Interface(fn=disambiguate, inputs=[text, acronym], outputs=expansion, \ | |
title="Scientific Acronym Disambiguation", description="Demo of model based on https://arxiv.org/abs/2102.08818") | |
iface.launch() |