# interface.py

# Importar 'spaces' y decoradores antes que cualquier biblioteca que pueda inicializar CUDA
from decorators import gpu_decorator

# Luego importar cualquier cosa relacionada con PyTorch o el modelo que va a usar la GPU
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
from sympy import symbols, lambdify, sympify

# Importar otras partes necesarias del código (config, etc.)
from config import DEVICE, MODEL_PATH, MAX_LENGTH, TEMPERATURE

# Cargar el modelo fuera de la función para evitar la inicialización innecesaria cada vez que se llame a la función
model_path = MODEL_PATH
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

###############################

# bioprocess_model.py

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import odeint
from scipy.optimize import curve_fit
from sklearn.metrics import mean_squared_error
import seaborn as sns
from sympy import symbols, lambdify, sympify

class BioprocessModel:
    def __init__(self):
        self.params = {}
        self.r2 = {}
        self.rmse = {}
        self.datax = []
        self.datas = []
        self.datap = []
        self.dataxp = []
        self.datasp = []
        self.datapp = []
        self.datax_std = []
        self.datas_std = []
        self.datap_std = []
        self.models = {}  # Initialize the models dictionary

    @staticmethod
    def logistic(time, xo, xm, um):
        return (xo * np.exp(um * time)) / (1 - (xo / xm) * (1 - np.exp(um * time)))

    @staticmethod
    def substrate(time, so, p, q, xo, xm, um):
        return so - (p * xo * ((np.exp(um * time)) / (1 - (xo / xm) * (1 - np.exp(um * time))) - 1)) - \
               (q * (xm / um) * np.log(1 - (xo / xm) * (1 - np.exp(um * time))))

    @staticmethod
    def product(time, po, alpha, beta, xo, xm, um):
        return po + (alpha * xo * ((np.exp(um * time) / (1 - (xo / xm) * (1 - np.exp(um * time)))) - 1)) + \
               (beta * (xm / um) * np.log(1 - (xo / xm) * (1 - np.exp(um * time))))

    @staticmethod
    def logistic_diff(X, t, params):
        xo, xm, um = params
        dXdt = um * X * (1 - X / xm)
        return dXdt

    def substrate_diff(self, S, t, params, biomass_params, X_func):
        so, p, q = params
        xo, xm, um = biomass_params
        X_t = X_func(t)
        dSdt = -p * (um * X_t * (1 - X_t / xm)) - q * X_t
        return dSdt

    def product_diff(self, P, t, params, biomass_params, X_func):
        po, alpha, beta = params
        xo, xm, um = biomass_params
        X_t = X_func(t)
        dPdt = alpha * (um * X_t * (1 - X_t / xm)) + beta * X_t
        return dPdt

    def process_data(self, df):
        biomass_cols = [col for col in df.columns if 'Biomasa' in col]
        substrate_cols = [col for col in df.columns if 'Sustrato' in col]
        product_cols = [col for col in df.columns if 'Producto' in col]

        time_col = [col for col in df.columns if 'Tiempo' in col][0]
        time = df[time_col].values

        data_biomass = np.array([df[col].values for col in biomass_cols])
        self.datax.append(data_biomass)
        self.dataxp.append(np.mean(data_biomass, axis=0))
        self.datax_std.append(np.std(data_biomass, axis=0, ddof=1))

        data_substrate = np.array([df[col].values for col in substrate_cols])
        self.datas.append(data_substrate)
        self.datasp.append(np.mean(data_substrate, axis=0))
        self.datas_std.append(np.std(data_substrate, axis=0, ddof=1))

        data_product = np.array([df[col].values for col in product_cols])
        self.datap.append(data_product)
        self.datapp.append(np.mean(data_product, axis=0))
        self.datap_std.append(np.std(data_product, axis=0, ddof=1))

        self.time = time

    def set_model(self, model_type, equation, params_str):
        """
        Sets up the model based on the type, equation, and parameters.
        
        :param model_type: Type of the model ('biomass', 'substrate', 'product')
        :param equation: The equation as a string
        :param params_str: Comma-separated string of parameter names
        """
        t_symbol = symbols('t')
        expr = sympify(equation)
        params = [param.strip() for param in params_str.split(',')]
        params_symbols = symbols(params)
        
        # Extraer símbolos utilizados en la expresión
        used_symbols = expr.free_symbols
        # Convertir símbolos a strings
        used_params = [str(s) for s in used_symbols if s != t_symbol]
        
        # Verificar que todos los parámetros en params_str estén usados en la ecuación
        for param in params:
            if param not in used_params:
                raise ValueError(f"El parámetro '{param}' no se usa en la ecuación '{equation}'.")

        if model_type == 'biomass':
            # Biomasa como función de tiempo y parámetros
            func_expr = expr
            func = lambdify((t_symbol, *params_symbols), func_expr, 'numpy')
            self.models['biomass'] = {
                'function': func,
                'params': params
            }
        elif model_type in ['substrate', 'product']:
            # Estos modelos dependen de biomasa, que ya debería estar establecida
            if 'biomass' not in self.models:
                raise ValueError("Biomasa debe estar configurada antes de Sustrato o Producto.")
            biomass_func = self.models['biomass']['function']
            # Reemplazar 'X(t)' por la función de biomasa
            func_expr = expr.subs('X(t)', biomass_func)
            func = lambdify((t_symbol, *params_symbols), func_expr, 'numpy')
            self.models[model_type] = {
                'function': func,
                'params': params
            }
        else:
            raise ValueError(f"Tipo de modelo no soportado: {model_type}")

    def fit_model(self, model_type, time, data, bounds=([-np.inf], [np.inf])):
        """
        Fits the model to the data.
        
        :param model_type: Type of the model ('biomass', 'substrate', 'product')
        :param time: Time data
        :param data: Observed data to fit
        :param bounds: Bounds for the parameters
        :return: Predicted data from the model
        """
        if model_type not in self.models:
            raise ValueError(f"Model type '{model_type}' is not set. Please use set_model first.")

        func = self.models[model_type]['function']
        params = self.models[model_type]['params']
        
        # Depuración: Asegurarse de que los parámetros estén bien definidos
        print(f"Fitting {model_type} model with function: {func} and parameters: {params}")
        
        # Definir la función de ajuste (asegurarse de que toma los parámetros correctamente)
        def fit_func(t, *args):
            try:
                return func(t, *args)
            except Exception as e:
                print(f"Error in fit_func: {e}")
                raise

        # Depuración: Verificar el número de parámetros que se espera ajustar
        print(f"Number of parameters to fit: {len(params)}")

        try:
            # Verifica que curve_fit puede recibir la función correctamente
            print(f"Calling curve_fit with time: {time}, data: {data}, bounds: {bounds}")

            # Intentar ajustar el modelo usando curve_fit
            popt, _ = curve_fit(fit_func, time, data, bounds=bounds, maxfev=10000)
            print(f"Optimal parameters found: {popt}")

            # Guardar los parámetros ajustados en el modelo
            self.params[model_type] = {param: val for param, val in zip(params, popt)}
            y_pred = fit_func(time, *popt)
            self.r2[model_type] = 1 - (np.sum((data - y_pred) ** 2) / np.sum((data - np.mean(data)) ** 2))
            self.rmse[model_type] = np.sqrt(mean_squared_error(data, y_pred))
            return y_pred
        except Exception as e:
            print(f"Error while fitting {model_type} model: {str(e)}")
            raise

    def plot_combined_results(self, time, biomass, substrate, product,
                              y_pred_biomass, y_pred_substrate, y_pred_product,
                              biomass_std=None, substrate_std=None, product_std=None,
                              experiment_name='', legend_position='best', params_position='upper right',
                              show_legend=True, show_params=True,
                              style='whitegrid', line_color='#0000FF', point_color='#000000',
                              line_style='-', marker_style='o'):
        sns.set_style(style)

        fig, ax1 = plt.subplots(figsize=(10, 7))
        ax1.set_xlabel('Tiempo')
        ax1.set_ylabel('Biomasa', color=line_color)

        ax1.plot(time, biomass, marker=marker_style, linestyle='', color=point_color, label='Biomasa (Datos)')
        ax1.plot(time, y_pred_biomass, linestyle=line_style, color=line_color, label='Biomasa (Modelo)')
        ax1.tick_params(axis='y', labelcolor=line_color)

        ax2 = ax1.twinx()
        ax2.set_ylabel('Sustrato', color='green')
        ax2.plot(time, substrate, marker=marker_style, linestyle='', color='green', label='Sustrato (Datos)')
        ax2.plot(time, y_pred_substrate, linestyle=line_style, color='green', label='Sustrato (Modelo)')
        ax2.tick_params(axis='y', labelcolor='green')

        ax3 = ax1.twinx()
        ax3.spines["right"].set_position(("axes", 1.1))
        ax3.set_ylabel('Producto', color='red')
        ax3.plot(time, product, marker=marker_style, linestyle='', color='red', label='Producto (Datos)')
        ax3.plot(time, y_pred_product, linestyle=line_style, color='red', label='Producto (Modelo)')
        ax3.tick_params(axis='y', labelcolor='red')

        fig.tight_layout()
        return fig

###############################

# Decorador GPU aplicado para manejar la ejecución en GPU si está disponible
@gpu_decorator(duration=300)
def generate_analysis(prompt, max_length=1024, device=None):
    try:
        # Si el dispositivo no se especifica, usa CPU por defecto
        if device is None:
            device = torch.device('cpu')
        
        # Mover el modelo al dispositivo adecuado (GPU o CPU) si es necesario
        if next(model.parameters()).device != device:
            model.to(device)
        
        # Preparar los datos de entrada en el dispositivo correcto
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        max_gen_length = min(max_length + input_ids.size(1), model.config.max_position_embeddings)

        # Generar el texto
        generated_ids = model.generate(
            input_ids=input_ids,
            max_length=max_gen_length,
            temperature=0.7,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            early_stopping=True
        )

        # Decodificar la respuesta generada
        output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        analysis = output_text[len(prompt):].strip()
        return analysis
    except RuntimeError as e:
        return f"Error durante la ejecución: {str(e)}"
    except Exception as e:
        return f"Ocurrió un error durante el análisis: {e}"

def parse_bounds(bounds_str, num_params):
    try:
        # Reemplazar 'inf' por 'np.inf' si el usuario lo escribió así
        bounds_str = bounds_str.replace('inf', 'np.inf')
        bounds = eval(f"[{bounds_str}]")
        if len(bounds) != num_params:
            raise ValueError("Número de límites no coincide con el número de parámetros.")
        lower_bounds = [b[0] for b in bounds]
        upper_bounds = [b[1] for b in bounds]
        return lower_bounds, upper_bounds
    except Exception as e:
        print(f"Error al parsear los límites: {e}. Usando límites por defecto.")
        lower_bounds = [-np.inf] * num_params
        upper_bounds = [np.inf] * num_params
        return lower_bounds, upper_bounds

def process_and_plot(
    file,
    biomass_eq1, biomass_eq2, biomass_eq3,
    biomass_param1, biomass_param2, biomass_param3,
    biomass_bound1, biomass_bound2, biomass_bound3,
    substrate_eq1, substrate_eq2, substrate_eq3,
    substrate_param1, substrate_param2, substrate_param3,
    substrate_bound1, substrate_bound2, substrate_bound3,
    product_eq1, product_eq2, product_eq3,
    product_param1, product_param2, product_param3,
    product_bound1, product_bound2, product_bound3,
    legend_position,
    show_legend,
    show_params,
    biomass_eq_count,
    substrate_eq_count,
    product_eq_count,
    device=None
):
    # Leer el archivo Excel
    df = pd.read_excel(file.name)
    
    # Verificar que las columnas necesarias estén presentes
    expected_columns = ['Tiempo', 'Biomasa', 'Sustrato', 'Producto']
    for col in expected_columns:
        if col not in df.columns:
            raise KeyError(f"La columna esperada '{col}' no se encuentra en el archivo Excel.")

    # Asignar los datos desde las columnas
    time = df['Tiempo'].values
    biomass_data = df['Biomasa'].values
    substrate_data = df['Sustrato'].values
    product_data = df['Producto'].values

    # Convierte los contadores a enteros
    biomass_eq_count = int(biomass_eq_count)
    substrate_eq_count = int(substrate_eq_count)
    product_eq_count = int(product_eq_count)

    # Recolecta las ecuaciones, parámetros y límites según los contadores
    biomass_eqs = [biomass_eq1, biomass_eq2, biomass_eq3][:biomass_eq_count]
    biomass_params = [biomass_param1, biomass_param2, biomass_param3][:biomass_eq_count]
    biomass_bounds = [biomass_bound1, biomass_bound2, biomass_bound3][:biomass_eq_count]

    substrate_eqs = [substrate_eq1, substrate_eq2, substrate_eq3][:substrate_eq_count]
    substrate_params = [substrate_param1, substrate_param2, substrate_param3][:substrate_eq_count]
    substrate_bounds = [substrate_bound1, substrate_bound2, substrate_bound3][:substrate_eq_count]

    product_eqs = [product_eq1, product_eq2, product_eq3][:product_eq_count]
    product_params = [product_param1, product_param2, product_param3][:product_eq_count]
    product_bounds = [product_bound1, product_bound2, product_bound3][:product_eq_count]

    biomass_results = []
    substrate_results = []
    product_results = []

    # Inicializar el modelo principal
    main_model = BioprocessModel()

    # Ajusta los modelos de Biomasa
    for i in range(len(biomass_eqs)):
        equation = biomass_eqs[i]
        params_str = biomass_params[i]
        bounds_str = biomass_bounds[i]

        try:
            main_model.set_model('biomass', equation, params_str)
        except ValueError as ve:
            raise ValueError(f"Error en la configuración del modelo de biomasa {i+1}: {ve}")

        params = [param.strip() for param in params_str.split(',')]
        lower_bounds, upper_bounds = parse_bounds(bounds_str, len(params))

        try:
            y_pred = main_model.fit_model(
                'biomass', time, biomass_data,
                bounds=(lower_bounds, upper_bounds)
            )
            biomass_results.append({
                'model': main_model,
                'y_pred': y_pred,
                'equation': equation,
                'params': main_model.params['biomass']
            })
        except Exception as e:
            raise RuntimeError(f"Error al ajustar el modelo de biomasa {i+1}: {e}")

    # Usa el primer modelo de biomasa para X(t)
    biomass_model = biomass_results[0]['model']
    biomass_func = biomass_model.models['biomass']['function']
    biomass_params_values = list(biomass_model.params['biomass'].values())

    # Ajusta los modelos de Sustrato
    for i in range(len(substrate_eqs)):
        equation = substrate_eqs[i]
        params_str = substrate_params[i]
        bounds_str = substrate_bounds[i]

        try:
            main_model.set_model('substrate', equation, params_str)
        except ValueError as ve:
            raise ValueError(f"Error en la configuración del modelo de sustrato {i+1}: {ve}")

        params = [param.strip() for param in params_str.split(',')]
        lower_bounds, upper_bounds = parse_bounds(bounds_str, len(params))

        try:
            y_pred = main_model.fit_model(
                'substrate', time, substrate_data,
                bounds=(lower_bounds, upper_bounds)
            )
            substrate_results.append({
                'model': main_model,
                'y_pred': y_pred,
                'equation': equation,
                'params': main_model.params['substrate']
            })
        except Exception as e:
            raise RuntimeError(f"Error al ajustar el modelo de sustrato {i+1}: {e}")

    # Ajusta los modelos de Producto
    for i in range(len(product_eqs)):
        equation = product_eqs[i]
        params_str = product_params[i]
        bounds_str = product_bounds[i]

        try:
            main_model.set_model('product', equation, params_str)
        except ValueError as ve:
            raise ValueError(f"Error en la configuración del modelo de producto {i+1}: {ve}")

        params = [param.strip() for param in params_str.split(',')]
        lower_bounds, upper_bounds = parse_bounds(bounds_str, len(params))

        try:
            y_pred = main_model.fit_model(
                'product', time, product_data,
                bounds=(lower_bounds, upper_bounds)
            )
            product_results.append({
                'model': main_model,
                'y_pred': y_pred,
                'equation': equation,
                'params': main_model.params['product']
            })
        except Exception as e:
            raise RuntimeError(f"Error al ajustar el modelo de producto {i+1}: {e}")

    # Genera las gráficas
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))

    # Gráfica de Biomasa
    axs[0].plot(time, biomass_data, 'o', label='Datos de Biomasa')
    for i, result in enumerate(biomass_results):
        axs[0].plot(time, result['y_pred'], '-', label=f'Modelo de Biomasa {i+1}')
    axs[0].set_xlabel('Tiempo')
    axs[0].set_ylabel('Biomasa')
    if show_legend:
        axs[0].legend(loc=legend_position)

    # Gráfica de Sustrato
    axs[1].plot(time, substrate_data, 'o', label='Datos de Sustrato')
    for i, result in enumerate(substrate_results):
        axs[1].plot(time, result['y_pred'], '-', label=f'Modelo de Sustrato {i+1}')
    axs[1].set_xlabel('Tiempo')
    axs[1].set_ylabel('Sustrato')
    if show_legend:
        axs[1].legend(loc=legend_position)

    # Gráfica de Producto
    axs[2].plot(time, product_data, 'o', label='Datos de Producto')
    for i, result in enumerate(product_results):
        axs[2].plot(time, result['y_pred'], '-', label=f'Modelo de Producto {i+1}')
    axs[2].set_xlabel('Tiempo')
    axs[2].set_ylabel('Producto')
    if show_legend:
        axs[2].legend(loc=legend_position)

    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    image = Image.open(buf)

    prompt = f"""
Eres un experto en modelado de bioprocesos.

Analiza los siguientes resultados experimentales y proporciona un veredicto sobre la calidad de los modelos, sugiriendo mejoras si es necesario.

Biomasa:
{biomass_results}

Sustrato:
{substrate_results}

Producto:
{product_results}
"""
    analysis = generate_analysis(prompt, device=device)

    return image, analysis