import functools
import hashlib
import logging
import os
import re
import shutil
import time
import traceback
from typing import List, Optional

import pandas as pd
from h2o_wave import Q, ui
from h2o_wave.types import FormCard, ImageCard, MarkupCard, StatListItem, Tab

from llm_studio.app_utils.config import default_cfg
from llm_studio.app_utils.db import Dataset
from llm_studio.app_utils.sections.common import clean_dashboard
from llm_studio.app_utils.sections.experiment import experiment_start
from llm_studio.app_utils.sections.histogram_card import histogram_card
from llm_studio.app_utils.utils import (
    add_model_type,
    azure_download,
    azure_file_options,
    check_valid_upload_content,
    clean_error,
    dir_file_table,
    get_data_dir,
    get_dataset_elements,
    get_datasets,
    get_experiments_status,
    get_frame_stats,
    get_model_types,
    get_problem_types,
    get_unique_dataset_name,
    kaggle_download,
    local_download,
    make_label,
    parse_ui_elements,
    remove_temp_files,
    s3_download,
    s3_file_options,
)
from llm_studio.app_utils.wave_utils import busy_dialog, ui_table_from_df
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains
from llm_studio.src.utils.config_utils import (
    load_config_py,
    load_config_yaml,
    save_config_yaml,
)
from llm_studio.src.utils.data_utils import (
    get_fill_columns,
    read_dataframe,
    read_dataframe_drop_missing_labels,
    sanity_check,
)
from llm_studio.src.utils.plot_utils import PlotData

logger = logging.getLogger(__name__)


def file_extension_is_compatible(q):
    cfg = q.client["dataset/import/cfg"]
    allowed_extensions = cfg.dataset._allowed_file_extensions

    is_correct_extension = []
    for mode in ["train", "validation"]:
        dataset_name = q.client[f"dataset/import/cfg/{mode}_dataframe"]

        if dataset_name is None or dataset_name == "None":
            continue
        is_correct_extension.append(dataset_name.endswith(allowed_extensions))
    return all(is_correct_extension)


async def dataset_import(
    q: Q,
    step: int,
    edit: Optional[bool] = False,
    error: Optional[str] = "",
    warning: Optional[str] = "",
    info: Optional[str] = "",
    allow_merge: bool = True,
) -> None:
    """Display dataset import cards.

    Args:
        q: Q
        step: current step of wizard
        edit: whether in edit mode
        error: optional error message
        warning: optional warning message
        info: optional info message
        allow_merge: whether to allow merging dataset when importing
    """

    await clean_dashboard(q, mode="full")
    q.client["nav/active"] = "dataset/import"
    if step == 1:  # select import data source
        q.page["dataset/import"] = ui.form_card(box="content", items=[])
        q.client.delete_cards.add("dataset/import")

        if q.client["dataset/import/source"] is None:
            q.client["dataset/import/source"] = "Upload"

        import_choices = [
            ui.choice("Upload", "Upload"),
            ui.choice("Local", "Local"),
            ui.choice("S3", "AWS S3"),
            ui.choice("Azure", "Azure Datalake"),
            ui.choice("Kaggle", "Kaggle"),
        ]

        items = [
            ui.text_l("Import dataset"),
            ui.dropdown(
                name="dataset/import/source",
                label="Source",
                value=(
                    "Upload"
                    if q.client["dataset/import/source"] is None
                    else q.client["dataset/import/source"]
                ),
                choices=import_choices,
                trigger=True,
                tooltip="Source of dataset import",
            ),
        ]

        if (
            q.client["dataset/import/source"] is None
            or q.client["dataset/import/source"] == "S3"
        ):
            if q.client["dataset/import/s3_bucket"] is None:
                q.client["dataset/import/s3_bucket"] = q.client[
                    "default_aws_bucket_name"
                ]
            if q.client["dataset/import/s3_access_key"] is None:
                q.client["dataset/import/s3_access_key"] = q.client[
                    "default_aws_access_key"
                ]
            if q.client["dataset/import/s3_secret_key"] is None:
                q.client["dataset/import/s3_secret_key"] = q.client[
                    "default_aws_secret_key"
                ]

            files = s3_file_options(
                q.client["dataset/import/s3_bucket"],
                q.client["dataset/import/s3_access_key"],
                q.client["dataset/import/s3_secret_key"],
            )

            if not files:
                ui_filename = ui.textbox(
                    name="dataset/import/s3_filename",
                    label="File name",
                    value="",
                    required=True,
                    tooltip="File name to be imported",
                )
            else:
                if default_cfg.s3_filename in files:
                    default_file = default_cfg.s3_filename
                else:
                    default_file = files[0]
                ui_filename = ui.dropdown(
                    name="dataset/import/s3_filename",
                    label="File name",
                    value=default_file,
                    choices=[ui.choice(x, x.split("/")[-1]) for x in files],
                    required=True,
                    tooltip="File name to be imported",
                )

            items += [
                ui.textbox(
                    name="dataset/import/s3_bucket",
                    label="S3 bucket name",
                    value=q.client["dataset/import/s3_bucket"],
                    trigger=True,
                    required=True,
                    tooltip="S3 bucket name including relative paths",
                ),
                ui.textbox(
                    name="dataset/import/s3_access_key",
                    label="AWS access key",
                    value=q.client["dataset/import/s3_access_key"],
                    trigger=True,
                    required=True,
                    password=True,
                    tooltip="Optional AWS access key; empty for anonymous access.",
                ),
                ui.textbox(
                    name="dataset/import/s3_secret_key",
                    label="AWS secret key",
                    value=q.client["dataset/import/s3_secret_key"],
                    trigger=True,
                    required=True,
                    password=True,
                    tooltip="Optional AWS secret key; empty for anonymous access.",
                ),
                ui_filename,
            ]

        elif (
            q.client["dataset/import/source"] is None
            or q.client["dataset/import/source"] == "Azure"
        ):
            if q.client["dataset/import/azure_conn_string"] is None:
                q.client["dataset/import/azure_conn_string"] = q.client[
                    "default_azure_conn_string"
                ]
            if q.client["dataset/import/azure_container"] is None:
                q.client["dataset/import/azure_container"] = q.client[
                    "default_azure_container"
                ]

            files = azure_file_options(
                q.client["dataset/import/azure_conn_string"],
                q.client["dataset/import/azure_container"],
            )
            print(files)

            if not files:
                ui_filename = ui.textbox(
                    name="dataset/import/azure_filename",
                    label="File name",
                    value="",
                    required=True,
                    tooltip="File name to be imported",
                )
            else:
                default_file = files[0]
                ui_filename = ui.dropdown(
                    name="dataset/import/azure_filename",
                    label="File name",
                    value=default_file,
                    choices=[ui.choice(x, x.split("/")[-1]) for x in files],
                    required=True,
                    tooltip="File name to be imported",
                )

            items += [
                ui.textbox(
                    name="dataset/import/azure_conn_string",
                    label="Datalake connection string",
                    value=q.client["dataset/import/azure_conn_string"],
                    trigger=True,
                    required=True,
                    password=True,
                    tooltip="Azure connection string to connect to Datalake storage",
                ),
                ui.textbox(
                    name="dataset/import/azure_container",
                    label="Datalake container name",
                    value=q.client["dataset/import/azure_container"],
                    trigger=True,
                    required=True,
                    tooltip="Azure Datalake container name including relative paths",
                ),
                ui_filename,
            ]

        elif q.client["dataset/import/source"] == "Upload":
            items += [
                ui.file_upload(
                    name="dataset/import/local_upload",
                    label="Upload!",
                    multiple=False,
                    file_extensions=default_cfg.allowed_file_extensions,
                )
            ]

        elif q.client["dataset/import/source"] == "Local":
            current_path = (
                q.client["dataset/import/local_path_current"]
                if q.client["dataset/import/local_path_current"] is not None
                else os.path.expanduser("~")
            )

            if q.args.__wave_submission_name__ == "dataset/import/local_path_list":
                idx = int(q.args["dataset/import/local_path_list"][0])
                options = q.client["dataset/import/local_path_list_last"]
                new_path = os.path.abspath(os.path.join(current_path, options[idx]))
                if os.path.exists(new_path):
                    current_path = new_path

            results_df = dir_file_table(current_path)
            files_list = results_df[current_path].tolist()
            q.client["dataset/import/local_path_list_last"] = files_list
            q.client["dataset/import/local_path_current"] = current_path

            items += [
                ui.textbox(
                    name="dataset/import/local_path",
                    label="File location",
                    value=current_path,
                    required=True,
                    tooltip="Location of file to be imported",
                ),
                ui_table_from_df(
                    q=q,
                    df=results_df,
                    name="dataset/import/local_path_list",
                    sortables=[],
                    searchables=[],
                    min_widths={current_path: "400"},
                    link_col=current_path,
                    height="calc(65vh)",
                ),
            ]

        elif q.client["dataset/import/source"] == "Kaggle":
            if q.client["dataset/import/kaggle_access_key"] is None:
                q.client["dataset/import/kaggle_access_key"] = q.client[
                    "default_kaggle_username"
                ]
            if q.client["dataset/import/kaggle_secret_key"] is None:
                q.client["dataset/import/kaggle_secret_key"] = q.client[
                    "default_kaggle_secret_key"
                ]

            items += [
                ui.textbox(
                    name="dataset/import/kaggle_command",
                    label="Kaggle API command",
                    value=default_cfg.kaggle_command,
                    required=True,
                    tooltip="Kaggle API command to be executed",
                ),
                ui.textbox(
                    name="dataset/import/kaggle_access_key",
                    label="Kaggle username",
                    value=q.client["dataset/import/kaggle_access_key"],
                    required=True,
                    password=False,
                    tooltip="Kaggle username for API authentication",
                ),
                ui.textbox(
                    name="dataset/import/kaggle_secret_key",
                    label="Kaggle secret key",
                    value=q.client["dataset/import/kaggle_secret_key"],
                    required=True,
                    password=True,
                    tooltip="Kaggle secret key for API authentication",
                ),
            ]

        allowed_types = ", ".join(default_cfg.allowed_file_extensions)
        allowed_types = " or".join(allowed_types.rsplit(",", 1))
        items += [
            ui.message_bar(type="info", text=info + f"Must be a {allowed_types} file."),
            ui.message_bar(type="error", text=error),
            ui.message_bar(type="warning", text=warning),
        ]

        q.page["dataset/import"].items = items

        buttons = [ui.button(name="dataset/list", label="Abort")]
        if q.client["dataset/import/source"] != "Upload":
            buttons.insert(
                0, ui.button(name="dataset/import/2", label="Continue", primary=True)
            )

        q.page["dataset/import/footer"] = ui.form_card(
            box="footer", items=[ui.inline(items=buttons, justify="start")]
        )
        q.client.delete_cards.add("dataset/import/footer")

        q.client["dataset/import/id"] = None
        q.client["dataset/import/cfg_file"] = None

    elif step == 2:  # download / import data from source
        q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[])
        try:
            if not q.args["dataset/import/cfg_file"] and not edit:
                if q.client["dataset/import/source"] == "S3":
                    (
                        q.client["dataset/import/path"],
                        q.client["dataset/import/name"],
                    ) = await s3_download(
                        q,
                        q.client["dataset/import/s3_bucket"],
                        q.client["dataset/import/s3_filename"],
                        q.client["dataset/import/s3_access_key"],
                        q.client["dataset/import/s3_secret_key"],
                    )
                elif q.client["dataset/import/source"] == "Azure":
                    (
                        q.client["dataset/import/path"],
                        q.client["dataset/import/name"],
                    ) = await azure_download(
                        q,
                        q.client["dataset/import/azure_conn_string"],
                        q.client["dataset/import/azure_container"],
                        q.client["dataset/import/azure_filename"],
                    )
                elif q.client["dataset/import/source"] in ("Upload", "Local"):
                    (
                        q.client["dataset/import/path"],
                        q.client["dataset/import/name"],
                    ) = await local_download(q, q.client["dataset/import/local_path"])
                elif q.client["dataset/import/source"] == "Kaggle":
                    (
                        q.client["dataset/import/path"],
                        q.client["dataset/import/name"],
                    ) = await kaggle_download(
                        q,
                        q.client["dataset/import/kaggle_command"],
                        q.client["dataset/import/kaggle_access_key"],
                        q.client["dataset/import/kaggle_secret_key"],
                    )

            # store if in edit mode
            q.client["dataset/import/edit"] = edit

            # clear dataset triggers from client
            for trigger_key in default_cfg.dataset_trigger_keys:
                if q.client[f"dataset/import/cfg/{trigger_key}"]:
                    del q.client[f"dataset/import/cfg/{trigger_key}"]

            await dataset_import(
                q,
                step=3,
                edit=edit,
                error=error,
                warning=warning,
                allow_merge=allow_merge,
            )
        except Exception:
            logger.error("Dataset error:", exc_info=True)
            error = (
                "Dataset import failed. Please make sure all required "
                "fields are filled correctly."
            )
            await clean_dashboard(q, mode="full")
            await dataset_import(q, step=1, error=str(error))

    elif step == 3:  # set dataset configuration
        q.page["dataset/import/footer"] = ui.form_card(box="footer", items=[])
        try:
            if not q.args["dataset/import/cfg_file"] and not edit:
                q.client["dataset/import/name"] = get_unique_dataset_name(
                    q, q.client["dataset/import/name"]
                )
            q.page["dataset/import"] = ui.form_card(box="content", items=[])
            q.client.delete_cards.add("dataset/import")

            wizard = q.page["dataset/import"]

            title = "Configure dataset"

            items = [
                ui.text_l(title),
                ui.textbox(
                    name="dataset/import/name",
                    label="Dataset name",
                    value=q.client["dataset/import/name"],
                    required=True,
                ),
            ]

            choices_problem_types = [
                ui.choice(name, label) for name, label in get_problem_types()
            ]

            if q.client["dataset/import/cfg_file"] is None:
                max_substring_len = 0
                for c in choices_problem_types:
                    problem_type_name = c.name.replace("_config", "")
                    if problem_type_name in q.client["dataset/import/name"]:
                        if len(problem_type_name) > max_substring_len:
                            q.client["dataset/import/cfg_file"] = c.name
                            q.client["dataset/import/cfg_category"] = c.name.split("_")[
                                0
                            ]
                            max_substring_len = len(problem_type_name)
            if q.client["dataset/import/cfg_file"] is None:
                q.client["dataset/import/cfg_file"] = default_cfg.cfg_file
                q.client["dataset/import/cfg_category"] = q.client[  # type: ignore
                    "dataset/import/cfg_file"
                ].split("_")[0]

            # set default value of problem type if no match to category
            if (
                q.client["dataset/import/cfg_category"]
                not in q.client["dataset/import/cfg_file"]
            ):
                q.client["dataset/import/cfg_file"] = get_problem_types(
                    category=q.client["dataset/import/cfg_category"]
                )[0][0]

            model_types = get_model_types(q.client["dataset/import/cfg_file"])
            if len(model_types) > 0:
                # add model type to cfg file name here
                q.client["dataset/import/cfg_file"] = add_model_type(
                    q.client["dataset/import/cfg_file"], model_types[0][0]
                )
            if not edit:
                q.client["dataset/import/cfg"] = load_config_py(
                    config_path=(
                        f"llm_studio/python_configs/"
                        f"{q.client['dataset/import/cfg_file']}"
                    ),
                    config_name="ConfigProblemBase",
                )

            option_items = get_dataset_elements(cfg=q.client["dataset/import/cfg"], q=q)
            items.extend(option_items)
            items.append(ui.message_bar(type="error", text=error))
            items.append(ui.message_bar(type="warning", text=warning))
            if file_extension_is_compatible(q):
                ui_nav_name = "dataset/import/4/edit" if edit else "dataset/import/4"
                buttons = [
                    ui.button(name=ui_nav_name, label="Continue", primary=True),
                    ui.button(name="dataset/list", label="Abort"),
                ]
                if allow_merge:
                    datasets_df = q.client.app_db.get_datasets_df()
                    if datasets_df.shape[0]:
                        label = "Merge With Existing Dataset"
                        buttons.insert(1, ui.button(name="dataset/merge", label=label))
            else:
                problem_type = make_label(
                    re.sub("_config.*", "", q.client["dataset/import/cfg_file"])
                )
                items += [
                    ui.text(
                        "<b> The chosen file extensions is not "
                        f"compatible with {problem_type}.</b> "
                    )
                ]
                buttons = [
                    ui.button(name="dataset/list", label="Abort"),
                ]
            q.page["dataset/import/footer"] = ui.form_card(
                box="footer", items=[ui.inline(items=buttons, justify="start")]
            )

            wizard.items = items

            q.client.delete_cards.add("dataset/import/footer")

        except Exception as exception:
            logger.error("Dataset error:", exc_info=True)
            error = clean_error(str(exception))
            await clean_dashboard(q, mode="full")
            await dataset_import(q, step=1, error=str(error))

    elif step == 4:  # verify if dataset does not exist already
        dataset_name = q.client["dataset/import/name"]
        original_name = q.client["dataset/import/original_name"]  # used in edit mode
        valid_dataset_name = get_unique_dataset_name(q, dataset_name)
        if valid_dataset_name != dataset_name and not (
            q.client["dataset/import/edit"] and dataset_name == original_name
        ):
            err = f"Dataset <strong>{dataset_name}</strong> already exists"
            q.client["dataset/import/name"] = valid_dataset_name
            await dataset_import(q, 3, edit=edit, error=err)
        else:
            await dataset_import(q, 5, edit=edit)

    elif step == 5:  # visualize dataset
        header = "<h2>Sample Data Visualization</h2>"
        valid_visualization = False
        try:
            cfg = q.client["dataset/import/cfg"]
            cfg = parse_ui_elements(
                cfg=cfg, q=q, limit=default_cfg.dataset_keys, pre="dataset/import/cfg/"
            )

            q.client["dataset/import/cfg"] = cfg
            plot = cfg.logging.plots_class.plot_data(cfg)
            text = (
                "Data Validity Check. Click <strong>Continue</strong> if the input "
                "data and labels appear correctly."
            )
            if plot.encoding == "image":
                plot_item = ui.image(title="", type="png", image=plot.data)
            elif plot.encoding == "html":
                plot_item = ui.markup(content=plot.data)
            elif plot.encoding == "df":
                df = pd.read_parquet(plot.data)
                df = df.iloc[:2000]
                min_widths = {"Content": "800"}
                plot_item = ui_table_from_df(
                    q=q,
                    df=df,
                    name="experiment/display/table",
                    markdown_cells=list(df.columns),
                    searchables=list(df.columns),
                    downloadable=False,
                    resettable=False,
                    min_widths=min_widths,
                    height="calc(100vh - 245px)",
                    max_char_length=5_000,
                    cell_overflow="tooltip",
                )
            else:
                raise ValueError(f"Unknown plot encoding `{plot.encoding}`")

            items = [ui.markup(content=header), ui.message_bar(text=text), plot_item]
            valid_visualization = True

            await busy_dialog(
                q=q,
                title="Performing sanity checks on the data",
                text="Please be patient...",
            )
            # add one-second delay for datasets where sanity check is instant
            # to avoid flickering dialog
            time.sleep(1)
            sanity_check(cfg)

        except Exception as exception:
            logger.error(
                f"Error while plotting data preview: {exception}", exc_info=True
            )
            text = (
                "Error occurred while visualizing the data. Please go back and verify "
                "whether the problem type and other settings were set properly."
            )
            items = [
                ui.markup(content=header),
                ui.message_bar(text=text, type="error"),
                ui.expander(
                    name="expander",
                    label="Expand Error Traceback",
                    items=[ui.markup(f"<pre>{traceback.format_exc()}</pre>")],
                ),
            ]

        buttons = [
            ui.button(
                name="dataset/import/6", label="Continue", primary=valid_visualization
            ),
            ui.button(
                name="dataset/import/3/edit",
                label="Back",
                primary=not valid_visualization,
            ),
            ui.button(name="dataset/list", label="Abort"),
        ]

        q.page["dataset/import"] = ui.form_card(box="content", items=items)
        q.client.delete_cards.add("dataset/import")

        q.page["dataset/import/footer"] = ui.form_card(
            box="footer", items=[ui.inline(items=buttons, justify="start")]
        )
        q.client.delete_cards.add("dataset/import/footer")

    elif step == 6:  # create dataset
        if q.client["dataset/import/name"] == "":
            await clean_dashboard(q, mode="full")
            await dataset_import(q, step=2, error="Please enter all required fields!")

        else:
            folder_name = q.client["dataset/import/path"].split("/")[-1]
            new_folder = q.client["dataset/import/name"]
            act_path = q.client["dataset/import/path"]
            new_path = new_folder.join(act_path.rsplit(folder_name, 1))

            try:
                shutil.move(q.client["dataset/import/path"], new_path)

                cfg = q.client["dataset/import/cfg"]

                # remap old path to new path
                for k in default_cfg.dataset_folder_keys:
                    old_path = getattr(cfg.dataset, k, None)
                    if old_path is not None:
                        setattr(
                            cfg.dataset,
                            k,
                            old_path.replace(q.client["dataset/import/path"], new_path),
                        )

                # change the default validation strategy if validation df set
                if cfg.dataset.validation_dataframe != "None":
                    cfg.dataset.validation_strategy = "custom"
                cfg_path = f"{new_path}/{q.client['dataset/import/cfg_file']}.yaml"
                save_config_yaml(cfg_path, cfg)

                train_rows = None
                if os.path.exists(cfg.dataset.train_dataframe):
                    train_rows = read_dataframe_drop_missing_labels(
                        cfg.dataset.train_dataframe, cfg
                    ).shape[0]
                validation_rows = None
                if os.path.exists(cfg.dataset.validation_dataframe):
                    validation_rows = read_dataframe_drop_missing_labels(
                        cfg.dataset.validation_dataframe, cfg
                    ).shape[0]

                dataset = Dataset(
                    id=q.client["dataset/import/id"],
                    name=q.client["dataset/import/name"],
                    path=new_path,
                    config_file=cfg_path,
                    train_rows=train_rows,
                    validation_rows=validation_rows,
                )
                if q.client["dataset/import/id"] is not None:
                    q.client.app_db.delete_dataset(dataset.id)
                q.client.app_db.add_dataset(dataset)
                await dataset_list(q)

            except Exception as exception:
                logger.error("Dataset error:", exc_info=True)
                q.client.app_db._session.rollback()
                error = clean_error(str(exception))
                await clean_dashboard(q, mode="full")
                await dataset_import(q, step=2, error=str(error))


async def dataset_merge(q: Q, step, error=""):
    if step == 1:  # Select which dataset to merge
        await clean_dashboard(q, mode="full")
        q.client["nav/active"] = "dataset/merge"

        q.page["dataset/merge"] = ui.form_card(box="content", items=[])
        q.client.delete_cards.add("dataset/merge")

        datasets_df = q.client.app_db.get_datasets_df()
        import_choices = [
            ui.choice(x["path"], x["name"]) for idx, x in datasets_df.iterrows()
        ]

        items = [
            ui.text_l("Merge current dataset with an existing dataset"),
            ui.dropdown(
                name="dataset/merge/target",
                label="Dataset",
                value=datasets_df.iloc[0]["path"],
                choices=import_choices,
                trigger=False,
                tooltip="Source of dataset import",
            ),
        ]

        if error:
            items.append(ui.message_bar(type="error", text=error))

        q.page["dataset/merge"].items = items

        buttons = [
            ui.button(name="dataset/merge/action", label="Merge", primary=True),
            ui.button(name="dataset/import/3", label="Back", primary=False),
            ui.button(name="dataset/list", label="Abort"),
        ]

        q.page["dataset/import/footer"] = ui.form_card(
            box="footer", items=[ui.inline(items=buttons, justify="start")]
        )
        q.client.delete_cards.add("dataset/import/footer")

    elif step == 2:  # copy file to dataset and go to edit dataset
        current_dir = q.client["dataset/import/path"]
        target_dir = q.args["dataset/merge/target"]

        if current_dir == target_dir:
            await dataset_merge(q, step=1, error="Cannot merge dataset with itself")
            return

        datasets_df = q.client.app_db.get_datasets_df().set_index("path")
        has_dataset_entry = current_dir in datasets_df.index

        if has_dataset_entry:
            experiment_df = q.client.app_db.get_experiments_df()
            source_id = int(datasets_df.loc[current_dir, "id"])
            has_experiment = any(experiment_df["dataset"].astype(int) == source_id)
        else:
            source_id = None
            has_experiment = False

        current_files = os.listdir(current_dir)
        current_files = [x for x in current_files if not x.endswith(".yaml")]
        target_files = os.listdir(target_dir)
        overlapping_files = list(set(current_files).intersection(set(target_files)))
        rename_map = {}

        for file in overlapping_files:
            tmp_str = file.split(".")
            if len(tmp_str) == 1:
                file_name, extension = file, ""
            else:
                file_name, extension = ".".join(tmp_str[:-1]), f".{tmp_str[-1]}"

            cnt = 1
            while f"{file_name}_{cnt}{extension}" in target_files:
                cnt += 1

            rename_map[file] = f"{file_name}_{cnt}{extension}"
            target_files.append(rename_map[file])

        if len(overlapping_files):
            warning = (
                f"Renamed {', '.join(rename_map.keys())} to "
                f"{', '.join(rename_map.values())} due to duplicated entries."
            )
        else:
            warning = ""

        for file in current_files:
            new_file = rename_map.get(file, file)
            src = os.path.join(current_dir, file)
            dst = os.path.join(target_dir, new_file)

            if has_experiment:
                if os.path.isdir(src):
                    shutil.copytree(src, dst)
                else:
                    shutil.copy(src, dst)
            else:
                shutil.move(src, dst)

        if not has_experiment:
            shutil.rmtree(current_dir)
            if has_dataset_entry:
                q.client.app_db.delete_dataset(source_id)

        dataset_id = int(datasets_df.loc[target_dir, "id"])
        await dataset_edit(q, dataset_id, warning=warning, allow_merge=False)


async def dataset_list_table(
    q: Q,
    show_experiment_datasets: bool = True,
) -> None:
    """Pepare dataset list form card

    Args:
        q: Q
        show_experiment_datasets: whether to also show datasets linked to experiments
    """

    q.client["dataset/list/df_datasets"] = get_datasets(
        q=q,
        show_experiment_datasets=show_experiment_datasets,
    )

    df_viz = q.client["dataset/list/df_datasets"].copy()

    columns_to_drop = [
        "id",
        "path",
        "config_file",
        "validation dataframe",
    ]

    df_viz = df_viz.drop(columns=columns_to_drop, errors="ignore")
    if "problem type" in df_viz.columns:
        df_viz["problem type"] = df_viz["problem type"].str.replace("Text ", "")

    widths = {
        "name": "200",
        "problem type": "210",
        "train dataframe": "190",
        "train rows": "120",
        "validation rows": "130",
        "labels": "120",
        "actions": "5",
    }

    actions_dict = {
        "dataset/newexperiment": "New experiment",
        "dataset/edit": "Edit dataset",
        "dataset/delete/dialog/single": "Delete dataset",
    }

    q.page["dataset/list"] = ui.form_card(
        box="content",
        items=[
            ui_table_from_df(
                q=q,
                df=df_viz,
                name="dataset/list/table",
                sortables=["train rows", "validation rows"],
                filterables=["name", "problem type"],
                searchables=[],
                min_widths=widths,
                link_col="name",
                height="calc(100vh - 245px)",
                actions=actions_dict,
            ),
            ui.message_bar(type="info", text=""),
        ],
    )
    q.client.delete_cards.add("dataset/list")


async def dataset_list(q: Q, reset: bool = True) -> None:
    """Display all datasets."""
    q.client["nav/active"] = "dataset/list"

    if reset:
        await clean_dashboard(q, mode="full")
        await dataset_list_table(q)

    q.page["dataset/display/footer"] = ui.form_card(
        box="footer",
        items=[
            ui.inline(
                items=[
                    ui.button(
                        name="dataset/import", label="Import dataset", primary=True
                    ),
                    ui.button(
                        name="dataset/list/delete",
                        label="Delete datasets",
                        primary=False,
                    ),
                ],
                justify="start",
            )
        ],
    )
    q.client.delete_cards.add("dataset/display/footer")
    remove_temp_files(q)

    await q.page.save()


async def dataset_newexperiment(q: Q, dataset_id: int):
    """Start a new experiment from given dataset."""

    dataset = q.client.app_db.get_dataset(dataset_id)

    q.client["experiment/start/cfg_file"] = dataset.config_file.split("/")[-1].replace(
        ".yaml", ""
    )
    q.client["experiment/start/cfg_category"] = q.client[
        "experiment/start/cfg_file"
    ].split("_")[0]
    q.client["experiment/start/dataset"] = str(dataset_id)

    await experiment_start(q)


async def dataset_edit(
    q: Q, dataset_id: int, error: str = "", warning: str = "", allow_merge: bool = True
):
    """Edit selected dataset.

    Args:
        q: Q
        dataset_id: dataset id to edit
        error: optional error message
        warning: optional warning message
        allow_merge: whether to allow merging dataset when editing
    """

    dataset = q.client.app_db.get_dataset(dataset_id)

    experiments_df = q.client.app_db.get_experiments_df()
    experiments_df = experiments_df[experiments_df["dataset"] == str(dataset_id)]
    statuses, _ = get_experiments_status(experiments_df)
    num_invalid = len([stat for stat in statuses if stat in ["running", "queued"]])

    if num_invalid:
        info = "s" if num_invalid > 1 else ""
        info_str = (
            f"Dataset <strong>{dataset.name}</strong> is linked to {num_invalid} "
            f"running or queued experiment{info}. Wait for them to finish or stop them "
            "first before editing the dataset."
        )
        q.page["dataset/list"].items[1].message_bar.text = info_str
        return

    q.client["dataset/import/id"] = dataset_id

    q.client["dataset/import/cfg_file"] = dataset.config_file.split("/")[-1].replace(
        ".yaml", ""
    )
    q.client["dataset/import/cfg_category"] = q.client["dataset/import/cfg_file"].split(
        "_"
    )[0]
    q.client["dataset/import/path"] = dataset.path
    q.client["dataset/import/name"] = dataset.name
    q.client["dataset/import/original_name"] = dataset.name
    q.client["dataset/import/cfg"] = load_config_yaml(dataset.config_file)

    if allow_merge and experiments_df.shape[0]:
        allow_merge = False

    await dataset_import(
        q=q, step=2, edit=True, error=error, warning=warning, allow_merge=allow_merge
    )


async def dataset_list_delete(q: Q):
    """Allow to select multiple datasets for deletion."""

    await dataset_list_table(q, show_experiment_datasets=False)

    q.page["dataset/list"].items[0].table.multiple = True

    info_str = "Only datasets not linked to experiments can be deleted."

    q.page["dataset/list"].items[1].message_bar.text = info_str

    q.page["dataset/display/footer"].items = [
        ui.inline(
            items=[
                ui.button(
                    name="dataset/delete/dialog", label="Delete datasets", primary=True
                ),
                ui.button(name="dataset/list/delete/abort", label="Abort"),
            ]
        )
    ]


async def dataset_delete(q: Q, dataset_ids: List[int]):
    """Delete selected datasets.

    Args:
        q: Q
        dataset_ids: list of dataset ids to delete
    """

    for dataset_id in dataset_ids:
        dataset = q.client.app_db.get_dataset(dataset_id)
        q.client.app_db.delete_dataset(dataset.id)

        try:
            shutil.rmtree(dataset.path)
        except OSError:
            pass


async def dataset_delete_single(q: Q, dataset_id: int):
    dataset = q.client.app_db.get_dataset(dataset_id)

    experiments_df = q.client.app_db.get_experiments_df()
    num_experiments = sum(experiments_df["dataset"] == str(dataset_id))
    if num_experiments:
        info = "s" if num_experiments > 1 else ""
        info_str = (
            f"Dataset <strong>{dataset.name}</strong> is linked to {num_experiments} "
            f"experiment{info}. Only datasets not linked to experiments can be deleted."
        )
        await dataset_list(q)
        q.page["dataset/list"].items[1].message_bar.text = info_str
    else:
        await dataset_delete(q, [dataset_id])
        await dataset_list(q)


async def dataset_display(q: Q) -> None:
    """Display a selected dataset."""

    dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[
        q.client["dataset/display/id"]
    ]
    dataset: Dataset = q.client.app_db.get_dataset(dataset_id)
    config_filename = dataset.config_file
    cfg = load_config_yaml(config_filename)
    dataset_filename = cfg.dataset.train_dataframe

    if (
        q.client["dataset/display/tab"] is None
        or q.args["dataset/display/data"] is not None
    ):
        q.client["dataset/display/tab"] = "dataset/display/data"

    if q.args["dataset/display/visualization"] is not None:
        q.client["dataset/display/tab"] = "dataset/display/visualization"

    if q.args["dataset/display/statistics"] is not None:
        q.client["dataset/display/tab"] = "dataset/display/statistics"

    if q.args["dataset/display/summary"] is not None:
        q.client["dataset/display/tab"] = "dataset/display/summary"

    await clean_dashboard(q, mode=q.client["dataset/display/tab"])

    items: List[Tab] = [
        ui.tab(name="dataset/display/data", label="Sample Train Data"),
        ui.tab(
            name="dataset/display/visualization", label="Sample Train Visualization"
        ),
        ui.tab(name="dataset/display/statistics", label="Train Data Statistics"),
        ui.tab(name="dataset/display/summary", label="Summary"),
    ]

    q.page["dataset/display/tab"] = ui.tab_card(
        box="nav2",
        link=True,
        items=items,
        value=q.client["dataset/display/tab"],
    )
    q.client.delete_cards.add("dataset/display/tab")

    if q.client["dataset/display/tab"] == "dataset/display/data":
        await show_data_tab(q=q, cfg=cfg, filename=dataset_filename)

    elif q.client["dataset/display/tab"] == "dataset/display/visualization":
        await show_visualization_tab(q, cfg)

    elif q.client["dataset/display/tab"] == "dataset/display/statistics":
        await show_statistics_tab(
            q, dataset_filename=dataset_filename, config_filename=config_filename
        )

    elif q.client["dataset/display/tab"] == "dataset/display/summary":
        await show_summary_tab(q, dataset_id)

    q.page["dataset/display/footer"] = ui.form_card(
        box="footer",
        items=[
            ui.inline(
                items=[
                    ui.button(
                        name="dataset/newexperiment/from_current",
                        label="Create experiment",
                        primary=False,
                        disabled=False,
                        tooltip=None,
                    ),
                    ui.button(name="dataset/list", label="Back", primary=False),
                ],
                justify="start",
            )
        ],
    )
    q.client.delete_cards.add("dataset/display/footer")


async def show_data_tab(q, cfg, filename: str):
    fill_columns = get_fill_columns(cfg)
    df = read_dataframe(filename, n_rows=200, fill_columns=fill_columns)
    q.page["dataset/display/data"] = ui.form_card(
        box="first",
        items=[
            ui_table_from_df(
                q=q,
                df=df,
                name="dataset/display/data/table",
                sortables=list(df.columns),
                height="calc(100vh - 265px)",
                cell_overflow="wrap",
            )
        ],
    )
    q.client.delete_cards.add("dataset/display/data")


async def show_visualization_tab(q, cfg):
    try:
        plot = cfg.logging.plots_class.plot_data(cfg)
    except Exception as error:
        logger.error(f"Error while plotting data preview: {error}", exc_info=True)
        plot = PlotData("<h2>Error while plotting data.</h2>", encoding="html")
    card: ImageCard | MarkupCard | FormCard
    if plot.encoding == "image":
        card = ui.image_card(box="first", title="", type="png", image=plot.data)
    elif plot.encoding == "html":
        card = ui.markup_card(box="first", title="", content=plot.data)
    elif plot.encoding == "df":
        df = pd.read_parquet(plot.data)
        df = df.iloc[:2000]
        min_widths = {"Content": "800"}
        card = ui.form_card(
            box="first",
            items=[
                ui_table_from_df(
                    q=q,
                    df=df,
                    name="dataset/display/visualization/table",
                    markdown_cells=list(df.columns),
                    searchables=list(df.columns),
                    downloadable=True,
                    resettable=True,
                    min_widths=min_widths,
                    height="calc(100vh - 245px)",
                    max_char_length=50_000,
                    cell_overflow="tooltip",
                )
            ],
        )

    else:
        raise ValueError(f"Unknown plot encoding `{plot.encoding}`")
    q.page["dataset/display/visualization"] = card
    q.client.delete_cards.add("dataset/display/visualization")


async def show_summary_tab(q, dataset_id):
    dataset_df = get_datasets(q)
    dataset_df = dataset_df[dataset_df.id == dataset_id]
    stat_list_items: List[StatListItem] = []
    for col in dataset_df.columns:
        if col in ["id", "config_file", "path", "process_id", "status"]:
            continue
        v = dataset_df[col].values[0]
        t: StatListItem = ui.stat_list_item(label=make_label(col), value=str(v))

        stat_list_items.append(t)
    q.page["dataset/display/summary"] = ui.stat_list_card(
        box="first", items=stat_list_items, title=""
    )
    q.client.delete_cards.add("dataset/display/summary")


async def show_statistics_tab(q, dataset_filename, config_filename):
    cfg_hash = hashlib.md5(open(config_filename, "rb").read()).hexdigest()
    stats_dict = compute_dataset_statistics(dataset_filename, config_filename, cfg_hash)

    for chat_type in ["prompts", "answers"]:
        q.page[f"dataset/display/statistics/{chat_type}_histogram"] = histogram_card(
            x=stats_dict[chat_type],
            x_axis_description=f"text_length_{chat_type.capitalize()}",
            title=f"Text Length Distribution for {chat_type.capitalize()}"
            f" (split by whitespace)",
            histogram_box="first",
        )
        q.client.delete_cards.add(f"dataset/display/statistics/{chat_type}_histogram")

    q.page["dataset/display/statistics/full_conversation_histogram"] = histogram_card(
        x=stats_dict["complete_conversations"],
        x_axis_description="text_length_complete_conversations",
        title="Text Length Distribution for complete "
        "conversations (split by whitespace)",
        histogram_box="second",
    )
    q.client.delete_cards.add("dataset/display/statistics/full_conversation_histogram")

    if len(set(stats_dict["number_of_prompts"])) > 1:
        q.page["dataset/display/statistics/parent_id_length_histogram"] = (
            histogram_card(
                x=stats_dict["number_of_prompts"],
                x_axis_description="number_of_prompts",
                title="Distribution of number of prompt-answer turns per conversation.",
                histogram_box="second",
            )
        )
        q.client.delete_cards.add(
            "dataset/display/statistics/parent_id_length_histogram"
        )

    df_stats = stats_dict["df_stats"]
    if df_stats is None:
        component_items = [
            ui.text(
                "Dataset does not contain numerical or text features. "
                "No statistics available."
            )
        ]
    else:
        if df_stats.shape[1] > 5:  # mixed text and numeric
            widths = {col: "77" for col in df_stats}
        else:  # only text features
            widths = None
        component_items = [
            ui_table_from_df(
                q=q,
                df=df_stats,
                name="dataset/display/statistics/table",
                sortables=list(df_stats.columns),
                min_widths=widths,
                height="265px",
            )
        ]
    q.page["dataset/display/statistics"] = ui.form_card(
        box="third",
        items=component_items,
    )
    q.client.delete_cards.add("dataset/display/statistics")


@functools.lru_cache()
def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str):
    """
    Compute various statistics for a dataset.
    - text length distribution for prompts and answers
    - text length distribution for complete conversations
    - distribution of number of prompt-answer turns per conversation
    - statistics for non text features

    We use LRU caching to avoid recomputing the statistics for the same dataset.
    Thus, cfg_hash is used as a function argument to identify the dataset.
    """
    df_train = read_dataframe(dataset_path)
    cfg = load_config_yaml(cfg_path)
    conversations = get_conversation_chains(
        df=df_train, cfg=cfg, limit_chained_samples=True
    )
    stats_dict = {}
    for chat_type in ["prompts", "answers"]:
        text_lengths = [
            [len(text.split(" ")) for text in conversation[chat_type]]
            for conversation in conversations
        ]
        text_lengths = [item for sublist in text_lengths for item in sublist]
        stats_dict[chat_type] = text_lengths
    input_texts = []
    for conversation in conversations:
        input_text = conversation["systems"][0]
        prompts = conversation["prompts"]
        answers = conversation["answers"]
        for prompt, answer in zip(prompts, answers):
            input_text += prompt + answer
        input_texts += [input_text]
    stats_dict["complete_conversations"] = [
        len(text.split(" ")) for text in input_texts
    ]
    stats_dict["number_of_prompts"] = [
        len(conversation["prompts"]) for conversation in conversations
    ]
    stats_dict["df_stats"] = get_frame_stats(df_train)
    return stats_dict


async def dataset_import_uploaded_file(q: Q):
    local_path = await q.site.download(
        q.args["dataset/import/local_upload"][0],
        f"{get_data_dir(q)}/"
        f'{q.args["dataset/import/local_upload"][0].split("/")[-1]}',
    )
    await q.site.unload(q.args["dataset/import/local_upload"][0])
    valid, error = check_valid_upload_content(local_path)
    if valid:
        q.args["dataset/import/local_path"] = local_path
        q.client["dataset/import/local_path"] = q.args["dataset/import/local_path"]
        await dataset_import(q, step=2)
    else:
        await dataset_import(q, step=1, error=error)


async def dataset_delete_current_datasets(q: Q):
    dataset_ids = list(
        q.client["dataset/list/df_datasets"]["id"].iloc[
            list(map(int, q.client["dataset/list/table"]))
        ]
    )
    await dataset_delete(q, dataset_ids)
    await dataset_list(q)