import json
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, AutoImageProcessor
from qwen_vl_utils import process_vision_info

class EndpointHandler:
    def __init__(self, model_dir):
        # Setup device configuration
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        try:
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_dir,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            self.model.to(self.device)
        except Exception as e:
            print(f"Error loading model: {e}")
            raise

        try:
            self.processor = AutoProcessor.from_pretrained(model_dir)
            self.image_processor = AutoImageProcessor.from_pretrained(model_dir)  # Ensure you have the correct processor
        except Exception as e:
            print(f"Error loading processor: {e}")
            raise

        self.vqa_pipeline = pipeline(
            task="visual-question-answering",
            model=self.model,
            image_processor=self.image_processor,  # Explicit image processor if needed
            device=0 if torch.cuda.is_available() else -1
        )

    def preprocess(self, request_data):
        messages = request_data.get('messages')
        if not messages:
            raise ValueError("Missing 'messages' in request data.")

        image_inputs, video_inputs = process_vision_info(messages)
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        ).to(self.device)

        return inputs

    def inference(self, inputs):
        with torch.no_grad():
            result = self.vqa_pipeline(
                images=inputs.get("images", None),
                videos=inputs.get("videos", None),
                question=inputs["text"]
            )
        return result

    def postprocess(self, inference_output):
        return json.dumps(inference_output)

    def __call__(self, request):
        try:
            request_data = json.loads(request)
            inputs = self.preprocess(request_data)
            result = self.inference(inputs)
            return self.postprocess(result)
        except Exception as e:
            error_message = f"Error: {str(e)}"
            print(error_message)
            return json.dumps({"error": error_message})