Update handler.py
Browse files- handler.py +8 -23
handler.py
CHANGED
@@ -1,19 +1,18 @@
|
|
1 |
import json
|
2 |
import torch
|
3 |
-
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline,
|
4 |
from qwen_vl_utils import process_vision_info
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, model_dir):
|
8 |
-
#
|
9 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
11 |
try:
|
12 |
-
# Load the model with automatic device mapping and memory-efficient precision
|
13 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
14 |
model_dir,
|
15 |
-
torch_dtype=torch.float16,
|
16 |
-
device_map="auto"
|
17 |
)
|
18 |
self.model.to(self.device)
|
19 |
except Exception as e:
|
@@ -21,31 +20,27 @@ class EndpointHandler:
|
|
21 |
raise
|
22 |
|
23 |
try:
|
24 |
-
# Initialize processor
|
25 |
self.processor = AutoProcessor.from_pretrained(model_dir)
|
|
|
26 |
except Exception as e:
|
27 |
print(f"Error loading processor: {e}")
|
28 |
raise
|
29 |
|
30 |
-
# Define a VQA pipeline with explicitly provided processor
|
31 |
self.vqa_pipeline = pipeline(
|
32 |
task="visual-question-answering",
|
33 |
model=self.model,
|
34 |
-
image_processor=self.
|
35 |
-
device=0 if torch.cuda.is_available() else -1
|
36 |
)
|
37 |
|
38 |
def preprocess(self, request_data):
|
39 |
-
# Extract messages
|
40 |
messages = request_data.get('messages')
|
41 |
if not messages:
|
42 |
raise ValueError("Missing 'messages' in request data.")
|
43 |
|
44 |
-
# Process visual and text inputs
|
45 |
image_inputs, video_inputs = process_vision_info(messages)
|
46 |
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
47 |
|
48 |
-
# Prepare inputs for the model
|
49 |
inputs = self.processor(
|
50 |
text=[text],
|
51 |
images=image_inputs,
|
@@ -53,36 +48,26 @@ class EndpointHandler:
|
|
53 |
padding=True,
|
54 |
return_tensors="pt"
|
55 |
).to(self.device)
|
56 |
-
|
57 |
return inputs
|
58 |
|
59 |
def inference(self, inputs):
|
60 |
-
# Execute model inference without gradient computation
|
61 |
with torch.no_grad():
|
62 |
result = self.vqa_pipeline(
|
63 |
images=inputs.get("images", None),
|
64 |
videos=inputs.get("videos", None),
|
65 |
question=inputs["text"]
|
66 |
)
|
67 |
-
|
68 |
return result
|
69 |
|
70 |
def postprocess(self, inference_output):
|
71 |
-
# Serialize inference result to JSON
|
72 |
return json.dumps(inference_output)
|
73 |
|
74 |
def __call__(self, request):
|
75 |
try:
|
76 |
-
# Parse the incoming request
|
77 |
request_data = json.loads(request)
|
78 |
-
|
79 |
-
# Preprocess input data
|
80 |
inputs = self.preprocess(request_data)
|
81 |
-
|
82 |
-
# Perform inference
|
83 |
result = self.inference(inputs)
|
84 |
-
|
85 |
-
# Return postprocessed result
|
86 |
return self.postprocess(result)
|
87 |
except Exception as e:
|
88 |
error_message = f"Error: {str(e)}"
|
|
|
1 |
import json
|
2 |
import torch
|
3 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, AutoImageProcessor
|
4 |
from qwen_vl_utils import process_vision_info
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, model_dir):
|
8 |
+
# Setup device configuration
|
9 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
11 |
try:
|
|
|
12 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
13 |
model_dir,
|
14 |
+
torch_dtype=torch.float16,
|
15 |
+
device_map="auto"
|
16 |
)
|
17 |
self.model.to(self.device)
|
18 |
except Exception as e:
|
|
|
20 |
raise
|
21 |
|
22 |
try:
|
|
|
23 |
self.processor = AutoProcessor.from_pretrained(model_dir)
|
24 |
+
self.image_processor = AutoImageProcessor.from_pretrained(model_dir) # Ensure you have the correct processor
|
25 |
except Exception as e:
|
26 |
print(f"Error loading processor: {e}")
|
27 |
raise
|
28 |
|
|
|
29 |
self.vqa_pipeline = pipeline(
|
30 |
task="visual-question-answering",
|
31 |
model=self.model,
|
32 |
+
image_processor=self.image_processor, # Explicit image processor if needed
|
33 |
+
device=0 if torch.cuda.is_available() else -1
|
34 |
)
|
35 |
|
36 |
def preprocess(self, request_data):
|
|
|
37 |
messages = request_data.get('messages')
|
38 |
if not messages:
|
39 |
raise ValueError("Missing 'messages' in request data.")
|
40 |
|
|
|
41 |
image_inputs, video_inputs = process_vision_info(messages)
|
42 |
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
43 |
|
|
|
44 |
inputs = self.processor(
|
45 |
text=[text],
|
46 |
images=image_inputs,
|
|
|
48 |
padding=True,
|
49 |
return_tensors="pt"
|
50 |
).to(self.device)
|
51 |
+
|
52 |
return inputs
|
53 |
|
54 |
def inference(self, inputs):
|
|
|
55 |
with torch.no_grad():
|
56 |
result = self.vqa_pipeline(
|
57 |
images=inputs.get("images", None),
|
58 |
videos=inputs.get("videos", None),
|
59 |
question=inputs["text"]
|
60 |
)
|
|
|
61 |
return result
|
62 |
|
63 |
def postprocess(self, inference_output):
|
|
|
64 |
return json.dumps(inference_output)
|
65 |
|
66 |
def __call__(self, request):
|
67 |
try:
|
|
|
68 |
request_data = json.loads(request)
|
|
|
|
|
69 |
inputs = self.preprocess(request_data)
|
|
|
|
|
70 |
result = self.inference(inputs)
|
|
|
|
|
71 |
return self.postprocess(result)
|
72 |
except Exception as e:
|
73 |
error_message = f"Error: {str(e)}"
|