File size: 2,627 Bytes
f40466f
 
d9d7db9
171cc73
f40466f
6b10bcf
 
d9d7db9
ae2331d
057b8f0
b150b57
 
 
d9d7db9
 
b150b57
 
 
 
 
 
 
 
d9d7db9
b150b57
 
 
 
d9da728
 
 
d9d7db9
 
d9da728
 
ed47265
 
 
b150b57
f40466f
ed47265
b150b57
f43b9bc
ed47265
 
 
 
 
b150b57
 
d9d7db9
b150b57
33ce564
 
 
d9da728
b150b57
 
d9da728
057b8f0
d9da728
33ce564
 
d9da728
33ce564
6b10bcf
33ce564
171cc73
33ce564
d9da728
 
33ce564
b150b57
 
b45af94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import json
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, AutoImageProcessor
from qwen_vl_utils import process_vision_info

class EndpointHandler:
    def __init__(self, model_dir):
        # Setup device configuration
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        try:
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_dir,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            self.model.to(self.device)
        except Exception as e:
            print(f"Error loading model: {e}")
            raise

        try:
            self.processor = AutoProcessor.from_pretrained(model_dir)
            self.image_processor = AutoImageProcessor.from_pretrained(model_dir)  # Ensure you have the correct processor
        except Exception as e:
            print(f"Error loading processor: {e}")
            raise

        self.vqa_pipeline = pipeline(
            task="visual-question-answering",
            model=self.model,
            image_processor=self.image_processor,  # Explicit image processor if needed
            device=0 if torch.cuda.is_available() else -1
        )

    def preprocess(self, request_data):
        messages = request_data.get('messages')
        if not messages:
            raise ValueError("Missing 'messages' in request data.")

        image_inputs, video_inputs = process_vision_info(messages)
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        ).to(self.device)

        return inputs

    def inference(self, inputs):
        with torch.no_grad():
            result = self.vqa_pipeline(
                images=inputs.get("images", None),
                videos=inputs.get("videos", None),
                question=inputs["text"]
            )
        return result

    def postprocess(self, inference_output):
        return json.dumps(inference_output)

    def __call__(self, request):
        try:
            request_data = json.loads(request)
            inputs = self.preprocess(request_data)
            result = self.inference(inputs)
            return self.postprocess(result)
        except Exception as e:
            error_message = f"Error: {str(e)}"
            print(error_message)
            return json.dumps({"error": error_message})