# coding=utf-8 # Copyright 2023 The GlotLID Authors. # Lint as: python3 """ GlotLID Space """ """ This space is built based on AMR-KELEG/ALDi space """ import constants import pandas as pd import streamlit as st from huggingface_hub import hf_hub_download from GlotScript import get_script_predictor import matplotlib.pyplot as plt import fasttext import altair as alt from altair import X, Y, Scale import base64 @st.cache_resource def load_sp(): sp = get_script_predictor() return sp sp = load_sp() def get_script(text): """Get the writing system of given text. Args: text: The text to be preprocessed. Returns: The writing system of text. """ return sp(text)[0] @st.cache_data def render_svg(svg): """Renders the given svg string.""" b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") html = rf'
' c = st.container() c.write(html, unsafe_allow_html=True) @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv(index=None).encode("utf-8") @st.cache_resource def load_model(model_name): model_path = hf_hub_download(repo_id=model_name, filename="model.bin") model = fasttext.load_model(model_path) return model model = load_model(constants.MODEL_NAME) def compute(sentences): """Computes the language labels for the given sentences. Args: sentences: A list of sentences. Returns: A list of language probablities and labels for the given sentences. """ progress_text = "Computing Language..." my_bar = st.progress(0, text=progress_text) BATCH_SIZE = 1 probs = [] labels = [] preprocessed_sentences = sentences for first_index in range(0, len(preprocessed_sentences), BATCH_SIZE): outputs = model.predict(preprocessed_sentences[first_index : first_index + BATCH_SIZE]) # BATCH_SIZE = 1 outputs_labels = outputs[0][0] outputs_probs = outputs[1][0] probs = probs + [max(min(o, 1), 0) for o in outputs_probs] labels = labels + outputs_labels my_bar.progress( min((first_index + BATCH_SIZE) / len(preprocessed_sentences), 1), text=progress_text, ) my_bar.empty() return probs, labels render_svg(open("assets/GlotLID_logo.svg").read()) tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) with tab1: sent = st.text_input( "Sentence:", placeholder="Enter a sentence.", on_change=None ) # TODO: Check if this is needed! clicked = st.button("Submit") if sent: probs, labels = compute([sent]) prob = probs[0] label = labels[0] ORANGE_COLOR = "#FF8000" fig, ax = plt.subplots(figsize=(8, 1)) fig.patch.set_facecolor("none") ax.set_facecolor("none") ax.spines["left"].set_color(ORANGE_COLOR) ax.spines["bottom"].set_color(ORANGE_COLOR) ax.tick_params(axis="x", colors=ORANGE_COLOR) ax.spines[["right", "top"]].set_visible(False) ax.barh(y=[0], width=[prob], color=ORANGE_COLOR) ax.set_xlim(0, 1) ax.set_ylim(-1, 1) ax.set_title(f"Langauge is: {label}", color=ORANGE_COLOR) ax.get_yaxis().set_visible(False) ax.set_xlabel("Confidence", color=ORANGE_COLOR) st.pyplot(fig) print(sent) with open("logs.txt", "a") as f: f.write(sent + "\n") with tab2: file = st.file_uploader("Upload a file", type=["txt"]) if file is not None: df = pd.read_csv(file, sep="\t", header=None) df.columns = ["Sentence"] df.reset_index(drop=True, inplace=True) # TODO: Run the model df['Probs'], df["Language"] = compute(df["Sentence"].tolist()) # A horizontal rule st.markdown("""---""") chart = ( alt.Chart(df.reset_index()) .mark_area(color="darkorange", opacity=0.5) .encode( x=X(field="index", title="Sentence Index"), y=Y("Probs", scale=Scale(domain=[0, 1])), ) ) st.altair_chart(chart.interactive(), use_container_width=True) col1, col2 = st.columns([4, 1]) with col1: # Display the output st.table( df, ) with col2: # Add a download button csv = convert_df(df) st.download_button( label=":file_folder: Download predictions as CSV", data=csv, file_name="GlotLID.csv", mime="text/csv", )