|
import functools |
|
import hashlib |
|
import logging |
|
import os |
|
import re |
|
import shutil |
|
import time |
|
import traceback |
|
from typing import List, Optional |
|
|
|
import pandas as pd |
|
from h2o_wave import Q, ui |
|
from h2o_wave.types import FormCard, ImageCard, MarkupCard, StatListItem, Tab |
|
|
|
from llm_studio.app_utils.config import default_cfg |
|
from llm_studio.app_utils.db import Dataset |
|
from llm_studio.app_utils.sections.common import clean_dashboard |
|
from llm_studio.app_utils.sections.experiment import experiment_start |
|
from llm_studio.app_utils.sections.histogram_card import histogram_card |
|
from llm_studio.app_utils.utils import ( |
|
add_model_type, |
|
azure_download, |
|
azure_file_options, |
|
check_valid_upload_content, |
|
clean_error, |
|
dir_file_table, |
|
get_data_dir, |
|
get_dataset_elements, |
|
get_datasets, |
|
get_experiments_status, |
|
get_frame_stats, |
|
get_model_types, |
|
get_problem_types, |
|
get_unique_dataset_name, |
|
kaggle_download, |
|
local_download, |
|
make_label, |
|
parse_ui_elements, |
|
remove_temp_files, |
|
s3_download, |
|
s3_file_options, |
|
) |
|
from llm_studio.app_utils.wave_utils import busy_dialog, ui_table_from_df |
|
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains |
|
from llm_studio.src.utils.config_utils import ( |
|
load_config_py, |
|
load_config_yaml, |
|
save_config_yaml, |
|
) |
|
from llm_studio.src.utils.data_utils import ( |
|
get_fill_columns, |
|
read_dataframe, |
|
read_dataframe_drop_missing_labels, |
|
sanity_check, |
|
) |
|
from llm_studio.src.utils.plot_utils import PlotData |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def file_extension_is_compatible(q): |
|
cfg = q.client["dataset/import/cfg"] |
|
allowed_extensions = cfg.dataset._allowed_file_extensions |
|
|
|
is_correct_extension = [] |
|
for mode in ["train", "validation"]: |
|
dataset_name = q.client[f"dataset/import/cfg/{mode}_dataframe"] |
|
|
|
if dataset_name is None or dataset_name == "None": |
|
continue |
|
is_correct_extension.append(dataset_name.endswith(allowed_extensions)) |
|
return all(is_correct_extension) |
|
|
|
|
|
async def dataset_import( |
|
q: Q, |
|
step: int, |
|
edit: Optional[bool] = False, |
|
error: Optional[str] = "", |
|
warning: Optional[str] = "", |
|
info: Optional[str] = "", |
|
allow_merge: bool = True, |
|
) -> None: |
|
"""Display dataset import cards. |
|
|
|
Args: |
|
q: Q |
|
step: current step of wizard |
|
edit: whether in edit mode |
|
error: optional error message |
|
warning: optional warning message |
|
info: optional info message |
|
allow_merge: whether to allow merging dataset when importing |
|
""" |
|
|
|
await clean_dashboard(q, mode="full") |
|
q.client["nav/active"] = "dataset/import" |
|
if step == 1: |
|
q.page["dataset/import"] = ui.form_card(box="content", items=[]) |
|
q.client.delete_cards.add("dataset/import") |
|
|
|
if q.client["dataset/import/source"] is None: |
|
q.client["dataset/import/source"] = "Upload" |
|
|
|
import_choices = [ |
|
ui.choice("Upload", "Upload"), |
|
ui.choice("Local", "Local"), |
|
ui.choice("S3", "AWS S3"), |
|
ui.choice("Azure", "Azure Datalake"), |
|
ui.choice("Kaggle", "Kaggle"), |
|
] |
|
|
|
items = [ |
|
ui.text_l("Import dataset"), |
|
ui.dropdown( |
|
name="dataset/import/source", |
|
label="Source", |
|
value=( |
|
"Upload" |
|
if q.client["dataset/import/source"] is None |
|
else q.client["dataset/import/source"] |
|
), |
|
choices=import_choices, |
|
trigger=True, |
|
tooltip="Source of dataset import", |
|
), |
|
] |
|
|
|
if ( |
|
q.client["dataset/import/source"] is None |
|
or q.client["dataset/import/source"] == "S3" |
|
): |
|
if q.client["dataset/import/s3_bucket"] is None: |
|
q.client["dataset/import/s3_bucket"] = q.client[ |
|
"default_aws_bucket_name" |
|
] |
|
if q.client["dataset/import/s3_access_key"] is None: |
|
q.client["dataset/import/s3_access_key"] = q.client[ |
|
"default_aws_access_key" |
|
] |
|
if q.client["dataset/import/s3_secret_key"] is None: |
|
q.client["dataset/import/s3_secret_key"] = q.client[ |
|
"default_aws_secret_key" |
|
] |
|
|
|
files = s3_file_options( |
|
q.client["dataset/import/s3_bucket"], |
|
q.client["dataset/import/s3_access_key"], |
|
q.client["dataset/import/s3_secret_key"], |
|
) |
|
|
|
if not files: |
|
ui_filename = ui.textbox( |
|
name="dataset/import/s3_filename", |
|
label="File name", |
|
value="", |
|
required=True, |
|
tooltip="File name to be imported", |
|
) |
|
else: |
|
if default_cfg.s3_filename in files: |
|
default_file = default_cfg.s3_filename |
|
else: |
|
default_file = files[0] |
|
ui_filename = ui.dropdown( |
|
name="dataset/import/s3_filename", |
|
label="File name", |
|
value=default_file, |
|
choices=[ui.choice(x, x.split("/")[-1]) for x in files], |
|
required=True, |
|
tooltip="File name to be imported", |
|
) |
|
|
|
items += [ |
|
ui.textbox( |
|
name="dataset/import/s3_bucket", |
|
label="S3 bucket name", |
|
value=q.client["dataset/import/s3_bucket"], |
|
trigger=True, |
|
required=True, |
|
tooltip="S3 bucket name including relative paths", |
|
), |
|
ui.textbox( |
|
name="dataset/import/s3_access_key", |
|
label="AWS access key", |
|
value=q.client["dataset/import/s3_access_key"], |
|
trigger=True, |
|
required=True, |
|
password=True, |
|
tooltip="Optional AWS access key; empty for anonymous access.", |
|
), |
|
ui.textbox( |
|
name="dataset/import/s3_secret_key", |
|
label="AWS secret key", |
|
value=q.client["dataset/import/s3_secret_key"], |
|
trigger=True, |
|
required=True, |
|
password=True, |
|
tooltip="Optional AWS secret key; empty for anonymous access.", |
|
), |
|
ui_filename, |
|
] |
|
|
|
elif ( |
|
q.client["dataset/import/source"] is None |
|
or q.client["dataset/import/source"] == "Azure" |
|
): |
|
if q.client["dataset/import/azure_conn_string"] is None: |
|
q.client["dataset/import/azure_conn_string"] = q.client[ |
|
"default_azure_conn_string" |
|
] |
|
if q.client["dataset/import/azure_container"] is None: |
|
q.client["dataset/import/azure_container"] = q.client[ |
|
"default_azure_container" |
|
] |
|
|
|
files = azure_file_options( |
|
q.client["dataset/import/azure_conn_string"], |
|
q.client["dataset/import/azure_container"], |
|
) |
|
print(files) |
|
|
|
if not files: |
|
ui_filename = ui.textbox( |
|
name="dataset/import/azure_filename", |
|
label="File name", |
|
value="", |
|
required=True, |
|
tooltip="File name to be imported", |
|
) |
|
else: |
|
default_file = files[0] |
|
ui_filename = ui.dropdown( |
|
name="dataset/import/azure_filename", |
|
label="File name", |
|
value=default_file, |
|
choices=[ui.choice(x, x.split("/")[-1]) for x in files], |
|
required=True, |
|
tooltip="File name to be imported", |
|
) |
|
|
|
items += [ |
|
ui.textbox( |
|
name="dataset/import/azure_conn_string", |
|
label="Datalake connection string", |
|
value=q.client["dataset/import/azure_conn_string"], |
|
trigger=True, |
|
required=True, |
|
password=True, |
|
tooltip="Azure connection string to connect to Datalake storage", |
|
), |
|
ui.textbox( |
|
name="dataset/import/azure_container", |
|
label="Datalake container name", |
|
value=q.client["dataset/import/azure_container"], |
|
trigger=True, |
|
required=True, |
|
tooltip="Azure Datalake container name including relative paths", |
|
), |
|
ui_filename, |
|
] |
|
|
|
elif q.client["dataset/import/source"] == "Upload": |
|
items += [ |
|
ui.file_upload( |
|
name="dataset/import/local_upload", |
|
label="Upload!", |
|
multiple=False, |
|
file_extensions=default_cfg.allowed_file_extensions, |
|
) |
|
] |
|
|
|
elif q.client["dataset/import/source"] == "Local": |
|
current_path = ( |
|
q.client["dataset/import/local_path_current"] |
|
if q.client["dataset/import/local_path_current"] is not None |
|
else os.path.expanduser("~") |
|
) |
|
|
|
if q.args.__wave_submission_name__ == "dataset/import/local_path_list": |
|
idx = int(q.args["dataset/import/local_path_list"][0]) |
|
options = q.client["dataset/import/local_path_list_last"] |
|
new_path = os.path.abspath(os.path.join(current_path, options[idx])) |
|
if os.path.exists(new_path): |
|
current_path = new_path |
|
|
|
results_df = dir_file_table(current_path) |
|
files_list = results_df[current_path].tolist() |
|
q.client["dataset/import/local_path_list_last"] = files_list |
|
q.client["dataset/import/local_path_current"] = current_path |
|
|
|
items += [ |
|
ui.textbox( |
|
name="dataset/import/local_path", |
|
label="File location", |
|
value=current_path, |
|
required=True, |
|
tooltip="Location of file to be imported", |
|
), |
|
ui_table_from_df( |
|
q=q, |
|
df=results_df, |
|
name="dataset/import/local_path_list", |
|
sortables=[], |
|
searchables=[], |
|
min_widths={current_path: "400"}, |
|
link_col=current_path, |
|
height="calc(65vh)", |
|
), |
|
] |
|
|
|
elif q.client["dataset/import/source"] == "Kaggle": |
|
if q.client["dataset/import/kaggle_access_key"] is None: |
|
q.client["dataset/import/kaggle_access_key"] = q.client[ |
|
"default_kaggle_username" |
|
] |
|
if q.client["dataset/import/kaggle_secret_key"] is None: |
|
q.client["dataset/import/kaggle_secret_key"] = q.client[ |
|
"default_kaggle_secret_key" |
|
] |
|
|
|
items += [ |
|
ui.textbox( |
|
name="dataset/import/kaggle_command", |
|
label="Kaggle API command", |
|
value=default_cfg.kaggle_command, |
|
required=True, |
|
tooltip="Kaggle API command to be executed", |
|
), |
|
ui.textbox( |
|
name="dataset/import/kaggle_access_key", |
|
label="Kaggle username", |
|
value=q.client["dataset/import/kaggle_access_key"], |
|
required=True, |
|
password=False, |
|
tooltip="Kaggle username for API authentication", |
|
), |
|
ui.textbox( |
|
name="dataset/import/kaggle_secret_key", |
|
label="Kaggle secret key", |
|
value=q.client["dataset/import/kaggle_secret_key"], |
|
required=True, |
|
password=True, |
|
tooltip="Kaggle secret key for API authentication", |
|
), |
|
] |
|
|
|
allowed_types = ", ".join(default_cfg.allowed_file_extensions) |
|
allowed_types = " or".join(allowed_types.rsplit(",", 1)) |
|
items += [ |
|
ui.message_bar(type="info", text=info + f"Must be a {allowed_types} file."), |
|
ui.message_bar(type="error", text=error), |
|
ui.message_bar(type="warning", text=warning), |
|
] |
|
|
|
q.page["dataset/import"].items = items |
|
|
|
buttons = [ui.button(name="dataset/list", label="Abort")] |
|
if q.client["dataset/import/source"] != "Upload": |
|
buttons.insert( |
|
0, ui.button(name="dataset/import/2", label="Continue", primary=True) |
|
) |
|
|
|
q.page["dataset/import/footer"] = ui.form_card( |
|
box="footer", items=[ui.inline(items=buttons, justify="start")] |
|
) |
|
q.client.delete_cards.add("dataset/import/footer") |
|
|
|
q.client["dataset/import/id"] = None |
|
q.client["dataset/import/cfg_file"] = None |
|
|
|
elif step == 2: |
|
q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[]) |
|
try: |
|
if not q.args["dataset/import/cfg_file"] and not edit: |
|
if q.client["dataset/import/source"] == "S3": |
|
( |
|
q.client["dataset/import/path"], |
|
q.client["dataset/import/name"], |
|
) = await s3_download( |
|
q, |
|
q.client["dataset/import/s3_bucket"], |
|
q.client["dataset/import/s3_filename"], |
|
q.client["dataset/import/s3_access_key"], |
|
q.client["dataset/import/s3_secret_key"], |
|
) |
|
elif q.client["dataset/import/source"] == "Azure": |
|
( |
|
q.client["dataset/import/path"], |
|
q.client["dataset/import/name"], |
|
) = await azure_download( |
|
q, |
|
q.client["dataset/import/azure_conn_string"], |
|
q.client["dataset/import/azure_container"], |
|
q.client["dataset/import/azure_filename"], |
|
) |
|
elif q.client["dataset/import/source"] in ("Upload", "Local"): |
|
( |
|
q.client["dataset/import/path"], |
|
q.client["dataset/import/name"], |
|
) = await local_download(q, q.client["dataset/import/local_path"]) |
|
elif q.client["dataset/import/source"] == "Kaggle": |
|
( |
|
q.client["dataset/import/path"], |
|
q.client["dataset/import/name"], |
|
) = await kaggle_download( |
|
q, |
|
q.client["dataset/import/kaggle_command"], |
|
q.client["dataset/import/kaggle_access_key"], |
|
q.client["dataset/import/kaggle_secret_key"], |
|
) |
|
|
|
|
|
q.client["dataset/import/edit"] = edit |
|
|
|
|
|
for trigger_key in default_cfg.dataset_trigger_keys: |
|
if q.client[f"dataset/import/cfg/{trigger_key}"]: |
|
del q.client[f"dataset/import/cfg/{trigger_key}"] |
|
|
|
await dataset_import( |
|
q, |
|
step=3, |
|
edit=edit, |
|
error=error, |
|
warning=warning, |
|
allow_merge=allow_merge, |
|
) |
|
except Exception: |
|
logger.error("Dataset error:", exc_info=True) |
|
error = ( |
|
"Dataset import failed. Please make sure all required " |
|
"fields are filled correctly." |
|
) |
|
await clean_dashboard(q, mode="full") |
|
await dataset_import(q, step=1, error=str(error)) |
|
|
|
elif step == 3: |
|
q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[]) |
|
try: |
|
if not q.args["dataset/import/cfg_file"] and not edit: |
|
q.client["dataset/import/name"] = get_unique_dataset_name( |
|
q, q.client["dataset/import/name"] |
|
) |
|
q.page["dataset/import"] = ui.form_card(box="content", items=[]) |
|
q.client.delete_cards.add("dataset/import") |
|
|
|
wizard = q.page["dataset/import"] |
|
|
|
title = "Configure dataset" |
|
|
|
items = [ |
|
ui.text_l(title), |
|
ui.textbox( |
|
name="dataset/import/name", |
|
label="Dataset name", |
|
value=q.client["dataset/import/name"], |
|
required=True, |
|
), |
|
] |
|
|
|
choices_problem_types = [ |
|
ui.choice(name, label) for name, label in get_problem_types() |
|
] |
|
|
|
if q.client["dataset/import/cfg_file"] is None: |
|
max_substring_len = 0 |
|
for c in choices_problem_types: |
|
problem_type_name = c.name.replace("_config", "") |
|
if problem_type_name in q.client["dataset/import/name"]: |
|
if len(problem_type_name) > max_substring_len: |
|
q.client["dataset/import/cfg_file"] = c.name |
|
q.client["dataset/import/cfg_category"] = c.name.split("_")[ |
|
0 |
|
] |
|
max_substring_len = len(problem_type_name) |
|
if q.client["dataset/import/cfg_file"] is None: |
|
q.client["dataset/import/cfg_file"] = default_cfg.cfg_file |
|
q.client["dataset/import/cfg_category"] = q.client[ |
|
"dataset/import/cfg_file" |
|
].split("_")[0] |
|
|
|
|
|
if ( |
|
q.client["dataset/import/cfg_category"] |
|
not in q.client["dataset/import/cfg_file"] |
|
): |
|
q.client["dataset/import/cfg_file"] = get_problem_types( |
|
category=q.client["dataset/import/cfg_category"] |
|
)[0][0] |
|
|
|
model_types = get_model_types(q.client["dataset/import/cfg_file"]) |
|
if len(model_types) > 0: |
|
|
|
q.client["dataset/import/cfg_file"] = add_model_type( |
|
q.client["dataset/import/cfg_file"], model_types[0][0] |
|
) |
|
if not edit: |
|
q.client["dataset/import/cfg"] = load_config_py( |
|
config_path=( |
|
f"llm_studio/python_configs/" |
|
f"{q.client['dataset/import/cfg_file']}" |
|
), |
|
config_name="ConfigProblemBase", |
|
) |
|
|
|
option_items = get_dataset_elements(cfg=q.client["dataset/import/cfg"], q=q) |
|
items.extend(option_items) |
|
items.append(ui.message_bar(type="error", text=error)) |
|
items.append(ui.message_bar(type="warning", text=warning)) |
|
if file_extension_is_compatible(q): |
|
ui_nav_name = "dataset/import/4/edit" if edit else "dataset/import/4" |
|
buttons = [ |
|
ui.button(name=ui_nav_name, label="Continue", primary=True), |
|
ui.button(name="dataset/list", label="Abort"), |
|
] |
|
if allow_merge: |
|
datasets_df = q.client.app_db.get_datasets_df() |
|
if datasets_df.shape[0]: |
|
label = "Merge With Existing Dataset" |
|
buttons.insert(1, ui.button(name="dataset/merge", label=label)) |
|
else: |
|
problem_type = make_label( |
|
re.sub("_config.*", "", q.client["dataset/import/cfg_file"]) |
|
) |
|
items += [ |
|
ui.text( |
|
"<b> The chosen file extensions is not " |
|
f"compatible with {problem_type}.</b> " |
|
) |
|
] |
|
buttons = [ |
|
ui.button(name="dataset/list", label="Abort"), |
|
] |
|
q.page["dataset/import/footer"] = ui.form_card( |
|
box="footer", items=[ui.inline(items=buttons, justify="start")] |
|
) |
|
|
|
wizard.items = items |
|
|
|
q.client.delete_cards.add("dataset/import/footer") |
|
|
|
except Exception as exception: |
|
logger.error("Dataset error:", exc_info=True) |
|
error = clean_error(str(exception)) |
|
await clean_dashboard(q, mode="full") |
|
await dataset_import(q, step=1, error=str(error)) |
|
|
|
elif step == 4: |
|
dataset_name = q.client["dataset/import/name"] |
|
original_name = q.client["dataset/import/original_name"] |
|
valid_dataset_name = get_unique_dataset_name(q, dataset_name) |
|
if valid_dataset_name != dataset_name and not ( |
|
q.client["dataset/import/edit"] and dataset_name == original_name |
|
): |
|
err = f"Dataset <strong>{dataset_name}</strong> already exists" |
|
q.client["dataset/import/name"] = valid_dataset_name |
|
await dataset_import(q, 3, edit=edit, error=err) |
|
else: |
|
await dataset_import(q, 5, edit=edit) |
|
|
|
elif step == 5: |
|
header = "<h2>Sample Data Visualization</h2>" |
|
valid_visualization = False |
|
try: |
|
cfg = q.client["dataset/import/cfg"] |
|
cfg = parse_ui_elements( |
|
cfg=cfg, q=q, limit=default_cfg.dataset_keys, pre="dataset/import/cfg/" |
|
) |
|
|
|
q.client["dataset/import/cfg"] = cfg |
|
plot = cfg.logging.plots_class.plot_data(cfg) |
|
text = ( |
|
"Data Validity Check. Click <strong>Continue</strong> if the input " |
|
"data and labels appear correctly." |
|
) |
|
if plot.encoding == "image": |
|
plot_item = ui.image(title="", type="png", image=plot.data) |
|
elif plot.encoding == "html": |
|
plot_item = ui.markup(content=plot.data) |
|
elif plot.encoding == "df": |
|
df = pd.read_parquet(plot.data) |
|
df = df.iloc[:2000] |
|
min_widths = {"Content": "800"} |
|
plot_item = ui_table_from_df( |
|
q=q, |
|
df=df, |
|
name="experiment/display/table", |
|
markdown_cells=list(df.columns), |
|
searchables=list(df.columns), |
|
downloadable=False, |
|
resettable=False, |
|
min_widths=min_widths, |
|
height="calc(100vh - 245px)", |
|
max_char_length=5_000, |
|
cell_overflow="tooltip", |
|
) |
|
else: |
|
raise ValueError(f"Unknown plot encoding `{plot.encoding}`") |
|
|
|
items = [ui.markup(content=header), ui.message_bar(text=text), plot_item] |
|
valid_visualization = True |
|
|
|
await busy_dialog( |
|
q=q, |
|
title="Performing sanity checks on the data", |
|
text="Please be patient...", |
|
) |
|
|
|
|
|
time.sleep(1) |
|
sanity_check(cfg) |
|
|
|
except Exception as exception: |
|
logger.error( |
|
f"Error while plotting data preview: {exception}", exc_info=True |
|
) |
|
text = ( |
|
"Error occurred while visualizing the data. Please go back and verify " |
|
"whether the problem type and other settings were set properly." |
|
) |
|
items = [ |
|
ui.markup(content=header), |
|
ui.message_bar(text=text, type="error"), |
|
ui.expander( |
|
name="expander", |
|
label="Expand Error Traceback", |
|
items=[ui.markup(f"<pre>{traceback.format_exc()}</pre>")], |
|
), |
|
] |
|
|
|
buttons = [ |
|
ui.button( |
|
name="dataset/import/6", label="Continue", primary=valid_visualization |
|
), |
|
ui.button( |
|
name="dataset/import/3/edit", |
|
label="Back", |
|
primary=not valid_visualization, |
|
), |
|
ui.button(name="dataset/list", label="Abort"), |
|
] |
|
|
|
q.page["dataset/import"] = ui.form_card(box="content", items=items) |
|
q.client.delete_cards.add("dataset/import") |
|
|
|
q.page["dataset/import/footer"] = ui.form_card( |
|
box="footer", items=[ui.inline(items=buttons, justify="start")] |
|
) |
|
q.client.delete_cards.add("dataset/import/footer") |
|
|
|
elif step == 6: |
|
if q.client["dataset/import/name"] == "": |
|
await clean_dashboard(q, mode="full") |
|
await dataset_import(q, step=2, error="Please enter all required fields!") |
|
|
|
else: |
|
folder_name = q.client["dataset/import/path"].split("/")[-1] |
|
new_folder = q.client["dataset/import/name"] |
|
act_path = q.client["dataset/import/path"] |
|
new_path = new_folder.join(act_path.rsplit(folder_name, 1)) |
|
|
|
try: |
|
shutil.move(q.client["dataset/import/path"], new_path) |
|
|
|
cfg = q.client["dataset/import/cfg"] |
|
|
|
|
|
for k in default_cfg.dataset_folder_keys: |
|
old_path = getattr(cfg.dataset, k, None) |
|
if old_path is not None: |
|
setattr( |
|
cfg.dataset, |
|
k, |
|
old_path.replace(q.client["dataset/import/path"], new_path), |
|
) |
|
|
|
|
|
if cfg.dataset.validation_dataframe != "None": |
|
cfg.dataset.validation_strategy = "custom" |
|
cfg_path = f"{new_path}/{q.client['dataset/import/cfg_file']}.yaml" |
|
save_config_yaml(cfg_path, cfg) |
|
|
|
train_rows = None |
|
if os.path.exists(cfg.dataset.train_dataframe): |
|
train_rows = read_dataframe_drop_missing_labels( |
|
cfg.dataset.train_dataframe, cfg |
|
).shape[0] |
|
validation_rows = None |
|
if os.path.exists(cfg.dataset.validation_dataframe): |
|
validation_rows = read_dataframe_drop_missing_labels( |
|
cfg.dataset.validation_dataframe, cfg |
|
).shape[0] |
|
|
|
dataset = Dataset( |
|
id=q.client["dataset/import/id"], |
|
name=q.client["dataset/import/name"], |
|
path=new_path, |
|
config_file=cfg_path, |
|
train_rows=train_rows, |
|
validation_rows=validation_rows, |
|
) |
|
if q.client["dataset/import/id"] is not None: |
|
q.client.app_db.delete_dataset(dataset.id) |
|
q.client.app_db.add_dataset(dataset) |
|
await dataset_list(q) |
|
|
|
except Exception as exception: |
|
logger.error("Dataset error:", exc_info=True) |
|
q.client.app_db._session.rollback() |
|
error = clean_error(str(exception)) |
|
await clean_dashboard(q, mode="full") |
|
await dataset_import(q, step=2, error=str(error)) |
|
|
|
|
|
async def dataset_merge(q: Q, step, error=""): |
|
if step == 1: |
|
await clean_dashboard(q, mode="full") |
|
q.client["nav/active"] = "dataset/merge" |
|
|
|
q.page["dataset/merge"] = ui.form_card(box="content", items=[]) |
|
q.client.delete_cards.add("dataset/merge") |
|
|
|
datasets_df = q.client.app_db.get_datasets_df() |
|
import_choices = [ |
|
ui.choice(x["path"], x["name"]) for idx, x in datasets_df.iterrows() |
|
] |
|
|
|
items = [ |
|
ui.text_l("Merge current dataset with an existing dataset"), |
|
ui.dropdown( |
|
name="dataset/merge/target", |
|
label="Dataset", |
|
value=datasets_df.iloc[0]["path"], |
|
choices=import_choices, |
|
trigger=False, |
|
tooltip="Source of dataset import", |
|
), |
|
] |
|
|
|
if error: |
|
items.append(ui.message_bar(type="error", text=error)) |
|
|
|
q.page["dataset/merge"].items = items |
|
|
|
buttons = [ |
|
ui.button(name="dataset/merge/action", label="Merge", primary=True), |
|
ui.button(name="dataset/import/3", label="Back", primary=False), |
|
ui.button(name="dataset/list", label="Abort"), |
|
] |
|
|
|
q.page["dataset/import/footer"] = ui.form_card( |
|
box="footer", items=[ui.inline(items=buttons, justify="start")] |
|
) |
|
q.client.delete_cards.add("dataset/import/footer") |
|
|
|
elif step == 2: |
|
current_dir = q.client["dataset/import/path"] |
|
target_dir = q.args["dataset/merge/target"] |
|
|
|
if current_dir == target_dir: |
|
await dataset_merge(q, step=1, error="Cannot merge dataset with itself") |
|
return |
|
|
|
datasets_df = q.client.app_db.get_datasets_df().set_index("path") |
|
has_dataset_entry = current_dir in datasets_df.index |
|
|
|
if has_dataset_entry: |
|
experiment_df = q.client.app_db.get_experiments_df() |
|
source_id = int(datasets_df.loc[current_dir, "id"]) |
|
has_experiment = any(experiment_df["dataset"].astype(int) == source_id) |
|
else: |
|
source_id = None |
|
has_experiment = False |
|
|
|
current_files = os.listdir(current_dir) |
|
current_files = [x for x in current_files if not x.endswith(".yaml")] |
|
target_files = os.listdir(target_dir) |
|
overlapping_files = list(set(current_files).intersection(set(target_files))) |
|
rename_map = {} |
|
|
|
for file in overlapping_files: |
|
tmp_str = file.split(".") |
|
if len(tmp_str) == 1: |
|
file_name, extension = file, "" |
|
else: |
|
file_name, extension = ".".join(tmp_str[:-1]), f".{tmp_str[-1]}" |
|
|
|
cnt = 1 |
|
while f"{file_name}_{cnt}{extension}" in target_files: |
|
cnt += 1 |
|
|
|
rename_map[file] = f"{file_name}_{cnt}{extension}" |
|
target_files.append(rename_map[file]) |
|
|
|
if len(overlapping_files): |
|
warning = ( |
|
f"Renamed {', '.join(rename_map.keys())} to " |
|
f"{', '.join(rename_map.values())} due to duplicated entries." |
|
) |
|
else: |
|
warning = "" |
|
|
|
for file in current_files: |
|
new_file = rename_map.get(file, file) |
|
src = os.path.join(current_dir, file) |
|
dst = os.path.join(target_dir, new_file) |
|
|
|
if has_experiment: |
|
if os.path.isdir(src): |
|
shutil.copytree(src, dst) |
|
else: |
|
shutil.copy(src, dst) |
|
else: |
|
shutil.move(src, dst) |
|
|
|
if not has_experiment: |
|
shutil.rmtree(current_dir) |
|
if has_dataset_entry: |
|
q.client.app_db.delete_dataset(source_id) |
|
|
|
dataset_id = int(datasets_df.loc[target_dir, "id"]) |
|
await dataset_edit(q, dataset_id, warning=warning, allow_merge=False) |
|
|
|
|
|
async def dataset_list_table( |
|
q: Q, |
|
show_experiment_datasets: bool = True, |
|
) -> None: |
|
"""Pepare dataset list form card |
|
|
|
Args: |
|
q: Q |
|
show_experiment_datasets: whether to also show datasets linked to experiments |
|
""" |
|
|
|
q.client["dataset/list/df_datasets"] = get_datasets( |
|
q=q, |
|
show_experiment_datasets=show_experiment_datasets, |
|
) |
|
|
|
df_viz = q.client["dataset/list/df_datasets"].copy() |
|
|
|
columns_to_drop = [ |
|
"id", |
|
"path", |
|
"config_file", |
|
"validation dataframe", |
|
] |
|
|
|
df_viz = df_viz.drop(columns=columns_to_drop, errors="ignore") |
|
if "problem type" in df_viz.columns: |
|
df_viz["problem type"] = df_viz["problem type"].str.replace("Text ", "") |
|
|
|
widths = { |
|
"name": "200", |
|
"problem type": "210", |
|
"train dataframe": "190", |
|
"train rows": "120", |
|
"validation rows": "130", |
|
"labels": "120", |
|
"actions": "5", |
|
} |
|
|
|
actions_dict = { |
|
"dataset/newexperiment": "New experiment", |
|
"dataset/edit": "Edit dataset", |
|
"dataset/delete/dialog/single": "Delete dataset", |
|
} |
|
|
|
q.page["dataset/list"] = ui.form_card( |
|
box="content", |
|
items=[ |
|
ui_table_from_df( |
|
q=q, |
|
df=df_viz, |
|
name="dataset/list/table", |
|
sortables=["train rows", "validation rows"], |
|
filterables=["name", "problem type"], |
|
searchables=[], |
|
min_widths=widths, |
|
link_col="name", |
|
height="calc(100vh - 245px)", |
|
actions=actions_dict, |
|
), |
|
ui.message_bar(type="info", text=""), |
|
], |
|
) |
|
q.client.delete_cards.add("dataset/list") |
|
|
|
|
|
async def dataset_list(q: Q, reset: bool = True) -> None: |
|
"""Display all datasets.""" |
|
q.client["nav/active"] = "dataset/list" |
|
|
|
if reset: |
|
await clean_dashboard(q, mode="full") |
|
await dataset_list_table(q) |
|
|
|
q.page["dataset/display/footer"] = ui.form_card( |
|
box="footer", |
|
items=[ |
|
ui.inline( |
|
items=[ |
|
ui.button( |
|
name="dataset/import", label="Import dataset", primary=True |
|
), |
|
ui.button( |
|
name="dataset/list/delete", |
|
label="Delete datasets", |
|
primary=False, |
|
), |
|
], |
|
justify="start", |
|
) |
|
], |
|
) |
|
q.client.delete_cards.add("dataset/display/footer") |
|
remove_temp_files(q) |
|
|
|
await q.page.save() |
|
|
|
|
|
async def dataset_newexperiment(q: Q, dataset_id: int): |
|
"""Start a new experiment from given dataset.""" |
|
|
|
dataset = q.client.app_db.get_dataset(dataset_id) |
|
|
|
q.client["experiment/start/cfg_file"] = dataset.config_file.split("/")[-1].replace( |
|
".yaml", "" |
|
) |
|
q.client["experiment/start/cfg_category"] = q.client[ |
|
"experiment/start/cfg_file" |
|
].split("_")[0] |
|
q.client["experiment/start/dataset"] = str(dataset_id) |
|
|
|
await experiment_start(q) |
|
|
|
|
|
async def dataset_edit( |
|
q: Q, dataset_id: int, error: str = "", warning: str = "", allow_merge: bool = True |
|
): |
|
"""Edit selected dataset. |
|
|
|
Args: |
|
q: Q |
|
dataset_id: dataset id to edit |
|
error: optional error message |
|
warning: optional warning message |
|
allow_merge: whether to allow merging dataset when editing |
|
""" |
|
|
|
dataset = q.client.app_db.get_dataset(dataset_id) |
|
|
|
experiments_df = q.client.app_db.get_experiments_df() |
|
experiments_df = experiments_df[experiments_df["dataset"] == str(dataset_id)] |
|
statuses, _ = get_experiments_status(experiments_df) |
|
num_invalid = len([stat for stat in statuses if stat in ["running", "queued"]]) |
|
|
|
if num_invalid: |
|
info = "s" if num_invalid > 1 else "" |
|
info_str = ( |
|
f"Dataset <strong>{dataset.name}</strong> is linked to {num_invalid} " |
|
f"running or queued experiment{info}. Wait for them to finish or stop them " |
|
"first before editing the dataset." |
|
) |
|
q.page["dataset/list"].items[1].message_bar.text = info_str |
|
return |
|
|
|
q.client["dataset/import/id"] = dataset_id |
|
|
|
q.client["dataset/import/cfg_file"] = dataset.config_file.split("/")[-1].replace( |
|
".yaml", "" |
|
) |
|
q.client["dataset/import/cfg_category"] = q.client["dataset/import/cfg_file"].split( |
|
"_" |
|
)[0] |
|
q.client["dataset/import/path"] = dataset.path |
|
q.client["dataset/import/name"] = dataset.name |
|
q.client["dataset/import/original_name"] = dataset.name |
|
q.client["dataset/import/cfg"] = load_config_yaml(dataset.config_file) |
|
|
|
if allow_merge and experiments_df.shape[0]: |
|
allow_merge = False |
|
|
|
await dataset_import( |
|
q=q, step=2, edit=True, error=error, warning=warning, allow_merge=allow_merge |
|
) |
|
|
|
|
|
async def dataset_list_delete(q: Q): |
|
"""Allow to select multiple datasets for deletion.""" |
|
|
|
await dataset_list_table(q, show_experiment_datasets=False) |
|
|
|
q.page["dataset/list"].items[0].table.multiple = True |
|
|
|
info_str = "Only datasets not linked to experiments can be deleted." |
|
|
|
q.page["dataset/list"].items[1].message_bar.text = info_str |
|
|
|
q.page["dataset/display/footer"].items = [ |
|
ui.inline( |
|
items=[ |
|
ui.button( |
|
name="dataset/delete/dialog", label="Delete datasets", primary=True |
|
), |
|
ui.button(name="dataset/list/delete/abort", label="Abort"), |
|
] |
|
) |
|
] |
|
|
|
|
|
async def dataset_delete(q: Q, dataset_ids: List[int]): |
|
"""Delete selected datasets. |
|
|
|
Args: |
|
q: Q |
|
dataset_ids: list of dataset ids to delete |
|
""" |
|
|
|
for dataset_id in dataset_ids: |
|
dataset = q.client.app_db.get_dataset(dataset_id) |
|
q.client.app_db.delete_dataset(dataset.id) |
|
|
|
try: |
|
shutil.rmtree(dataset.path) |
|
except OSError: |
|
pass |
|
|
|
|
|
async def dataset_delete_single(q: Q, dataset_id: int): |
|
dataset = q.client.app_db.get_dataset(dataset_id) |
|
|
|
experiments_df = q.client.app_db.get_experiments_df() |
|
num_experiments = sum(experiments_df["dataset"] == str(dataset_id)) |
|
if num_experiments: |
|
info = "s" if num_experiments > 1 else "" |
|
info_str = ( |
|
f"Dataset <strong>{dataset.name}</strong> is linked to {num_experiments} " |
|
f"experiment{info}. Only datasets not linked to experiments can be deleted." |
|
) |
|
await dataset_list(q) |
|
q.page["dataset/list"].items[1].message_bar.text = info_str |
|
else: |
|
await dataset_delete(q, [dataset_id]) |
|
await dataset_list(q) |
|
|
|
|
|
async def dataset_display(q: Q) -> None: |
|
"""Display a selected dataset.""" |
|
|
|
dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[ |
|
q.client["dataset/display/id"] |
|
] |
|
dataset: Dataset = q.client.app_db.get_dataset(dataset_id) |
|
config_filename = dataset.config_file |
|
cfg = load_config_yaml(config_filename) |
|
dataset_filename = cfg.dataset.train_dataframe |
|
|
|
if ( |
|
q.client["dataset/display/tab"] is None |
|
or q.args["dataset/display/data"] is not None |
|
): |
|
q.client["dataset/display/tab"] = "dataset/display/data" |
|
|
|
if q.args["dataset/display/visualization"] is not None: |
|
q.client["dataset/display/tab"] = "dataset/display/visualization" |
|
|
|
if q.args["dataset/display/statistics"] is not None: |
|
q.client["dataset/display/tab"] = "dataset/display/statistics" |
|
|
|
if q.args["dataset/display/summary"] is not None: |
|
q.client["dataset/display/tab"] = "dataset/display/summary" |
|
|
|
await clean_dashboard(q, mode=q.client["dataset/display/tab"]) |
|
|
|
items: List[Tab] = [ |
|
ui.tab(name="dataset/display/data", label="Sample Train Data"), |
|
ui.tab( |
|
name="dataset/display/visualization", label="Sample Train Visualization" |
|
), |
|
ui.tab(name="dataset/display/statistics", label="Train Data Statistics"), |
|
ui.tab(name="dataset/display/summary", label="Summary"), |
|
] |
|
|
|
q.page["dataset/display/tab"] = ui.tab_card( |
|
box="nav2", |
|
link=True, |
|
items=items, |
|
value=q.client["dataset/display/tab"], |
|
) |
|
q.client.delete_cards.add("dataset/display/tab") |
|
|
|
if q.client["dataset/display/tab"] == "dataset/display/data": |
|
await show_data_tab(q=q, cfg=cfg, filename=dataset_filename) |
|
|
|
elif q.client["dataset/display/tab"] == "dataset/display/visualization": |
|
await show_visualization_tab(q, cfg) |
|
|
|
elif q.client["dataset/display/tab"] == "dataset/display/statistics": |
|
await show_statistics_tab( |
|
q, dataset_filename=dataset_filename, config_filename=config_filename |
|
) |
|
|
|
elif q.client["dataset/display/tab"] == "dataset/display/summary": |
|
await show_summary_tab(q, dataset_id) |
|
|
|
q.page["dataset/display/footer"] = ui.form_card( |
|
box="footer", |
|
items=[ |
|
ui.inline( |
|
items=[ |
|
ui.button( |
|
name="dataset/newexperiment/from_current", |
|
label="Create experiment", |
|
primary=False, |
|
disabled=False, |
|
tooltip=None, |
|
), |
|
ui.button(name="dataset/list", label="Back", primary=False), |
|
], |
|
justify="start", |
|
) |
|
], |
|
) |
|
q.client.delete_cards.add("dataset/display/footer") |
|
|
|
|
|
async def show_data_tab(q, cfg, filename: str): |
|
fill_columns = get_fill_columns(cfg) |
|
df = read_dataframe(filename, n_rows=200, fill_columns=fill_columns) |
|
q.page["dataset/display/data"] = ui.form_card( |
|
box="first", |
|
items=[ |
|
ui_table_from_df( |
|
q=q, |
|
df=df, |
|
name="dataset/display/data/table", |
|
sortables=list(df.columns), |
|
height="calc(100vh - 265px)", |
|
cell_overflow="wrap", |
|
) |
|
], |
|
) |
|
q.client.delete_cards.add("dataset/display/data") |
|
|
|
|
|
async def show_visualization_tab(q, cfg): |
|
try: |
|
plot = cfg.logging.plots_class.plot_data(cfg) |
|
except Exception as error: |
|
logger.error(f"Error while plotting data preview: {error}", exc_info=True) |
|
plot = PlotData("<h2>Error while plotting data.</h2>", encoding="html") |
|
card: ImageCard | MarkupCard | FormCard |
|
if plot.encoding == "image": |
|
card = ui.image_card(box="first", title="", type="png", image=plot.data) |
|
elif plot.encoding == "html": |
|
card = ui.markup_card(box="first", title="", content=plot.data) |
|
elif plot.encoding == "df": |
|
df = pd.read_parquet(plot.data) |
|
df = df.iloc[:2000] |
|
min_widths = {"Content": "800"} |
|
card = ui.form_card( |
|
box="first", |
|
items=[ |
|
ui_table_from_df( |
|
q=q, |
|
df=df, |
|
name="dataset/display/visualization/table", |
|
markdown_cells=list(df.columns), |
|
searchables=list(df.columns), |
|
downloadable=True, |
|
resettable=True, |
|
min_widths=min_widths, |
|
height="calc(100vh - 245px)", |
|
max_char_length=50_000, |
|
cell_overflow="tooltip", |
|
) |
|
], |
|
) |
|
|
|
else: |
|
raise ValueError(f"Unknown plot encoding `{plot.encoding}`") |
|
q.page["dataset/display/visualization"] = card |
|
q.client.delete_cards.add("dataset/display/visualization") |
|
|
|
|
|
async def show_summary_tab(q, dataset_id): |
|
dataset_df = get_datasets(q) |
|
dataset_df = dataset_df[dataset_df.id == dataset_id] |
|
stat_list_items: List[StatListItem] = [] |
|
for col in dataset_df.columns: |
|
if col in ["id", "config_file", "path", "process_id", "status"]: |
|
continue |
|
v = dataset_df[col].values[0] |
|
t: StatListItem = ui.stat_list_item(label=make_label(col), value=str(v)) |
|
|
|
stat_list_items.append(t) |
|
q.page["dataset/display/summary"] = ui.stat_list_card( |
|
box="first", items=stat_list_items, title="" |
|
) |
|
q.client.delete_cards.add("dataset/display/summary") |
|
|
|
|
|
async def show_statistics_tab(q, dataset_filename, config_filename): |
|
cfg_hash = hashlib.md5(open(config_filename, "rb").read()).hexdigest() |
|
stats_dict = compute_dataset_statistics(dataset_filename, config_filename, cfg_hash) |
|
|
|
for chat_type in ["prompts", "answers"]: |
|
q.page[f"dataset/display/statistics/{chat_type}_histogram"] = histogram_card( |
|
x=stats_dict[chat_type], |
|
x_axis_description=f"text_length_{chat_type.capitalize()}", |
|
title=f"Text Length Distribution for {chat_type.capitalize()}" |
|
f" (split by whitespace)", |
|
histogram_box="first", |
|
) |
|
q.client.delete_cards.add(f"dataset/display/statistics/{chat_type}_histogram") |
|
|
|
q.page["dataset/display/statistics/full_conversation_histogram"] = histogram_card( |
|
x=stats_dict["complete_conversations"], |
|
x_axis_description="text_length_complete_conversations", |
|
title="Text Length Distribution for complete " |
|
"conversations (split by whitespace)", |
|
histogram_box="second", |
|
) |
|
q.client.delete_cards.add("dataset/display/statistics/full_conversation_histogram") |
|
|
|
if len(set(stats_dict["number_of_prompts"])) > 1: |
|
q.page["dataset/display/statistics/parent_id_length_histogram"] = ( |
|
histogram_card( |
|
x=stats_dict["number_of_prompts"], |
|
x_axis_description="number_of_prompts", |
|
title="Distribution of number of prompt-answer turns per conversation.", |
|
histogram_box="second", |
|
) |
|
) |
|
q.client.delete_cards.add( |
|
"dataset/display/statistics/parent_id_length_histogram" |
|
) |
|
|
|
df_stats = stats_dict["df_stats"] |
|
if df_stats is None: |
|
component_items = [ |
|
ui.text( |
|
"Dataset does not contain numerical or text features. " |
|
"No statistics available." |
|
) |
|
] |
|
else: |
|
if df_stats.shape[1] > 5: |
|
widths = {col: "77" for col in df_stats} |
|
else: |
|
widths = None |
|
component_items = [ |
|
ui_table_from_df( |
|
q=q, |
|
df=df_stats, |
|
name="dataset/display/statistics/table", |
|
sortables=list(df_stats.columns), |
|
min_widths=widths, |
|
height="265px", |
|
) |
|
] |
|
q.page["dataset/display/statistics"] = ui.form_card( |
|
box="third", |
|
items=component_items, |
|
) |
|
q.client.delete_cards.add("dataset/display/statistics") |
|
|
|
|
|
@functools.lru_cache() |
|
def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str): |
|
""" |
|
Compute various statistics for a dataset. |
|
- text length distribution for prompts and answers |
|
- text length distribution for complete conversations |
|
- distribution of number of prompt-answer turns per conversation |
|
- statistics for non text features |
|
|
|
We use LRU caching to avoid recomputing the statistics for the same dataset. |
|
Thus, cfg_hash is used as a function argument to identify the dataset. |
|
""" |
|
df_train = read_dataframe(dataset_path) |
|
cfg = load_config_yaml(cfg_path) |
|
conversations = get_conversation_chains( |
|
df=df_train, cfg=cfg, limit_chained_samples=True |
|
) |
|
stats_dict = {} |
|
for chat_type in ["prompts", "answers"]: |
|
text_lengths = [ |
|
[len(text.split(" ")) for text in conversation[chat_type]] |
|
for conversation in conversations |
|
] |
|
text_lengths = [item for sublist in text_lengths for item in sublist] |
|
stats_dict[chat_type] = text_lengths |
|
input_texts = [] |
|
for conversation in conversations: |
|
input_text = conversation["systems"][0] |
|
prompts = conversation["prompts"] |
|
answers = conversation["answers"] |
|
for prompt, answer in zip(prompts, answers): |
|
input_text += prompt + answer |
|
input_texts += [input_text] |
|
stats_dict["complete_conversations"] = [ |
|
len(text.split(" ")) for text in input_texts |
|
] |
|
stats_dict["number_of_prompts"] = [ |
|
len(conversation["prompts"]) for conversation in conversations |
|
] |
|
stats_dict["df_stats"] = get_frame_stats(df_train) |
|
return stats_dict |
|
|
|
|
|
async def dataset_import_uploaded_file(q: Q): |
|
local_path = await q.site.download( |
|
q.args["dataset/import/local_upload"][0], |
|
f"{get_data_dir(q)}/" |
|
f'{q.args["dataset/import/local_upload"][0].split("/")[-1]}', |
|
) |
|
await q.site.unload(q.args["dataset/import/local_upload"][0]) |
|
valid, error = check_valid_upload_content(local_path) |
|
if valid: |
|
q.args["dataset/import/local_path"] = local_path |
|
q.client["dataset/import/local_path"] = q.args["dataset/import/local_path"] |
|
await dataset_import(q, step=2) |
|
else: |
|
await dataset_import(q, step=1, error=error) |
|
|
|
|
|
async def dataset_delete_current_datasets(q: Q): |
|
dataset_ids = list( |
|
q.client["dataset/list/df_datasets"]["id"].iloc[ |
|
list(map(int, q.client["dataset/list/table"])) |
|
] |
|
) |
|
await dataset_delete(q, dataset_ids) |
|
await dataset_list(q) |
|
|