from __future__ import annotations import reprlib from pathlib import Path from tempfile import mkdtemp import streamlit as st from huggingface_hub import hf_hub_download from skops import card from skops.card._model_card import PlotSection, split_subsection_names from utils import iterate_key_section_content, process_card_for_rendering from tasks import AddSectionTask, AddFigureTask, DeleteSectionTask, TaskState, UpdateFigureTask, UpdateSectionTask arepr = reprlib.Repr() arepr.maxstring = 24 tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files hf_path = Path(mkdtemp(prefix="skops-")) # hf repo def load_model_card_from_repo(repo_id: str) -> card.Card: print("downloading model card") path = hf_hub_download(repo_id, "README.md") model_card = card.parse_modelcard(path) return model_card def _update_model_card( model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool, ) -> None: # This is a very roundabout way to update the model card but it's necessary # because of how streamlit handles session state. Basically, there have to # be "key" arguments, which have to be retrieved from the session_state, as # they are up-to-date. Just getting the Python variables is not enough, as # they can be out of date. # key names must match with those used in form new_title = st.session_state[f"{key}.title"] new_content = st.session_state[f"{key}.content"] # determine if title is the same old_title_split = split_subsection_names(section_name) new_title_split = old_title_split[:-1] + [new_title] is_title_same = old_title_split == new_title_split # determine if content is the same if is_fig: if isinstance(new_content, PlotSection): is_content_same = content == new_content else: is_content_same = not bool(new_content) else: is_content_same = content == new_content if is_title_same and is_content_same: return if is_fig: fpath = None if new_content: # new figure uploaded fname = new_content.name.replace(" ", "_") fpath = tmp_path / fname task = UpdateFigureTask( model_card, key=key, old_name=section_name, new_name=new_title, data=new_content, path=fpath, ) else: task = UpdateSectionTask( model_card, key=key, old_name=section_name, new_name=new_title, old_content=content, new_content=new_content, ) st.session_state.task_state.add(task) def _add_section(model_card: card.Card, key: str) -> None: section_name = f"{key}/Untitled" task = AddSectionTask(model_card, title=section_name, content="[More Information Needed]") st.session_state.task_state.add(task) def _add_figure(model_card: card.Card, key: str) -> None: section_name = f"{key}/Untitled" task = AddFigureTask(model_card, title=section_name, content="cat.png") st.session_state.task_state.add(task) def _delete_section(model_card: card.Card, key: str) -> None: task = DeleteSectionTask(model_card, key=key) st.session_state.task_state.add(task) def _add_section_form( model_card: card.Card, key: str, section_name: str, old_title: str, content: str ) -> None: with st.form(key, clear_on_submit=False): st.header(section_name) # setting the 'key' argument below to update the session_state st.text_input("Section name", value=old_title, key=f"{key}.title") st.text_area("Content", value=content, key=f"{key}.content") is_fig = False st.form_submit_button( "Update", on_click=_update_model_card, args=(model_card, key, section_name, content, is_fig), ) def _add_fig_form( model_card: card.Card, key: str, section_name: str, old_title: str, content: str ) -> None: with st.form(key, clear_on_submit=False): st.header(section_name) # setting the 'key' argument below to update the session_state st.text_input("Section name", value=old_title, key=f"{key}.title") st.file_uploader("Upload image", key=f"{key}.content") is_fig = True st.form_submit_button( "Update", on_click=_update_model_card, args=(model_card, key, section_name, content, is_fig), ) def create_form_from_section( model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool = False ) -> None: split_sections = split_subsection_names(section_name) old_title = split_sections[-1] if is_fig: _add_fig_form( model_card=model_card, key=key, section_name=section_name, old_title=old_title, content=content, ) else: _add_section_form( model_card=model_card, key=key, section_name=section_name, old_title=old_title, content=content, ) col_0, col_1, col_2 = st.columns([4, 2, 2]) with col_0: st.button( f"delete '{arepr.repr(old_title)}'", on_click=_delete_section, args=(model_card, key), key=f"{key}.delete", ) with col_1: st.button( "add section below", on_click=_add_section, args=(model_card, key), key=f"{key}.add", ) with col_2: st.button( "add figure below", on_click=_add_figure, args=(model_card, key), key=f"{key}.fig", ) def display_sections(model_card: card.Card) -> None: for key, section_name, content, is_fig in iterate_key_section_content(model_card._data): create_form_from_section(model_card, key, section_name, content, is_fig) def display_model_card(model_card: card.Card) -> None: rendered = model_card.render() metadata, rendered = process_card_for_rendering(rendered) # strip metadata with st.expander("show metadata"): st.text(metadata) st.markdown(rendered, unsafe_allow_html=True) def reset_model_card() -> None: if "task_state" not in st.session_state: return if "model_card" not in st.session_state: del st.session_state["model_card"] while st.session_state.task_state.done_list: st.session_state.task_state.undo() def delete_model_card() -> None: if "model_card" in st.session_state: del st.session_state["model_card"] if "task_state" in st.session_state: st.session_state.task_state.reset() def undo_last(): st.session_state.task_state.undo() display_model_card(st.session_state.model_card) def redo_last(): st.session_state.task_state.redo() display_model_card(st.session_state.model_card) def add_download_model_card_button(): model_card = st.session_state.get("model_card") download_disabled = not bool(model_card) data = model_card.render() st.download_button( "Save (md)", data=data, disabled=download_disabled ) def edit_input_form(): if "task_state" not in st.session_state: st.session_state.task_state = TaskState() with st.sidebar: col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2]) undo_disabled = not bool(st.session_state.task_state.done_list) redo_disabled = not bool(st.session_state.task_state.undone_list) with col_0: name = f"UNDO ({len(st.session_state.task_state.done_list)})" st.button(name, on_click=undo_last, disabled=undo_disabled) with col_1: name = f"REDO ({len(st.session_state.task_state.undone_list)})" st.button(name, on_click=redo_last, disabled=redo_disabled) with col_2: st.button("Reset", on_click=reset_model_card) col_0, col_1, *_ = st.columns([2, 2, 2, 2]) with col_0: add_download_model_card_button() with col_1: st.button("Delete", on_click=delete_model_card) if "model_card" in st.session_state: display_sections(st.session_state.model_card) if "model_card" in st.session_state: display_model_card(st.session_state.model_card)