Spaces:
Running
Running
import numpy as np | |
import torch | |
import pandas as pd | |
import translate | |
import gradio as gr | |
data = pd.read_csv("./embedding_data.csv") | |
embeddings = np.load("./embeddings.npy") | |
def normalize_vector(v): | |
norm = np.linalg.norm(v) | |
if norm == 0: | |
return v | |
return v / norm | |
def embed_one(model, tokenizer, text, normalize=True): | |
tokens = tokenizer(text, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
embedding = model.model.encoder(**tokens).last_hidden_state.mean(axis=1) | |
embedding = embedding.detach().numpy()[0] | |
if normalize: | |
return normalize_vector(embedding) | |
else: | |
return embedding | |
def knn(query_embedding, embeddings, df, k=5, hebrew=True): | |
sims = np.dot(embeddings, query_embedding.T) | |
outs = np.argsort(sims, axis=0)[-k:][::-1] | |
select = outs.ravel() | |
if hebrew: | |
return df.iloc[select][["arabic", "hebrew", "validated"]] | |
else: | |
return df.iloc[select][["arabic", "english", "validated"]] | |
def run_knn(text, k=5): | |
print(text) | |
query_embedding = embed_one(translate.model_from_ar, | |
translate.tokenizer_from_ar, text) | |
return knn(query_embedding, embeddings, data, k=k, hebrew=True) | |
def style_dataframe(df): | |
styled_df = df.style.set_properties(**{ | |
'font-family': 'Arial, sans-serif', | |
'font-size': '20px', | |
'text-align': 'right', | |
'direction': 'rtl', | |
'align': 'right' | |
}).set_table_styles([ | |
{'selector': 'th', 'props': [('text-align', 'right')]} | |
]) | |
return styled_df | |
def style_dataframe(df): | |
return df.style.set_table_styles([ | |
{'selector': 'thead', 'props': [('text-align', 'right')]}, | |
{'selector': '.index_name', 'props': [('text-align', 'right')]}, | |
]).set_properties(**{ | |
'text-align': 'right', | |
}) # Replace 'column_name' with your actual column name | |
def update_df(hidden_arabic): | |
df = run_knn(hidden_arabic, 100) | |
# replace true and false in validated column with checkmark and x emoji | |
df["validated"] = df["validated"].apply(lambda x: "✅" if x else "❌") | |
# replace name validated with "מאומת" | |
df = df.rename(columns={"validated": "מאומת"}) | |
# replace name arabic with "ערבית" | |
df = df.rename(columns={"arabic": "ערבית"}) | |
# replace name hebrew with "עברית" | |
df = df.rename(columns={"hebrew": "עברית"}) | |
styled_df = style_dataframe(df) | |
return gr.DataFrame(value=styled_df, visible=True) |