Spaces:
Runtime error
Runtime error
""" | |
Name: karan kumar Pathak | |
Email: 2020fc04335@wilp.bits-pilani.com | |
""" | |
import os | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
from prettytable import PrettyTable | |
import pandas as pd | |
import torch | |
import traceback | |
config = { | |
"model_type": "roberta", | |
"model_name_or_path": "roberta-large", | |
"logic_lambda": 0.5, | |
"prior": "random", | |
"mask_rate": 0.0, | |
"cand_k": 1, | |
"max_seq1_length": 256, | |
"max_seq2_length": 128, | |
"max_num_questions": 8, | |
"do_lower_case": False, | |
"seed": 42, | |
"n_gpu": torch.cuda.device_count(), | |
} | |
os.system('git clone https://github.com/kkpathak91/project_metch/') | |
os.system('rm -r project_metch/data/') | |
os.system('rm -r project_metch/results/') | |
os.system('rm -r project_metch/models/') | |
os.system('mv project_metch/* ./') | |
model_dir = snapshot_download('kkpathak91/FVM') | |
config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/') | |
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/') | |
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/') | |
from src.loren import Loren | |
loren = Loren(config, verbose=False) | |
try: | |
js = loren.check('Donald Trump won the 2020 U.S. presidential election.') | |
except Exception as e: | |
raise ValueError(e) | |
def highlight_phrase(text, phrase): | |
text = loren.fc_client.tokenizer.clean_up_tokenization(text) | |
return text.replace('<mask>', f'<i><b>{phrase}</b></i>') | |
def highlight_entity(text, entity): | |
return text.replace(entity, f'<i><b>{entity}</b></i>') | |
def gradio_formatter(js, output_type): | |
zebra_css = ''' | |
tr:nth-child(even) { | |
background: #f1f1f1; | |
} | |
thead{ | |
background: #f1f1f1; | |
}''' | |
if output_type == 'e': | |
data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]} | |
elif output_type == 'z': | |
p_sup, p_ref, p_nei = [], [], [] | |
for x in js['phrase_veracity']: | |
max_idx = torch.argmax(torch.tensor(x)).tolist() | |
x = ['%.4f' % xx for xx in x] | |
x[max_idx] = f'<i><b>{x[max_idx]}</b></i>' | |
p_sup.append(x[2]) | |
p_ref.append(x[0]) | |
p_nei.append(x[1]) | |
data = { | |
'Claim Phrase': js['claim_phrases'], | |
'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])], | |
'p_SUP': p_sup, | |
'p_REF': p_ref, | |
'p_NEI': p_nei, | |
} | |
else: | |
raise NotImplementedError | |
data = pd.DataFrame(data) | |
pt = PrettyTable(field_names=list(data.columns), | |
align='l', border=True, hrules=1, vrules=1) | |
for v in data.values: | |
pt.add_row(v) | |
html = pt.get_html_string(attributes={ | |
'style': 'border-width: 2px; bordercolor: black' | |
}, format=True) | |
html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html | |
html = html.replace('<', '<').replace('>', '>') | |
return html | |
def run(claim): | |
try: | |
js = loren.check(claim) | |
except Exception as error_msg: | |
exc = traceback.format_exc() | |
msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' | |
loren.logger.error(claim) | |
loren.logger.error(msg) | |
return 'Oops, something went wrong.', '', '' | |
label = js['claim_veracity'] | |
loren.logger.warning(label + str(js)) | |
ev_html = gradio_formatter(js, 'e') | |
z_html = gradio_formatter(js, 'z') | |
return label, z_html, ev_html | |
iface = gr.Interface( | |
fn=run, | |
inputs="text", | |
outputs=[ | |
'text', | |
'html', | |
'html', | |
], | |
examples=['Kanpur is a city in Nepal', | |
'PV Sindhu is an Indian Badminton Player.'], | |
title="A Framework for Data-Driven Document Evaluation and Scoring", | |
layout='horizontal', | |
description="[Student Name: Karan Kumar Pathak] " " [Roll No.: 2020fc04334] ", | |
flagging_dir='results/flagged/', | |
allow_flagging='auto', | |
flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise', | |
'Error: Require Commonsense', 'Error: Evidence Retrieval'] | |
) | |
iface.launch(enable_queue=True) |