import cv2
import io
import numpy as np
from PIL import Image

import pytesseract

from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware

from mltu.inferenceModel import OnnxInferenceModel
from mltu.utils.text_utils import ctc_decoder
from mltu.transformers import ImageResizer
from mltu.configs import BaseModelConfigs

from textblob import TextBlob
from happytransformer import HappyTextToText, TTSettings


from transformers import AutoTokenizer, T5ForConditionalGeneration
from pydantic import BaseModel

tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large")
chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large")

configs = BaseModelConfigs.load("./configs.yaml")

#happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")

beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100)

app = FastAPI()

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class ImageToWordModel(OnnxInferenceModel):
    def __init__(self, char_list, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.char_list = char_list

    def predict(self, image: np.ndarray):
        image = ImageResizer.resize_maintaining_aspect_ratio(
            image, *self.input_shape[:2][::-1]
        )

        image_pred = np.expand_dims(image, axis=0).astype(np.float32)

        preds = self.model.run(None, {self.input_name: image_pred})[0]

        text = ctc_decoder(preds, self.char_list)[0]

        return text


model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
extracted_text = ""

@app.post("/extract_handwritten_text/")
async def predict_text(image: UploadFile):
    global extracted_text
    # Read the uploaded image
    img = await image.read()
    nparr = np.frombuffer(img, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

    # Make a prediction
    extracted_text = model.predict(img)
    #corrected_text = happy_tt.generate_text(extracted_text, beam_settings)

    return {"text": extracted_text}


@app.post("/extract_text/")
async def extract_text_from_image(image: UploadFile):
    global extracted_text
    # Check if the uploaded file is an image
    if image.content_type.startswith("image/"):
        # Read the image from the uploaded file
        image_bytes = await image.read()
        img = Image.open(io.BytesIO(image_bytes))

        # Perform OCR on the image
        extracted_text = pytesseract.image_to_string(img)
        #corrected_text = happy_tt.generate_text(extracted_text, beam_settings)

        return {"text": extracted_text}
    else:
        return {"error": "Invalid file format. Please upload an image."}

class ChatPrompt(BaseModel):
    prompt: str

@app.post("/chat_prompt/")
async def chat_prompt(request: ChatPrompt):
    global extracted_text
    input_text = request.prompt + ": " + extracted_text
    print(input_text)
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    outputs = chatModel.generate(input_ids, max_length=256)
    edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return {"edited_text": edited_text}