import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import os
import gradio as gr

# Login to Hugging Face Hub
access_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
login(token=access_token)

# Define model details
peft_model_id = "kuyesu22/sunbird-ug-lang-v1.0-bloom-7b1-lora"
config = PeftConfig.from_pretrained(peft_model_id)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    torch_dtype=torch.float16,  # Use mixed precision for speed
    device_map="auto"           # Automatically allocate to available devices
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora fine-tuned model
model = PeftModel.from_pretrained(model, peft_model_id)

# Ensure model is in evaluation mode
model.eval()

# Define inference function
def make_inference(english_text):
    # Tokenize the input English sentence
    batch = tokenizer(
        f"### English:\n{english_text}\n\n### Runyankole:",
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(model.device)  # Move batch to the same device as the model

    # Generate the translation using the model
    with torch.no_grad():
        with torch.cuda.amp.autocast():  # Mixed precision inference
            output_tokens = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_new_tokens=100,
                do_sample=True,  # Enables sampling for more creative responses
                temperature=0.7,  # Control randomness in predictions
                num_return_sequences=1,  # Return only one translation
                pad_token_id=tokenizer.eos_token_id  # Handle padding tokens
            )

    # Decode the output tokens to get the translation
    translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    return translated_text

# Gradio Interface
def launch_gradio_interface():
    inputs = gr.components.Textbox(lines=2, label="English Text")  # Input text in English
    outputs = gr.components.Textbox(label="Translated Runyankole Text")  # Output in Runyankole

    # Launch Gradio app
    gr.Interface(
        fn=make_inference,
        inputs=inputs,
        outputs=outputs,
        title="Sunbird UG Lang Translator",
        description="Translate English to Runyankole using BLOOM model fine-tuned with LoRA.",
    ).launch()

# Entry point to run the Gradio app
if __name__ == "__main__":
    launch_gradio_interface()