Joshua Sundance Bailey
initial commit
171c1a6
raw
history blame
2.67 kB
import os
import random
from typing import Sequence
import datasets
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import streamlit as st
from setfit import SetFitModel
st.set_page_config(
page_title="mtg-coloridentity-multilabel-classification",
page_icon="πŸ§™",
layout="wide",
initial_sidebar_state="collapsed",
menu_items=None,
)
default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
HF_HOME = os.environ.get("HF_HOME", default_hf_home)
coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification"
colors = ["B", "G", "R", "U", "W"]
labels = ["black", "green", "red", "blue", "white"]
sns.set()
col1, col2 = st.columns(2)
@st.cache_resource
def get_model(
model_id: str = coloridentity_model,
cache_dir: str = HF_HOME,
**kwargs,
) -> SetFitModel:
return SetFitModel.from_pretrained(model_id, cache_dir=cache_dir, **kwargs)
@st.cache_data
def get_data(
repo_id: str = coloridentity_model,
cache_dir: str = HF_HOME,
**kwargs,
) -> datasets.Dataset:
dataset_dict = datasets.load_dataset(repo_id, cache_dir=cache_dir, **kwargs)
return datasets.concatenate_datasets(
list(dataset_dict.values()),
)
def get_random_text() -> str:
return dataset.select([random.randint(0, len(dataset))])[0]["text"] # nosec
@st.cache_data
def get_preds(input_text: str) -> Sequence[float]:
return model.predict_proba(input_text)
def prob_bars(preds: Sequence[float]) -> None:
_preds = (float(p) for p in preds)
df = pd.DataFrame(
zip(labels, _preds),
columns=["Color", "Probability"],
)
plt.figure(figsize=(8, 6))
ax = sns.barplot(x="Color", y="Probability", data=df, palette=labels)
# Add data labels on each bar
for p in ax.patches:
ax.annotate(
format(p.get_height(), ".4f"),
(p.get_x() + p.get_width() / 2.0, p.get_height()),
ha="center",
va="center",
xytext=(0, 9),
textcoords="offset points",
)
plt.title("Prediction Probabilities")
plt.xlabel("Color")
plt.ylabel("Probability")
st.pyplot(plt.gcf())
model = get_model()
dataset = get_data()
default_text = get_random_text()
if "input_text" not in st.session_state:
st.session_state.input_text = default_text
with col1:
if st.button("🎲 Roll the Dice"):
st.session_state.input_text = get_random_text()
input_text = st.text_area(
"Card name and text",
st.session_state.input_text,
height=400,
)
preds = get_preds(input_text)
with col2:
prob_bars(preds)