import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gc
import sys
from diffusers import FluxPipeline
import time
from sentence_transformers import SentenceTransformer
import psutil
import json
import spaces
from threading import Thread
#-----------------
from relatively_constant_variables import knowledge_base

# Initialize the zero tensor on CUDA
zero = torch.Tensor([0]).cuda()
print(zero.device)  # This will print 'cpu' outside the @spaces.GPU decorated function

modelnames = ["stvlynn/Gemma-2-2b-Chinese-it", "unsloth/Llama-3.2-1B-Instruct", "unsloth/Llama-3.2-3B-Instruct", "nbeerbower/mistral-nemo-wissenschaft-12B", "princeton-nlp/gemma-2-9b-it-SimPO", "cognitivecomputations/dolphin-2.9.3-mistral-7B-32k", "01-ai/Yi-Coder-9B-Chat", "ArliAI/Llama-3.1-8B-ArliAI-RPMax-v1.1", "ArliAI/Phi-3.5-mini-3.8B-ArliAI-RPMax-v1.1", 
              "Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2-0.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-7B-Instruct", "Qwen/Qwen1.5-MoE-A2.7B-Chat", "HuggingFaceTB/SmolLM-135M-Instruct", "microsoft/Phi-3-mini-4k-instruct", "Groq/Llama-3-Groq-8B-Tool-Use", "hugging-quants/Meta-Llama-3.1-8B-Instruct-BNB-NF4", 
              "SpectraSuite/TriLM_3.9B_Unpacked", "h2oai/h2o-danube3-500m-chat", "OuteAI/Lite-Mistral-150M-v2-Instruct", "Zyphra/Zamba2-1.2B", "anthracite-org/magnum-v2-4b", ]

imagemodelnames = ["black-forest-labs/FLUX.1-schnell"]

current_model_index = 0
current_image_model_index = 0
modelname = modelnames[current_model_index]
imagemodelname = imagemodelnames[current_image_model_index]
lastmodelnameinloadfunction = None
lastimagemodelnameinloadfunction = None

# Load the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize model and tokenizer as global variables
model = None
tokenizer = None
flux_pipe = None

# Dictionary to store loaded models
loaded_models = {}

def get_size_str(bytes):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if bytes < 1024:
            return f"{bytes:.2f} {unit}"
        bytes /= 1024

def load_model(model_name):
    global model, tokenizer, lastmodelnameinloadfunction, loaded_models

    print(f"Loading model and tokenizer: {model_name}")
    
    # Record initial GPU memory usage
    initial_memory = torch.cuda.memory_allocated()

    # Clear old model and tokenizer if they exist
    if 'model' in globals() and model is not None:
        model = None
    if 'tokenizer' in globals() and tokenizer is not None:
        tokenizer = None
    
    torch.cuda.empty_cache()
    gc.collect()

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model_size = sum(p.numel() * p.element_size() for p in model.parameters())
    tokenizer_size = sum(sys.getsizeof(v) for v in tokenizer.__dict__.values())
    loaded_models[model_name] = (model, tokenizer)

    # Calculate memory usage
    final_memory = torch.cuda.memory_allocated()
    memory_used = final_memory - initial_memory

    loaded_models[model_name] = [str(time.time()), memory_used]

    lastmodelnameinloadfunction = (model_name, model_size, tokenizer_size)
    print(f"Model and tokenizer {model_name} loaded successfully")
    print(f"Model size: {get_size_str(model_size)}")
    print(f"Tokenizer size: {get_size_str(tokenizer_size)}")
    print(f"GPU memory used: {get_size_str(memory_used)}")

    return (f"Model and tokenizer {model_name} loaded successfully. "
            f"Model size: {get_size_str(model_size)}, "
            f"Tokenizer size: {get_size_str(tokenizer_size)}, "
            f"GPU memory used: {get_size_str(memory_used)}")

def load_image_model(imagemodelname):
    global flux_pipe, lastimagemodelnameinloadfunction, loaded_models

    print(f"Loading image model: {imagemodelname}")
    
    # Record initial GPU memory usage
    initial_memory = torch.cuda.memory_allocated()

    if 'flux_pipe' in globals() and flux_pipe is not None:
        flux_pipe = None
    
    torch.cuda.empty_cache()
    gc.collect()

    flux_pipe = FluxPipeline.from_pretrained(imagemodelname, torch_dtype=torch.bfloat16)
    flux_pipe.enable_model_cpu_offload()
    model_size = sum(p.numel() * p.element_size() for p in flux_pipe.transformer.parameters())
    #tokenizer_size = 0  # FLUX doesn't use a separate tokenizer
    loaded_models[imagemodelname] = flux_pipe

    # Calculate memory usage
    final_memory = torch.cuda.memory_allocated()
    memory_used = final_memory - initial_memory

    loaded_models[imagemodelname] = [str(time.time()), memory_used]

    lastimagemodelnameinloadfunction = (imagemodelname, model_size) #, tokenizer_size)
    print(f"Model and tokenizer {imagemodelname} loaded successfully")
    print(f"Model size: {get_size_str(model_size)}")
    #print(f"Tokenizer size: {get_size_str(tokenizer_size)}")
    print(f"GPU memory used: {get_size_str(memory_used)}")

    return (f"Model and tokenizer {imagemodelname} loaded successfully. "
            f"Model size: {get_size_str(model_size)}, "
            #f"Tokenizer size: {get_size_str(tokenizer_size)}, "
            f"GPU memory used: {get_size_str(memory_used)}")


def clear_all_models():
    global model, tokenizer, flux_pipe, loaded_models
    for model_name, model_obj in loaded_models.items():
        if isinstance(model_obj, tuple):
            model_obj[0].to('cpu')
            del model_obj[0]
            del model_obj[1]
        else:
            model_obj.to('cpu')
            del model_obj
    model = None
    tokenizer = None
    flux_pipe = None
    loaded_models.clear()
    torch.cuda.empty_cache()
    gc.collect()
    return "All models cleared from memory."

def load_model_list(model_list):
    messages = []
    for model_name in model_list:
        message = load_model(model_name)
        messages.append(message)
    return "\n".join(messages)

def loaded_model_list():
    global loaded_models
    return loaded_models


# Initial model load
load_model(modelname)
load_image_model(imagemodelname)

# Create embeddings for the knowledge base
knowledge_base_embeddings = embedding_model.encode([doc["content"] for doc in knowledge_base])

def retrieve(query, k=2):
    query_embedding = embedding_model.encode([query])
    similarities = torch.nn.functional.cosine_similarity(torch.tensor(query_embedding), torch.tensor(knowledge_base_embeddings))
    top_k_indices = similarities.argsort(descending=True)[:k]
    return [(knowledge_base[i]["content"], knowledge_base[i]["id"]) for i in top_k_indices]

def get_ram_usage():
    ram = psutil.virtual_memory()
    return f"RAM Usage: {ram.percent:.2f}%, Available: {ram.available / (1024 ** 3):.2f}GB, Total: {ram.total / (1024 ** 3):.2f}GB"

# Global dictionary to store outputs
output_dict = {}

def empty_output_dict():
    global output_dict
    output_dict = {}
    print("Output dictionary has been emptied.")

def get_model_details(model):
    return {
        "name": model.config.name_or_path,
        "architecture": model.config.architectures[0] if model.config.architectures else "Unknown",
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }

def get_tokenizer_details(tokenizer):
    return {
        "name": tokenizer.__class__.__name__,
        "vocab_size": tokenizer.vocab_size,
        "model_max_length": tokenizer.model_max_length,
    }

@spaces.GPU
def generate_response(prompt, use_rag, stream=False):
    global output_dict, model, tokenizer
    
    print(zero.device)  # This will print 'cuda:0' inside the @spaces.GPU decorated function
    torch.cuda.empty_cache()
    print(dir(model))

    if use_rag:
        retrieved_docs = retrieve(prompt)
        context = " ".join([doc for doc, _ in retrieved_docs])
        doc_ids = [doc_id for _, doc_id in retrieved_docs]
        full_prompt = f"Context: {context}\nQuestion: {prompt}\nAnswer:"
    else:
        full_prompt = prompt
        doc_ids = None
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": full_prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(zero.device)
    start_time = time.time()
    total_tokens = 0
    
    print(output_dict)
    output_key = f"output_{len(output_dict) + 1}"
    print(output_key)
    output_dict[output_key] = {
        "input_prompt": prompt,
        "full_prompt": full_prompt,
        "use_rag": use_rag,
        "generated_text": "",
        "tokens_per_second": 0,
        "ram_usage": "",
        "doc_ids": doc_ids if doc_ids else "N/A",
        "model_details": get_model_details(model),
        "tokenizer_details": get_tokenizer_details(tokenizer),
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start_time))
    }
    print(output_dict)

    if stream:
        streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
        generation_kwargs = dict(
            model_inputs,
            streamer=streamer,
            max_new_tokens=512,
            temperature=0.7,
        )
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        for new_text in streamer:
            output_dict[output_key]["generated_text"] += new_text
            total_tokens += 1
            current_time = time.time()
            tokens_per_second = total_tokens / (current_time - start_time)
            ram_usage = get_ram_usage()
            output_dict[output_key]["tokens_per_second"] = f"{tokens_per_second:.2f}"
            output_dict[output_key]["ram_usage"] = ram_usage
            yield (output_dict[output_key]["generated_text"], 
                output_dict[output_key]["tokens_per_second"], 
                output_dict[output_key]["ram_usage"], 
                output_dict[output_key]["doc_ids"])
    else:
        generated_ids = model.generate(
            model_inputs.input_ids,
            max_new_tokens=512
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        total_tokens = len(generated_ids[0])
        end_time = time.time()
        tokens_per_second = total_tokens / (end_time - start_time)
        ram_usage = get_ram_usage()
        
        output_dict[output_key]["generated_text"] = response
        output_dict[output_key]["tokens_per_second"] = f"{tokens_per_second:.2f}"
        output_dict[output_key]["ram_usage"] = ram_usage
        print(output_dict)

        yield (output_dict[output_key]["generated_text"], 
            output_dict[output_key]["tokens_per_second"], 
            output_dict[output_key]["ram_usage"], 
            output_dict[output_key]["doc_ids"])
            
@spaces.GPU
def generate_image(prompt):
    global output_dict, flux_pipe
    
    print(dir(flux_pipe))

    # Generate image using FLUX
    image = flux_pipe(
        prompt,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]
    image_path = f"flux_output_{time.time()}.png"
    print(image_path)
    image.save(image_path)
    ram_usage = get_ram_usage()
    return image_path, ram_usage, image_path

def get_output_details(output_key):
    if output_key in output_dict:
        return output_dict[output_key]
    else:
        return f"No output found for key: {output_key}"

# Update the switch_model function to return the load_model message
def switch_model(choice):
    global modelname
    modelname = choice
    load_message = load_model(modelname)
    return load_message, f"Current model: {modelname}"

# Update the model_change_handler function
def model_change_handler(choice):
    message, current_model = switch_model(choice)
    return message, current_model, message  # Use the same message for both outputs

def format_output_dict():
    global output_dict
    formatted_output = ""
    for key, value in output_dict.items():
        formatted_output += f"Key: {key}\n"
        formatted_output += json.dumps(value, indent=2)
        formatted_output += "\n\n"
    print(formatted_output)
    return formatted_output