kargaranamir's picture
upload.
e94a434
raw
history blame
4.76 kB
# coding=utf-8
# Copyright 2023 The GlotLID Authors.
# Lint as: python3
"""
GlotLID Space
"""
""" This space is built based on AMR-KELEG/ALDi space """
import constants
import pandas as pd
import streamlit as st
from huggingface_hub import hf_hub_download
from GlotScript import get_script_predictor
import matplotlib.pyplot as plt
import fasttext
import altair as alt
from altair import X, Y, Scale
import base64
@st.cache_resource
def load_sp():
sp = get_script_predictor()
return sp
sp = load_sp()
def get_script(text):
"""Get the writing system of given text.
Args:
text: The text to be preprocessed.
Returns:
The writing system of text.
"""
return sp(text)[0]
@st.cache_data
def render_svg(svg):
"""Renders the given svg string."""
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>'
c = st.container()
c.write(html, unsafe_allow_html=True)
@st.cache_data
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv(index=None).encode("utf-8")
@st.cache_resource
def load_model(model_name):
model_path = hf_hub_download(repo_id=model_name, filename="model.bin")
model = fasttext.load_model(model_path)
return model
model = load_model(constants.MODEL_NAME)
def compute(sentences):
"""Computes the language labels for the given sentences.
Args:
sentences: A list of sentences.
Returns:
A list of language probablities and labels for the given sentences.
"""
progress_text = "Computing Language..."
my_bar = st.progress(0, text=progress_text)
BATCH_SIZE = 1
probs = []
labels = []
preprocessed_sentences = sentences
for first_index in range(0, len(preprocessed_sentences), BATCH_SIZE):
outputs = model.predict(preprocessed_sentences[first_index : first_index + BATCH_SIZE])
# BATCH_SIZE = 1
outputs_labels = outputs[0][0]
outputs_probs = outputs[1][0]
probs = probs + [max(min(o, 1), 0) for o in outputs_probs]
labels = labels + outputs_labels
my_bar.progress(
min((first_index + BATCH_SIZE) / len(preprocessed_sentences), 1),
text=progress_text,
)
my_bar.empty()
return probs, labels
render_svg(open("assets/GlotLID_logo.svg").read())
tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"])
with tab1:
sent = st.text_input(
"Sentence:", placeholder="Enter a sentence.", on_change=None
)
# TODO: Check if this is needed!
clicked = st.button("Submit")
if sent:
probs, labels = compute([sent])
prob = probs[0]
label = labels[0]
ORANGE_COLOR = "#FF8000"
fig, ax = plt.subplots(figsize=(8, 1))
fig.patch.set_facecolor("none")
ax.set_facecolor("none")
ax.spines["left"].set_color(ORANGE_COLOR)
ax.spines["bottom"].set_color(ORANGE_COLOR)
ax.tick_params(axis="x", colors=ORANGE_COLOR)
ax.spines[["right", "top"]].set_visible(False)
ax.barh(y=[0], width=[prob], color=ORANGE_COLOR)
ax.set_xlim(0, 1)
ax.set_ylim(-1, 1)
ax.set_title(f"Langauge is: {label}", color=ORANGE_COLOR)
ax.get_yaxis().set_visible(False)
ax.set_xlabel("Confidence", color=ORANGE_COLOR)
st.pyplot(fig)
print(sent)
with open("logs.txt", "a") as f:
f.write(sent + "\n")
with tab2:
file = st.file_uploader("Upload a file", type=["txt"])
if file is not None:
df = pd.read_csv(file, sep="\t", header=None)
df.columns = ["Sentence"]
df.reset_index(drop=True, inplace=True)
# TODO: Run the model
df['Probs'], df["Language"] = compute(df["Sentence"].tolist())
# A horizontal rule
st.markdown("""---""")
chart = (
alt.Chart(df.reset_index())
.mark_area(color="darkorange", opacity=0.5)
.encode(
x=X(field="index", title="Sentence Index"),
y=Y("Probs", scale=Scale(domain=[0, 1])),
)
)
st.altair_chart(chart.interactive(), use_container_width=True)
col1, col2 = st.columns([4, 1])
with col1:
# Display the output
st.table(
df,
)
with col2:
# Add a download button
csv = convert_df(df)
st.download_button(
label=":file_folder: Download predictions as CSV",
data=csv,
file_name="GlotLID.csv",
mime="text/csv",
)