File size: 3,356 Bytes
ed47265
 
33ce564
 
 
6b10bcf
 
d67e0d7
ed47265
6b10bcf
d67e0d7
 
ed47265
6b10bcf
33ce564
 
 
 
d67e0d7
057b8f0
 
ed47265
 
 
 
 
33ce564
ed47265
 
33ce564
ed47265
 
 
 
33ce564
ed47265
 
 
 
 
 
 
 
 
 
33ce564
 
ed47265
33ce564
057b8f0
 
 
 
 
 
 
ed47265
 
 
057b8f0
ed47265
057b8f0
 
 
 
ed47265
33ce564
 
ed47265
 
 
 
 
33ce564
6b10bcf
33ce564
ed47265
 
 
33ce564
 
 
 
 
ed47265
33ce564
6b10bcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import json

class EndpointHandler:
    def __init__(self, model_dir):
        # Load the model and processor for Qwen2-VL-7B without FlashAttention2
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_dir,
            torch_dtype=torch.float16,  # Use FP16 for reduced memory usage
            device_map="auto"  # Automatically assigns the model to the available GPU(s)
        )
        self.processor = AutoProcessor.from_pretrained(model_dir)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

        # Enable gradient checkpointing to save memory
        self.model.gradient_checkpointing_enable()

    def preprocess(self, request_data):
        # Handle image and video input from the request
        messages = request_data.get('messages')
        if not messages:
            raise ValueError("Messages are required")
        
        # Process vision info (image or video) from the messages
        image_inputs, video_inputs = process_vision_info(messages)

        # Prepare text input for the chat model
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        # Prepare inputs for the model (text + vision inputs)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        
        return inputs.to(self.device)

    def inference(self, inputs):
        # Perform inference with the model
        with torch.no_grad():
            # Generate the output with memory-efficient settings
            generated_ids = self.model.generate(
                **inputs, 
                max_new_tokens=128,  # Limit output length
                num_beams=1,  # Set beam size to reduce memory consumption
                max_batch_size=1  # Set batch size to 1 for memory optimization
            )

        # Trim the output (remove input tokens from generated output)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]

        # Clear the CUDA cache after inference to release unused memory
        torch.cuda.empty_cache()

        return generated_ids_trimmed

    def postprocess(self, inference_output):
        # Decode the generated output from the model
        output_text = self.processor.batch_decode(
            inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return output_text

    def __call__(self, request):
        try:
            # Parse the JSON request data
            request_data = json.loads(request)
            # Preprocess the input data (text, images, videos)
            inputs = self.preprocess(request_data)
            # Perform inference
            outputs = self.inference(inputs)
            # Postprocess the output
            result = self.postprocess(outputs)
            return json.dumps({"result": result})
        except Exception as e:
            return json.dumps({"error": str(e)})