|
import json |
|
import requests |
|
import re |
|
from relbert import RelBERT |
|
import gradio as gr |
|
|
|
model = RelBERT(model='relbert/relbert-roberta-large') |
|
|
|
|
|
def get_example(): |
|
url = "https://huggingface.co/datasets/relbert/analogy_questions/raw/main/dataset/sat/test.jsonl" |
|
r = requests.get(url) |
|
example = [json.loads(i) for i in r.content.decode().split('\n') if len(i) > 0] |
|
return example |
|
|
|
|
|
def cosine_similarity(a, b, zero_vector_mask: float = -100): |
|
norm_a = sum(map(lambda x: x * x, a)) ** 0.5 |
|
norm_b = sum(map(lambda x: x * x, b)) ** 0.5 |
|
if norm_b * norm_a == 0: |
|
return zero_vector_mask |
|
return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b) |
|
|
|
|
|
def clean(text): |
|
text = re.sub(r"\A\s+", "", text) |
|
text = re.sub(r"\s+\Z", "", text) |
|
return text |
|
|
|
|
|
def greet( |
|
query, |
|
candidate_1, |
|
candidate_2, |
|
candidate_3, |
|
candidate_4, |
|
candidate_5, |
|
candidate_6): |
|
query = [clean(i) for i in query.split(',')] |
|
|
|
if len(query) == 0: |
|
raise ValueError(f'ERROR: query is empty {query}') |
|
if len(query) == 1: |
|
raise ValueError(f'ERROR: query contains single word {query}') |
|
if len(query) > 2: |
|
raise ValueError(f'ERROR: query contains more than two word {query}') |
|
|
|
pairs = [] |
|
pairs_id = [] |
|
for n, i in enumerate([ |
|
candidate_1, |
|
candidate_2, |
|
candidate_3, |
|
candidate_4, |
|
candidate_5, |
|
candidate_6 |
|
]): |
|
if i == '': |
|
continue |
|
candidate = [clean(x) for x in i.split(',')] |
|
if len(candidate) == 1: |
|
raise ValueError(f'ERROR: candidate {n + 1} contains single word {candidate}') |
|
if len(candidate) > 2: |
|
raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {candidate}') |
|
pairs.append(candidate) |
|
pairs_id.append(n+1) |
|
if len(pairs_id) < 2: |
|
raise ValueError(f'ERROR: please specify at least two candidates: {pairs}') |
|
vectors = model.get_embedding(pairs+[query]) |
|
vector_q = vectors.pop(-1) |
|
sims = [] |
|
for v in vectors: |
|
sims.append(cosine_similarity(v, vector_q)) |
|
output = {f'candidate {i}: [{p[0]}, {p[1]}]': s for i, s, p in zip(pairs_id, sims, pairs)} |
|
return output |
|
|
|
|
|
examples = get_example()[:15] |
|
examples = [[','.join(i['stem'])] + [','.join(c) for c in i['choice'] + [''] * (6 - len(i['choice']))] for i in examples] |
|
demo = gr.Interface( |
|
fn=greet, |
|
inputs=[ |
|
gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"), |
|
], |
|
outputs="label", |
|
examples=examples |
|
) |
|
demo.launch(show_error=True) |
|
|