import subprocess
import sys
import time
from typing import List

from distilabel.steps.generators.data import LoadDataFromDicts
from distilabel.steps.expand import ExpandColumns
from distilabel.steps.keep import KeepColumns
from distilabel.steps.tasks.self_instruct import SelfInstruct
from distilabel.steps.tasks.evol_instruct.base import EvolInstruct
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import TextGenerationToArgilla
from dotenv import load_dotenv

from domain import (
    DomainExpert,
    CleanNumberedList,
    create_topics,
    create_examples_template,
    APPLICATION_DESCRIPTION,
)

load_dotenv()


def define_pipeline(
    argilla_api_key: str,
    argilla_api_url: str,
    argilla_dataset_name: str,
    topics: List[str],
    perspectives: List[str],
    domain_expert_prompt: str,
    examples: List[dict],
    hub_token: str,
    endpoint_base_url: str,
):
    """Define the pipeline for the specific domain."""

    terms = create_topics(topics, perspectives)
    template = create_examples_template(examples)
    with Pipeline("farming") as pipeline:
        load_data = LoadDataFromDicts(
            name="load_data",
            data=[{"input": term} for term in terms],
            batch_size=64,
        )
        llm = InferenceEndpointsLLM(
            base_url=endpoint_base_url,
            api_key=hub_token,
        )
        self_instruct = SelfInstruct(
            name="self-instruct",
            application_description=APPLICATION_DESCRIPTION,
            num_instructions=5,
            input_batch_size=8,
            llm=llm,
        )

        evol_instruction_complexity = EvolInstruct(
            name="evol_instruction_complexity",
            llm=llm,
            num_evolutions=2,
            store_evolutions=True,
            input_batch_size=8,
            include_original_instruction=True,
            input_mappings={"instruction": "question"},
        )

        expand_instructions = ExpandColumns(
            name="expand_columns", columns={"instructions": "question"}
        )
        cleaner = CleanNumberedList(name="clean_numbered_list")
        expand_evolutions = ExpandColumns(
            name="expand_columns_evolved",
            columns={"evolved_instructions": "evolved_questions"},
        )

        domain_expert = DomainExpert(
            name="domain_expert",
            llm=llm,
            input_batch_size=8,
            input_mappings={"instruction": "evolved_questions"},
            output_mappings={"generation": "domain_expert_answer"},
        )

        domain_expert._system_prompt = domain_expert_prompt
        domain_expert._template = template

        keep_columns = KeepColumns(
            name="keep_columns",
            columns=["model_name", "evolved_questions", "domain_expert_answer"],
        )

        to_argilla = TextGenerationToArgilla(
            name="text_generation_to_argilla",
            dataset_name=argilla_dataset_name,
            dataset_workspace="admin",
            api_url=argilla_api_url,
            api_key=argilla_api_key,
            input_mappings={
                "instruction": "evolved_questions",
                "generation": "domain_expert_answer",
            },
        )

        load_data.connect(self_instruct)
        self_instruct.connect(expand_instructions)
        expand_instructions.connect(cleaner)
        cleaner.connect(evol_instruction_complexity)
        evol_instruction_complexity.connect(expand_evolutions)
        expand_evolutions.connect(domain_expert)
        domain_expert.connect(keep_columns)
        keep_columns.connect(to_argilla)
    return pipeline


def serialize_pipeline(
    argilla_api_key: str,
    argilla_api_url: str,
    argilla_dataset_name: str,
    topics: List[str],
    perspectives: List[str],
    domain_expert_prompt: str,
    hub_token: str,
    endpoint_base_url: str,
    pipeline_config_path: str = "pipeline.yaml",
    examples: List[dict] = [],
):
    """Serialize the pipeline to a yaml file."""
    pipeline = define_pipeline(
        argilla_api_key=argilla_api_key,
        argilla_api_url=argilla_api_url,
        argilla_dataset_name=argilla_dataset_name,
        topics=topics,
        perspectives=perspectives,
        domain_expert_prompt=domain_expert_prompt,
        hub_token=hub_token,
        endpoint_base_url=endpoint_base_url,
        examples=examples,
    )
    pipeline.save(path=pipeline_config_path, overwrite=True, format="yaml")


def create_pipelines_run_command(
    hub_token: str,
    argilla_api_key: str,
    argilla_api_url: str,
    pipeline_config_path: str = "pipeline.yaml",
    argilla_dataset_name: str = "domain_specific_datasets",
):
    """Create the command to run the pipeline."""
    command_to_run = [
        sys.executable,
        "-m",
        "distilabel",
        "pipeline",
        "run",
        "--config",
        pipeline_config_path,
        "--param",
        f"text_generation_to_argilla.dataset_name={argilla_dataset_name}",
        "--param",
        f"text_generation_to_argilla.api_key={argilla_api_key}",
        "--param",
        f"text_generation_to_argilla.api_url={argilla_api_url}",
        "--param",
        f"self-instruct.llm.api_key={hub_token}",
        "--param",
        f"evol_instruction_complexity.llm.api_key={hub_token}",
        "--param",
        f"domain_expert.llm.api_key={hub_token}",
        "--ignore-cache",
    ]
    return command_to_run


def run_pipeline(
    hub_token: str,
    argilla_api_key: str,
    argilla_api_url: str,
    pipeline_config_path: str = "pipeline.yaml",
    argilla_dataset_name: str = "domain_specific_datasets",
):
    """Run the pipeline and yield the output as a generator of logs."""

    command_to_run = create_pipelines_run_command(
        hub_token=hub_token,
        pipeline_config_path=pipeline_config_path,
        argilla_dataset_name=argilla_dataset_name,
        argilla_api_key=argilla_api_key,
        argilla_api_url=argilla_api_url,
    )

    # Run the script file
    process = subprocess.Popen(
        args=command_to_run,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        env={"HF_TOKEN": hub_token},
    )

    while process.stdout and process.stdout.readable():
        time.sleep(0.2)
        line = process.stdout.readline()
        if not line:
            break
        yield line.decode("utf-8")