# interface.py

from UI import create_interface
from models import BioprocessModel
import io
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import copy
from config import DEVICE, MODEL_PATH, MAX_LENGTH, TEMPERATURE

device = DEVICE
model_path = MODEL_PATH
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to(device).eval()

def generate_analysis(prompt, max_length=MAX_LENGTH):
    try:
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        generated_ids = model.generate(
            input_ids=input_ids,
            max_length=max_length + len(input_ids[0]),
            temperature=TEMPERATURE,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            early_stopping=True
        )
        output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        analysis = output_text[len(prompt):].strip()
        return analysis
    except Exception as e:
        return f"An error occurred during analysis: {e}"

def parse_bounds(bounds_str, num_params):
    try:
        bounds = eval(f"[{bounds_str}]")
        if len(bounds) != num_params:
            raise ValueError
        lower_bounds = [b[0] for b in bounds]
        upper_bounds = [b[1] for b in bounds]
        return lower_bounds, upper_bounds
    except:
        lower_bounds = [-np.inf] * num_params
        upper_bounds = [np.inf] * num_params
        return lower_bounds, upper_bounds

def process_and_plot(
    file,
    biomass_equations_list,
    biomass_params_list,
    biomass_bounds_list,
    substrate_equations_list,
    substrate_params_list,
    substrate_bounds_list,
    product_equations_list,
    product_params_list,
    product_bounds_list,
    legend_position,
    show_legend,
    show_params,
    biomass_eq_count,
    substrate_eq_count,
    product_eq_count
):
    # Implement the function to process data, fit models, generate plots, and get analysis
    # Return the plot image and analysis
    return [image], analysis

if __name__ == "__main__":
    demo = create_interface(process_and_plot)
    demo.launch()