import glob import logging import os import shutil import time import zipfile from pathlib import Path from typing import Callable, List, Optional, Set import accelerate import einops import huggingface_hub import numpy as np import pandas as pd import torch import transformers import yaml from h2o_wave import Q, data, ui from sqlitedict import SqliteDict from llm_studio.app_utils.config import default_cfg from llm_studio.app_utils.hugging_face_utils import ( get_model_card, publish_model_to_hugging_face, ) from llm_studio.app_utils.sections.chat import chat_tab, load_cfg_model_tokenizer from llm_studio.app_utils.sections.common import clean_dashboard from llm_studio.app_utils.utils import ( add_model_type, flatten_dict, get_cfg_list_items, get_data_dir, get_download_link, get_experiment_status, get_experiments, get_model_types, get_problem_categories, get_problem_types, get_ui_elements, get_unique_name, hf_repo_friendly_name, parse_ui_elements, remove_model_type, set_env, start_experiment, ) from llm_studio.app_utils.wave_utils import busy_dialog, ui_table_from_df, wave_theme from llm_studio.python_configs.cfg_checks import check_config_for_errors from llm_studio.src.datasets.text_utils import get_tokenizer from llm_studio.src.tooltips import tooltips from llm_studio.src.utils.config_utils import ( NON_GENERATION_PROBLEM_TYPES, load_config_py, load_config_yaml, save_config_yaml, ) from llm_studio.src.utils.exceptions import LLMResourceException from llm_studio.src.utils.export_utils import ( check_available_space, get_artifact_path_path, get_logs_path, get_model_path, get_predictions_path, save_logs, save_prediction_outputs, ) from llm_studio.src.utils.logging_utils import write_flag from llm_studio.src.utils.modeling_utils import unwrap_model from llm_studio.src.utils.plot_utils import PLOT_ENCODINGS from llm_studio.src.utils.utils import add_file_to_zip, kill_child_processes logger = logging.getLogger(__name__) async def experiment_start(q: Q) -> None: """Display experiment start cards.""" await clean_dashboard(q, mode="experiment_start", exclude=["experiment/start"]) q.client["nav/active"] = "experiment/start" show_update_warnings = True is_create_experiment = False # reset certain configs if new experiment start session if ( q.args.__wave_submission_name__ == "experiment/start" or q.args.__wave_submission_name__ == "experiment/start_experiment" or q.args.__wave_submission_name__ == "dataset/newexperiment" or q.args.__wave_submission_name__ == "dataset/newexperiment/from_current" or q.args.__wave_submission_name__ == "experiment/list/new" ): q.client["experiment/start/cfg_experiment_prev"] = None q.client["experiment/start/cfg_file_prev"] = None q.client["experiment/start/prev_dataset"] = None q.client["experiment/start/cfg_sub"] = None show_update_warnings = False is_create_experiment = True # get all the datasets available df_datasets = q.client.app_db.get_datasets_df() # Hide inference only datasets df_datasets = df_datasets.loc[df_datasets["train_rows"].notna()] if ( not q.client["experiment/start/dataset"] or q.client["experiment/start/dataset"] not in df_datasets.id.astype(str).values ): if len(df_datasets) >= 1: q.client["experiment/start/dataset"] = str(df_datasets["id"].iloc[-1]) else: q.client["experiment/start/dataset"] = "1" warning_message = "Experiment settings might be updated after changing {}" items = [ ui.separator(name="general_expander", label="General settings"), ui.dropdown( name="experiment/start/dataset", label="Dataset", required=True, value=q.client["experiment/start/dataset"], choices=[ ui.choice(str(row["id"]), str(row["name"])) for _, row in df_datasets.iterrows() ], trigger=True, tooltip=tooltips["experiments_dataset"], ), ] if ( show_update_warnings and q.client["experiment/start/dataset_prev"] != q.client["experiment/start/dataset"] ): items += [ ui.message_bar(type="warning", text=warning_message.format("Dataset")) ] show_update_warnings = False if ( q.client["experiment/start/cfg_file"] is None or q.client["experiment/start/dataset_prev"] != q.client["experiment/start/dataset"] ) and q.client["experiment/start/cfg_category"] != "experiment": dataset = q.client.app_db.get_dataset(q.client["experiment/start/dataset"]) if dataset is not None: problem_type = dataset.config_file.replace(dataset.path + "/", "").replace( ".yaml", "" ) else: problem_type = default_cfg.cfg_file q.client["experiment/start/cfg_file"] = problem_type q.client["experiment/start/cfg_category"] = problem_type.split("_")[0] if q.client["experiment/start/cfg_category"] == "experiment": q.client["experiment/start/cfg_file"] = "experiment" # get all experiments df_experiments = get_experiments(q, mode="train") # get all problem category choices choices_problem_categories = [ ui.choice(name, label) for name, label in get_problem_categories() ] if len(df_experiments["id"]) > 0: choices_problem_categories += [ui.choice("experiment", "From Experiment")] # set default value of problem type if no match to category if ( q.client["experiment/start/cfg_category"] not in q.client["experiment/start/cfg_file"] ): if q.client["experiment/start/cfg_category"] != "experiment": q.client["experiment/start/cfg_file"] = get_problem_types( category=q.client["experiment/start/cfg_category"] )[0][0] # get all problem type choices choices_problem_types = [ ui.choice(name, label) for name, label in get_problem_types( category=q.client["experiment/start/cfg_category"] ) ] # remove model type if present in cfg file name here q.client["experiment/start/cfg_file"] = remove_model_type( q.client["experiment/start/cfg_file"] ) if len(df_experiments["id"]) > 0: if q.client["experiment/start/cfg_experiment"] is None: q.client["experiment/start/cfg_experiment"] = str( df_experiments["id"].iloc[0] ) # Default pretrained from the previous experiment to False if ( q.client["experiment/start/cfg_experiment_pretrained"] is None or is_create_experiment ): q.client["experiment/start/cfg_experiment_pretrained"] = False if q.client["experiment/start/cfg_category"] != "experiment": items += [ ui.dropdown( name="experiment/start/cfg_file", label="Problem Type", required=True, choices=choices_problem_types, value=q.client["experiment/start/cfg_file"], trigger=True, tooltip=tooltips["experiments_problem_type"], ) ] model_types = get_model_types(q.client["experiment/start/cfg_file"]) if len(model_types) > 0: choices = [ui.choice(name, label) for name, label in model_types] if q.client["experiment/start/cfg_sub"] in [None, ""]: q.client["experiment/start/cfg_sub"] = model_types[0][0] items += [ ui.dropdown( name="experiment/start/cfg_sub", label="Model Type", required=True, choices=choices, value=q.client["experiment/start/cfg_sub"], trigger=True, ) ] else: q.client["experiment/start/cfg_sub"] = "" # add model type to cfg file name here q.client["experiment/start/cfg_file"] = add_model_type( q.client["experiment/start/cfg_file"], q.client["experiment/start/cfg_sub"] ) if ( show_update_warnings and q.client["experiment/start/cfg_file_prev"] != q.client["experiment/start/cfg_file"] and q.client["experiment/start/cfg_category"] != "experiment" ): items += [ ui.message_bar(type="warning", text=warning_message.format("Problem Type")) ] show_update_warnings = False if q.client["experiment/start/cfg_category"] == "experiment": items += [ ui.dropdown( name="experiment/start/cfg_experiment", label="Experiment", required=True, choices=[ ui.choice(str(row.id), row["name"]) for _, row in df_experiments.iterrows() ], value=q.client["experiment/start/cfg_experiment"], trigger=True, ) ] if ( show_update_warnings and q.client["experiment/start/cfg_experiment_prev"] != q.client["experiment/start/cfg_experiment"] ): items += [ ui.message_bar( type="warning", text=warning_message.format("previous Experiment") ) ] # Show pretrained weights toggle only for successfully finished experiments if ( df_experiments.loc[ df_experiments.id == int(q.client["experiment/start/cfg_experiment"]), "status", ].values[0] == "finished" ): items += [ ui.toggle( name="experiment/start/cfg_experiment_pretrained", label="Use previous experiment weights", value=q.client["experiment/start/cfg_experiment_pretrained"], trigger=True, ) ] # only show yaml option, when not starting from another experiment if q.client["experiment/start/cfg_category"] != "experiment": items += [ ui.toggle( name="experiment/start/from_yaml", label="Import config from YAML", value=False, trigger=True, tooltip=tooltips["experiments_import_config_from_yaml"], ) ] if q.args["experiment/start/from_yaml"]: items += [ ui.file_upload( name="experiment/upload_yaml", label="Upload!", multiple=False, file_extensions=["yaml"], ) ] if q.args["experiment/upload_yaml"] is not None: # reset previous, so the UI will be reloaded q.client["experiment/start/cfg_file_prev"] = None await config_import_uploaded_file(q) logger.info( f"PREV {q.client['experiment/start/cfg_file_prev']} " f"{q.client['experiment/start/cfg_file']} " f"{q.client['experiment/start/dataset_prev']} " f"{q.client['experiment/start/dataset']} " f"{q.client['experiment/start/cfg_experiment_prev']} " f"{q.client['experiment/start/cfg_experiment']} " ) # set mode to training q.client["experiment/start/cfg_mode/mode"] = "train" if q.client["experiment/start/cfg_category"] == "experiment": logger.info("Starting from experiment") # reset previous config file q.client["experiment/start/cfg_file_prev"] = None q.client["experiment/start/experiment"] = q.client.app_db.get_experiment( q.client["experiment/start/cfg_experiment"] ) parent_path = os.path.dirname(q.client["experiment/start/experiment"].path) parent_exp_name = parent_path.split("/")[-1] parent_experiment = f"{parent_exp_name}" old_config = load_config_yaml(f"{parent_path}/cfg.yaml") old_config._parent_experiment = parent_experiment q.client["experiment/start/cfg"] = old_config # set pretrained weights if q.client["experiment/start/cfg_experiment_pretrained"]: prev_weights = os.path.join( q.client["experiment/start/experiment"].path, "checkpoint.pth", ) if os.path.exists(prev_weights): q.client["experiment/start/cfg"].architecture.pretrained_weights = ( prev_weights ) q.client["experiment/start/cfg"].architecture._visibility[ "pretrained_weights" ] = -1 experiments_df = q.client.app_db.get_experiments_df() output_dir = os.path.abspath( os.path.join(q.client["experiment/start/cfg"].output_directory, "..") ) q.client["experiment/start/cfg"].experiment_name = get_unique_name( q.client["experiment/start/cfg"].experiment_name, experiments_df["name"].values, lambda x: os.path.exists(os.path.join(output_dir, x)), ) # Configuration flags: # from_dataset -- take the values from the dataset config # from_cfg -- take the values from the configuration file # from_default -- take the values from the the default settings # from_dataset_args -- take the values from the dataset's q.args # Otherwise -- take the values from the q.args (user input) # pick default values from config if ( q.client["experiment/start/cfg_experiment_prev"] != q.client["experiment/start/cfg_experiment"] ): q.client["experiment/start/cfg_mode/from_dataset"] = False q.client["experiment/start/cfg_mode/from_cfg"] = True q.client["experiment/start/cfg_mode/from_dataset_args"] = False q.client["experiment/start/dataset"] = str( q.client["experiment/start/experiment"].dataset ) items[1].dropdown.value = q.client["experiment/start/dataset"] # pick default values from config or dataset elif ( q.client["experiment/start/dataset_prev"] != q.client["experiment/start/dataset"] ): q.client["experiment/start/cfg_mode/from_dataset"] = True q.client["experiment/start/cfg_mode/from_cfg"] = True q.client["experiment/start/cfg_mode/from_dataset_args"] = False # pick default values from args else: q.client["experiment/start/cfg_mode/from_dataset"] = False q.client["experiment/start/cfg_mode/from_cfg"] = False q.client["experiment/start/cfg_mode/from_dataset_args"] = True q.client["experiment/start/cfg_mode/from_default"] = False q.client["experiment/start/cfg_experiment_prev"] = q.client[ "experiment/start/cfg_experiment" ] else: logger.info("Starting from CFG") # reset previous experiment q.client["experiment/start/cfg_experiment_prev"] = None # pick default values from dataset or config if ( q.client["experiment/start/cfg_file_prev"] != q.client["experiment/start/cfg_file"] ) or ( q.client["experiment/start/dataset_prev"] != q.client["experiment/start/dataset"] ): q.client["experiment/start/cfg_mode/from_dataset"] = True q.client["experiment/start/cfg_mode/from_cfg"] = True q.client["experiment/start/cfg_mode/from_default"] = True q.client["experiment/start/cfg_mode/from_dataset_args"] = False # pick default values from args else: q.client["experiment/start/cfg_mode/from_dataset"] = False q.client["experiment/start/cfg_mode/from_cfg"] = False q.client["experiment/start/cfg_mode/from_default"] = False q.client["experiment/start/cfg_mode/from_dataset_args"] = True q.client["experiment/start/cfg_file_prev"] = q.client[ "experiment/start/cfg_file" ] config_path = ( f"llm_studio/python_configs/{q.client['experiment/start/cfg_file']}" ) q.client["experiment/start/cfg"] = load_config_py( config_path=config_path, config_name="ConfigProblemBase" ) q.client["experiment/start/dataset_prev"] = q.client["experiment/start/dataset"] logger.info(f"From dataset {q.client['experiment/start/cfg_mode/from_dataset']}") logger.info(f"From cfg {q.client['experiment/start/cfg_mode/from_cfg']}") logger.info(f"From default {q.client['experiment/start/cfg_mode/from_default']}") logger.info(f"Config file: {q.client['experiment/start/cfg_file']}") option_items = get_ui_elements(cfg=q.client["experiment/start/cfg"], q=q) items.extend(option_items) if q.client["experiment/start/cfg_mode/from_cfg"]: q.page["experiment/start"] = ui.form_card(box="content", items=items) else: q.page["experiment/start"].items = items q.client.delete_cards.add("experiment/start") q.page["experiment/start/footer"] = ui.form_card( box="footer", items=[ ui.inline( items=[ ui.button( name="experiment/start/run", label="Run experiment", primary=True, ) ], justify="start", ) ], ) q.client.delete_cards.add("experiment/start/footer") async def experiment_run(q: Q, pre: str = "experiment/start"): """Start an experiment. Args: q: Q pre: prefix for client key """ # import here to avoid circular imports from llm_studio.app_utils.sections.project import list_current_experiments logger.info("Starting experiment") logger.info(f"{pre}/cfg_file") logger.info(f"CFG: {q.client[f'{pre}/cfg_file']}") if q.client[f"{pre}/cfg_category"] == "experiment": q.client[f"{pre}/cfg_file"] = q.client[f"{pre}/experiment"].config_file cfg = q.client[f"{pre}/cfg"] cfg = parse_ui_elements(cfg=cfg, q=q, pre=f"{pre}/cfg/") cfg.experiment_name = cfg.experiment_name.replace("/", "-") errors = check_config_for_errors(cfg) if errors["title"] and not q.args["experiment/start/error/proceed"]: title = ( errors["title"][0] if len(errors["title"]) == 1 else "The following configuration mismatches were found:" ) error_text = [ui.text(message) for message in errors["message"]] q.page["meta"].dialog = ui.dialog( title=title, name="experiment/start/error/dialog", items=error_text + [ ui.buttons( [ ui.button( name="experiment/start/error/ok", label="Ok", primary=True ), ui.button( name="experiment/start/error/proceed", label="I want to proceed anyhow", primary=False, ), ] ) ], closable=True, ) q.client["keep_meta"] = True else: start_experiment(cfg=cfg, q=q, pre=pre) await list_current_experiments(q) def get_experiment_table( q, df_viz, predictions, height="calc(100vh - 245px)", actions=None ): col_remove = [ "id", "path", "mode", "seed", "process_id", "gpu_list", "loss", "eta", "epoch", "config_file", ] if predictions: col_remove += ["epoch", "val metric"] for col in col_remove: if col in df_viz: del df_viz[col] # df_viz = df_viz.rename( # columns={"process_id": "pid", "config_file": "problem type"}, # ) # df_viz["problem type"] = df_viz["problem type"].str.replace("Text ", "") if actions == "experiment" and q.client["experiment/list/mode"] == "train": actions_dict = { "experiment/list/new": "New experiment", "experiment/list/rename": "Rename experiment", "experiment/list/stop/table": "Stop experiment", "experiment/list/delete/table/dialog": "Delete experiment", } else: actions_dict = {} min_widths = { "name": "350", "dataset": "150", # "problem type": "190", "metric": "75", "val metric": "102", "progress": "85", "status": "90", "info": "115", "actions": "5" if predictions else "5", } if predictions: for k, v in min_widths.items(): min_widths[k] = str(int(np.ceil(int(v) * 1.05))) return ui_table_from_df( q=q, df=df_viz, name="experiment/list/table", sortables=["val metric"], filterables=["name", "dataset", "problem type", "metric", "status"], searchables=["name", "dataset"], numerics=["val metric"], tags=["status"], progresses=["progress"], min_widths=min_widths, link_col="name", height=height, actions=actions_dict, ) async def experiment_list( q: Q, reset: bool = True, allowed_statuses: Optional[List[str]] = None, actions: bool = True, ) -> None: """List all experiments.""" if q.client["experiment/list/mode"] is None: q.client["experiment/list/mode"] = "train" if q.client["experiment/list/mode"] == "train": q.client["nav/active"] = "experiment/list" else: q.client["nav/active"] = "experiment/list_predictions" if reset: await clean_dashboard(q, mode="full") q.client["experiment/list/df_experiments"] = get_experiments( q, mode=q.client["experiment/list/mode"], status=allowed_statuses, ) df_viz = q.client["experiment/list/df_experiments"].copy() table = get_experiment_table( q, df_viz, q.client["experiment/list/mode"] == "predict", actions="experiment" if actions else None, ) message_bar = get_experiment_list_message_bar(q) items = [table, message_bar] q.page["experiment/list"] = ui.form_card(box="content", items=items) q.client.delete_cards.add("experiment/list") buttons = [ ui.button(name="experiment/list/refresh", label="Refresh", primary=True), ui.button( name="experiment/list/compare", label="Compare experiments", primary=False, ), ui.button(name="experiment/list/stop", label="Stop experiments", primary=False), ui.button( name="experiment/list/delete", label="Delete experiments", primary=False ), ] q.page["dataset/display/footer"] = ui.form_card( box="footer", items=[ui.inline(items=buttons, justify="start")] ) q.client.delete_cards.add("dataset/display/footer") def get_table_and_message_item_indices(q): table_item_idx, message_item_idx = 0, 1 return table_item_idx, message_item_idx async def experiment_compare(q: Q, selected_rows: list): if q.client["experiment/compare/tab"] is None: q.client["experiment/compare/tab"] = "experiment/compare/charts" if q.args["experiment/compare/charts"] is not None: q.client["experiment/compare/tab"] = "experiment/compare/charts" if q.args["experiment/compare/config"] is not None: q.client["experiment/compare/tab"] = "experiment/compare/config" experiment_ids = [ q.client["experiment/list/df_experiments"]["id"].iloc[int(idx)] for idx in selected_rows ] await clean_dashboard(q, mode=q.client["experiment/compare/tab"]) tabs = [ ui.tab(name="experiment/compare/charts", label="Charts"), ui.tab(name="experiment/compare/config", label="Config"), ] q.page["experiment/compare/tab"] = ui.tab_card( box="nav2", link=True, items=tabs, value=q.client["experiment/compare/tab"] ) q.client.delete_cards.add("experiment/compare/tab") if q.client["experiment/compare/tab"] == "experiment/compare/charts": charts = [] experiment_names = [] for experiment_id in experiment_ids: experiment = q.client.app_db.get_experiment(experiment_id) experiment_path = experiment.path charts.append(load_charts(experiment_path)) current_name = f" {experiment.name}" experiment_names.append(current_name) await charts_tab(q, charts, experiment_names) elif q.client["experiment/compare/tab"] == "experiment/compare/config": if q.client["experiment/compare/diff_toggle"] is None: q.client["experiment/compare/diff_toggle"] = False settings = pd.DataFrame() for experiment_id in experiment_ids: experiment = q.client.app_db.get_experiment(experiment_id) experiment_path = experiment.path experiment_cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) items = get_cfg_list_items(experiment_cfg) act_df = pd.Series({item.label: item.value for item in items}) settings[experiment.name] = act_df settings.index.name = "setting" if q.client["experiment/compare/diff_toggle"]: val_counts = settings.T.nunique() drop_idx = val_counts[val_counts == 1].index.values settings = settings.drop(drop_idx) items = [ ui.toggle( name="experiment/compare/diff_toggle", label="Show differences only", value=q.client["experiment/compare/diff_toggle"], trigger=True, ), ui_table_from_df( q=q, df=settings.reset_index(), name="experiment/compare/summary/table", link_col="setting", height="calc(100vh - 315px)", ), ] q.page["experiment/compare/config"] = ui.form_card(box="first", items=items) q.client.delete_cards.add("experiment/compare/config") buttons = [ ui.button(name="experiment/compare", label="Refresh", primary=True), ui.button(name="experiment/list/current", label="Back", primary=False), ] q.page["experiment/compare/footer"] = ui.form_card( box="footer", items=[ui.inline(items=buttons, justify="start")] ) q.client.delete_cards.add("experiment/compare/footer") async def experiment_rename_form(q: Q, error: str = "") -> None: experiment = q.client.app_db.get_experiment(q.client["experiment/rename/id"]) experiment_name = experiment.name items = [ ui.textbox( name="experiment/rename/name", label=f"New name for {experiment_name}", value=experiment_name, required=True, ) ] if error: items.append(ui.message_bar(type="error", text=error)) q.page["experiment/list"].items = items buttons = [ ui.button(name="experiment/rename/action", label="Rename", primary=True), ui.button(name="experiment/list/current", label="Abort", primary=False), ] q.page["dataset/display/footer"] = ui.form_card( box="footer", items=[ui.inline(items=buttons, justify="start")] ) q.client.delete_cards.add("dataset/display/footer") async def experiment_rename_ui_workflow(q: Q): selected_row = q.args["experiment/list/rename"] rename_id = q.client["experiment/list/df_experiments"]["id"].iloc[int(selected_row)] q.client["experiment/rename/id"] = rename_id await experiment_rename_form(q) async def experiment_rename_action(q, experiment, new_name): """Rename experiment with `current_id` id in DB to `new_name`""" old_name = experiment.name old_path = experiment.path new_path = old_path.replace(old_name, new_name) if old_path != new_path: old_exp_path = f"{old_path}" exp_path = f"{new_path}" logger.info(f"Renaming {old_exp_path} to {exp_path}") shutil.move(os.path.abspath(old_exp_path), os.path.abspath(exp_path)) # update the experiment name in the DB with SqliteDict(os.path.join(new_path, "charts.db")) as charts: for k1 in PLOT_ENCODINGS: if k1 == "df": # this is required to properly overwrite it df = charts[k1].copy() for k2, v2 in df.items(): logger.info( f"Renaming charts {v2} to {v2.replace(old_name, new_name)}" ) df[k2] = v2.replace(old_name, new_name) charts[k1] = df charts.commit() for config_file in ["cfg.yaml"]: config_path = os.path.join(exp_path, config_file) if os.path.exists(config_path): experiment_cfg = load_config_yaml(config_path) experiment_cfg.experiment_name = new_name experiment_cfg.output_directory = new_path save_config_yaml(config_path, experiment_cfg) rename_files = ["preds"] for file in rename_files: old_file = get_artifact_path_path(old_name, exp_path, file) new_file = get_artifact_path_path(new_name, exp_path, file) if os.path.exists(old_file): logger.info(f"Renaming {old_file} to {new_file}") shutil.move(os.path.abspath(old_file), os.path.abspath(new_file)) delete_files = ["logs"] # will be generated on demand with updates for file in delete_files: file = get_artifact_path_path(old_name, exp_path, file) if os.path.exists(file): logger.info(f"Deleting {file}") os.remove(file) q.client.app_db.rename_experiment(experiment.id, new_name, new_path) async def experiment_delete(q: Q, experiment_ids: List[int]) -> None: """Delete selected experiments. Args: q: Q experiment_ids: list of experiment ids to delete """ for experiment_id in experiment_ids: experiment = q.client.app_db.get_experiment(experiment_id) q.client.app_db.delete_experiment(experiment.id) shutil.rmtree(f"{experiment.path}") async def experiment_stop(q: Q, experiment_ids: List[int]) -> None: """Stop selected experiments. Args: q: Q experiment_ids: list of experiment ids to stop """ for experiment_id in experiment_ids: experiment = q.client.app_db.get_experiment(experiment_id) try: ret = kill_child_processes(int(experiment.process_id)) if ret: flag_path = os.path.join(experiment.path, "flags.json") write_flag(flag_path, "status", "stopped") except Exception as e: logger.error(f"Error while stopping the experiment: {e}") pass def load_charts(experiment_path): try: with SqliteDict(os.path.join(experiment_path, "charts.db")) as charts: charts = dict(charts) except Exception: charts = {} logger.warning("Too early, wait for the charts to appear") return charts async def experiment_display(q: Q) -> None: """Display a selected experiment.""" experiment_id = q.client["experiment/list/df_experiments"]["id"].iloc[ q.client["experiment/display/id"] ] q.client["experiment/display/experiment_id"] = experiment_id experiment = q.client.app_db.get_experiment(experiment_id) q.client["experiment/display/experiment"] = experiment q.client["experiment/display/experiment_path"] = experiment.path status, _ = get_experiment_status(experiment.path) charts = load_charts(q.client["experiment/display/experiment_path"]) q.client["experiment/display/charts"] = charts if experiment.mode == "train": if q.client["experiment/display/tab"] is None: q.client["experiment/display/tab"] = "experiment/display/charts" else: if q.client["experiment/display/tab"] is None: q.client["experiment/display/tab"] = "experiment/display/summary" if q.args["experiment/display/charts"] is not None: q.client["experiment/display/tab"] = "experiment/display/charts" if q.args["experiment/display/summary"] is not None: q.client["experiment/display/tab"] = "experiment/display/summary" if q.args["experiment/display/train_data_insights"] is not None: q.client["experiment/display/tab"] = "experiment/display/train_data_insights" if q.args["experiment/display/validation_prediction_insights"] is not None: q.client["experiment/display/tab"] = ( "experiment/display/validation_prediction_insights" ) if q.args["experiment/display/config"] is not None: q.client["experiment/display/tab"] = "experiment/display/config" if q.args["experiment/display/deployment"] is not None: q.client["experiment/display/tab"] = "experiment/display/deployment" if q.args["experiment/display/logs"] is not None: q.client["experiment/display/tab"] = "experiment/display/logs" if q.args["experiment/display/chat"] is not None: q.client["experiment/display/tab"] = "experiment/display/chat" await clean_dashboard(q, mode=q.client["experiment/display/tab"]) tabs = [ ui.tab(name="experiment/display/charts", label="Charts"), ui.tab(name="experiment/display/summary", label="Summary"), ] # html for legacy experiments has_train_data_insights = any( [ charts.get(plot_encoding, dict()).get("train_data") is not None for plot_encoding in PLOT_ENCODINGS ] ) if has_train_data_insights: tabs += [ ui.tab( name="experiment/display/train_data_insights", label="Train Data Insights", ) ] has_validation_prediction_insights = any( [ charts.get(plot_encoding, dict()).get("validation_predictions") is not None for plot_encoding in PLOT_ENCODINGS ] ) if has_validation_prediction_insights: tabs += [ ui.tab( name="experiment/display/validation_prediction_insights", label="Validation Prediction Insights", ) ] tabs += [ ui.tab(name="experiment/display/logs", label="Logs"), ui.tab(name="experiment/display/config", label="Config"), ] if status == "finished": tabs += [ui.tab(name="experiment/display/chat", label="Chat")] q.page["experiment/display/tab"] = ui.tab_card( box="nav2", link=True, items=tabs, value=q.client["experiment/display/tab"] ) q.client.delete_cards.add("experiment/display/tab") if q.client["experiment/display/tab"] == "experiment/display/charts": await charts_tab(q, [charts], [""]) elif q.client["experiment/display/tab"] in [ "experiment/display/train_data_insights", "experiment/display/validation_prediction_insights", ]: await insights_tab(charts, q) elif q.client["experiment/display/tab"] in ["experiment/display/summary"]: await summary_tab(experiment_id, q) elif q.client["experiment/display/tab"] in ["experiment/display/config"]: await configs_tab(q) elif q.client["experiment/display/tab"] in ["experiment/display/logs"]: await logs_tab(q) elif q.client["experiment/display/tab"] in ["experiment/display/chat"]: await chat_tab(q) await q.page.save() buttons = [ ui.button(name="experiment/display/refresh", label="Refresh", primary=True) ] buttons += [ ui.button( name="experiment/display/download_logs", label="Download logs/config", primary=False, ) ] if status == "finished": buttons += [ ui.button( name="experiment/display/download_predictions", label="Download predictions", primary=False, disabled=False, tooltip=None, ), ui.button( name="experiment/display/download_model", label="Download model", primary=False, disabled=False, tooltip=None, ), ui.button( name="experiment/display/push_to_huggingface", label="Push checkpoint to huggingface", primary=False, disabled=False, tooltip=None, ), ] buttons += [ui.button(name="experiment/list/current", label="Back", primary=False)] q.page["experiment/display/footer"] = ui.form_card( box="footer", items=[ ui.inline(items=buttons, justify="start"), ], ) q.client.delete_cards.add("experiment/display/footer") async def insights_tab(charts, q): if q.client["experiment/display/tab"] == "experiment/display/train_data_insights": key = "train_data" elif ( q.client["experiment/display/tab"] == "experiment/display/validation_prediction_insights" ): key = "validation_predictions" for k1 in PLOT_ENCODINGS: if k1 not in charts: continue for k2, v2 in charts[k1].items(): if k2 != key: continue if k1 == "html": q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.markup_card( box="first", title="", content=v2 ) q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") continue elif k1 == "image": q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.image_card( box="first", title="", type="png", image=v2 ) q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") continue elif k1 == "df": df = pd.read_parquet(v2) min_widths = { col: "350" for col in df.columns if "text" in str(col).lower() } # if key == "train_data": min_widths["Content"] = "800" q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.form_card( box="first", items=[ ui_table_from_df( q=q, df=df, name=f"experiment/display/charts/{k1}_{k2}", sortables=[ col for col in df.columns if col.startswith("Metric") ], markdown_cells=[ col for col in df.columns if not col.startswith("Metric") ], searchables=list(df.columns), downloadable=True, resettable=True, min_widths=min_widths, height="calc(100vh - 245px)", max_char_length=50_000, cell_overflow="tooltip", ) ], ) q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") continue async def summary_tab(experiment_id, q): experiment_df = get_experiments(q) input_dict = experiment_df[experiment_df.id == experiment_id].iloc[0].to_dict() cfg = load_config_yaml( os.path.join(q.client["experiment/display/experiment_path"], "cfg.yaml") ) _ = get_tokenizer(cfg) # experiment card card_name = "experiment/display/summary/experiment" q.page[card_name] = ui.form_card( box=ui.box(zone="first"), items=[ ui.separator("Experiment"), ui.stats( [ ui.stat( value=cfg.experiment_name, label="Name", ), ], justify="between", inset=True, ), ui.stats( [ ui.stat( value=input_dict["config_file"], label="Problem Type", ), ], justify="between", inset=True, ), ], ) q.client.delete_cards.add(card_name) # datasets card card_name = "experiment/display/summary/datasets" q.page[card_name] = ui.form_card( box=ui.box(zone="first"), items=[ ui.separator("Datasets"), ui.stats( [ ui.stat( value=Path(cfg.dataset.train_dataframe).stem, label="Training Dataset", ), ], justify="between", inset=True, ), ui.stats( [ ui.stat( value=( "-" if cfg.dataset.validation_dataframe in ["", "None", None] else Path(cfg.dataset.validation_dataframe).stem ), label="Validation Dataset", ), ], justify="between", inset=True, ), ], ) q.client.delete_cards.add(card_name) # score card card_name = "experiment/display/summary/score" q.page[card_name] = ui.form_card( box=ui.box(zone="first"), items=[ ui.separator("Score"), ui.stats( [ ui.stat( value=input_dict["metric"], label="Metric", ), ], justify="between", inset=True, ), ui.stats( [ ui.stat( value=( "-" if input_dict["val metric"] in ["", "None", None] else str(input_dict["val metric"]) ), label="Validation Score", ), ], justify="between", inset=True, ), ], ) q.client.delete_cards.add(card_name) # main configs card card_name = "experiment/display/summary/main_configs" q.page[card_name] = ui.form_card( box=ui.box(zone="second"), items=[ ui.separator("Main Configurations"), ui.stats( [ ui.stat( value=cfg.llm_backbone, label="LLM Backbone", ), ui.stat( value=str(cfg.training.lora), label="Lora", ), ui.stat( value=str(cfg.training.epochs), label="Epochs", ), ui.stat( value=str(cfg.training.batch_size), label="Batch Size", ), ], justify="between", inset=True, ), ui.stats( [ ui.stat( value=str(input_dict["loss"]), label="Loss Function", ), ui.stat( value=cfg.architecture.backbone_dtype, label="Backbone Dtype", ), ui.stat( value=str(cfg.architecture.gradient_checkpointing), label="Gradient Checkpointing", ), ui.stat( value=input_dict["gpu_list"], label="GPU List", ), ], justify="between", inset=True, ), ], ) q.client.delete_cards.add(card_name) # code card card_name = "experiment/display/summary/code" content = get_experiment_summary_code_card(cfg=cfg) q.page[card_name] = ui.markdown_card( box=ui.box(zone="third"), title="", content=content, ) q.client.delete_cards.add(card_name) async def configs_tab(q): experiment_cfg = load_config_yaml( os.path.join(q.client["experiment/display/experiment_path"], "cfg.yaml") ) items = get_cfg_list_items(experiment_cfg) q.page["experiment/display/config"] = ui.stat_list_card( box="first", items=items, title="" ) q.client.delete_cards.add("experiment/display/config") async def logs_tab(q): logs_path = f"{q.client['experiment/display/experiment_path']}/logs.log" text = "" in_pre = 0 # Read log file only if it already exists if os.path.exists(logs_path): with open(logs_path, "r") as f: for line in f.readlines(): if in_pre == 0: text += "
" in line: in_pre += 1 if "" in line: in_pre -= 1 if in_pre == 0: text += "