import streamlit as st
# from gliner import GLiNER
from datasets import load_dataset
import evaluate
import numpy as np
import threading
import time
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model, TaskType
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments


seqeval = evaluate.load("seqeval")

# id2label = {0: "O"}
# label2id = {"O": 0}
# def build_id2label(examples):
#     for i, label in enumerate(examples["mbert_token_classes"]):
#         if label.startswith("I-") and label not in label2id:
#             current_len = len(id2label)
#             id2label[current_len] = label
#             label2id[label] = current_len

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# Load the fine-tuned GLiNER model
st.write('Loading the pretrained model ...')
model_name = "iiiorg/piiranha-v1-detect-personal-information"
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

print(model)

# Prepare model for LoRA training
model.train() # model in evaluation mode (dropout modules are activated)
# enable gradient check pointing
model.gradient_checkpointing_enable()

# enable quantized training
model = prepare_model_for_kbit_training(model)

# LoRA config
config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["query_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.TOKEN_CLS
)

# LoRA trainable version of model
model = get_peft_model(model, config)

print(model)
# trainable parameter count
model.print_trainable_parameters()

# # print weights
# pytorch_total_params = sum(p.numel() for p in model.parameters())
# torch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f'total params: {pytorch_total_params}. tunable params: {torch_total_params}')

if torch.cuda.is_available():
    model = model.to("cuda")

# Load data. 
raw_dataset = load_dataset("ai4privacy/pii-masking-400k", split='train[1:1000]')
# raw_dataset = raw_dataset.filter(lambda example: example["language"].startswith("en"))
raw_dataset = raw_dataset.train_test_split(test_size=0.2)
print(raw_dataset)
print(raw_dataset.column_names)
# raw_dataset = raw_dataset.select_columns(["mbert_tokens"])
# raw_dataset = raw_dataset.rename_column("mbert_tokens", "tokens")
# raw_dataset = raw_dataset.rename_column("mbert_token_classes", "labels")

# inputs = tokenizer(
#     raw_dataset['train'][0]['mbert_tokens'],
#     truncation=True,
#     is_split_into_words=True)
# print(inputs)
# print(inputs.tokens())
# print(inputs.word_ids())

# Build label2id and id2label
st.write("Building label mappings")
label2id = model.config.label2id
id2label = model.config.id2label 
# raw_dataset.map(
#     build_id2label,
#     batched=False)

st.write("id2label: ", model.config.id2label)
st.write("label2id: ", model.config.label2id)

# function to align labels with tokens 
# --> special tokens: -100 label id (ignored by cross entropy),
# --> if tokens are inside a word, replace 'B-' with 'I-' 
def align_labels_with_tokens(labels):
    aligned_label_ids = []
    aligned_label_ids.append(-100)
    for i, label in enumerate(labels):
        if label.startswith("B-"):
            label = label.replace("B-", "I-")
        aligned_label_ids.append(label2id[label])
    aligned_label_ids.append(-100)
    return aligned_label_ids

# create tokenize function
def tokenize_function(examples):
    # tokenize and truncate text. The examples argument would have already stripped
    # the train or test label.
    new_labels = []
    inputs = tokenizer(
        examples['mbert_tokens'],
        is_split_into_words=True,
        padding=True,
        truncation=True,
        max_length=512)
    for _, labels in enumerate(examples['mbert_token_classes']):
        new_labels.append(align_labels_with_tokens(labels))

    inputs["labels"] = new_labels
    return inputs

# tokenize training and validation datasets
tokenized_data = raw_dataset.map(
    tokenize_function,
    batched=True)
# data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

st.write(tokenized_data["train"][:2]["labels"])

import os

# Print all CUDA environment variables
for key, value in os.environ.items():
    if "CUDA" in key.upper():
        print(f"{key}={value}")

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

# hyperparameters
lr = 2e-4
batch_size = 4
num_epochs = 4
output_dir = "xia-lora-deberta-v2" 

# define training arguments
training_args = TrainingArguments(
    output_dir= output_dir,
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    fp16=True,
    optim="paged_adamw_8bit",
)

# configure trainer
trainer = Trainer(
    model=model,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# train model
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

# renable warnings
model.config.use_cache = True

st.write('Pushing model to huggingface')

# Push model to huggingface
hf_name = 'CarolXia' # your hf username or org name
model_id = hf_name + "/" + output_dir
model.push_to_hub(model_id, token=st.secrets["HUGGINGFACE_TOKEN"])
trainer.push_to_hub(model_id, token=st.secrets["HUGGINGFACE_TOKEN"])