import gradio as gr
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
    # Load the model and tokenizer
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name)

    # Tokenize the input sequence
    input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
    sequence_length = input_ids.shape[1] - 2  # Excluding the special tokens

    # Adjust end position if not specified
    if end_pos is None:
        end_pos = sequence_length

    # List of amino acids
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")

    # Initialize heatmap
    heatmap = np.zeros((20, end_pos - start_pos + 1))

    # Calculate LLRs for each position and amino acid
    for position in range(start_pos, end_pos + 1):
        # Mask the target position
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, position] = tokenizer.mask_token_id
        # Get logits for the masked token
        with torch.no_grad():
            logits = model(masked_input_ids).logits
        # Calculate log probabilities
        probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
        log_probabilities = torch.log(probabilities)
        # Get the log probability of the wild-type residue
        wt_residue = input_ids[0, position].item()
        log_prob_wt = log_probabilities[wt_residue].item()
        # Calculate LLR for each variant
        for i, amino_acid in enumerate(amino_acids):
            log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
            heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt

    # Visualize the heatmap
    plt.figure(figsize=(15, 5))
    plt.imshow(heatmap, cmap="viridis_r", aspect="auto")
    plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
    plt.yticks(range(20), amino_acids)
    plt.xlabel("Position in Protein Sequence")
    plt.ylabel("Amino Acid Mutations")
    plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
    plt.colorbar(label="Log Likelihood Ratio (LLR)")

    # Save the plot to a temporary file and return the file path
    temp_file = "temp_heatmap.png"
    return temp_file

def heatmap_interface(sequence, start, end=None):
    # Convert start and end to integers
    start = int(start)
    if end is not None:
        end = int(end)

    # If end is None or greater than sequence length, set it to sequence length
    if end is None or end > len(sequence) or end <= 0:
        end = len(sequence)

    # Ensure start is within bounds
    if start < 1 or start > len(sequence):
        return "Start position is out of bounds."

    # Generate heatmap
    heatmap_path = generate_heatmap(sequence, start, end)
    return heatmap_path

# Define the Gradio interface
iface = gr.Interface(
        gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."),
        gr.Number(label="Start Position", value=1),
        gr.Number(label="End Position")  # No default value needed

# Run the Gradio app