Update handler.py
Browse files- handler.py +23 -27
handler.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import json
|
2 |
import torch
|
3 |
-
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
4 |
from qwen_vl_utils import process_vision_info
|
5 |
|
6 |
|
7 |
class EndpointHandler:
|
8 |
def __init__(self, model_dir):
|
9 |
-
#
|
10 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
11 |
model_dir,
|
12 |
torch_dtype=torch.float16, # FP16 for memory efficiency
|
@@ -16,6 +16,13 @@ class EndpointHandler:
|
|
16 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
self.model.eval()
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def preprocess(self, request_data):
|
20 |
# Parse messages, extract video and text inputs
|
21 |
messages = request_data.get('messages')
|
@@ -41,42 +48,31 @@ class EndpointHandler:
|
|
41 |
return inputs.to(self.device)
|
42 |
|
43 |
def inference(self, inputs):
|
44 |
-
#
|
45 |
with torch.no_grad():
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
num_beams=1, # Reduce memory usage
|
50 |
-
max_batch_size=1 # Process one batch at a time
|
51 |
)
|
52 |
|
53 |
-
|
54 |
-
generated_ids_trimmed = [
|
55 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
56 |
-
]
|
57 |
-
|
58 |
-
return generated_ids_trimmed
|
59 |
|
60 |
def postprocess(self, inference_output):
|
61 |
-
#
|
62 |
-
|
63 |
-
inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
64 |
-
)
|
65 |
-
return output_text
|
66 |
|
67 |
def __call__(self, request):
|
68 |
try:
|
69 |
# Parse the incoming request data
|
70 |
request_data = json.loads(request)
|
71 |
-
|
72 |
# Preprocess the input data
|
73 |
inputs = self.preprocess(request_data)
|
74 |
-
|
75 |
-
# Perform inference
|
76 |
-
|
77 |
-
|
78 |
-
# Postprocess the
|
79 |
-
|
80 |
-
return json.dumps({"result": result})
|
81 |
except Exception as e:
|
82 |
return json.dumps({"error": str(e)})
|
|
|
1 |
import json
|
2 |
import torch
|
3 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline
|
4 |
from qwen_vl_utils import process_vision_info
|
5 |
|
6 |
|
7 |
class EndpointHandler:
|
8 |
def __init__(self, model_dir):
|
9 |
+
# Initialize the model and processor for Visual Question Answering (VQA)
|
10 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
11 |
model_dir,
|
12 |
torch_dtype=torch.float16, # FP16 for memory efficiency
|
|
|
16 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
self.model.eval()
|
18 |
|
19 |
+
# Initialize the VQA pipeline
|
20 |
+
self.vqa_pipeline = pipeline(
|
21 |
+
task="visual-question-answering",
|
22 |
+
model=self.model,
|
23 |
+
device=0 if torch.cuda.is_available() else -1
|
24 |
+
)
|
25 |
+
|
26 |
def preprocess(self, request_data):
|
27 |
# Parse messages, extract video and text inputs
|
28 |
messages = request_data.get('messages')
|
|
|
48 |
return inputs.to(self.device)
|
49 |
|
50 |
def inference(self, inputs):
|
51 |
+
# Use the VQA pipeline for inference
|
52 |
with torch.no_grad():
|
53 |
+
result = self.vqa_pipeline(
|
54 |
+
images=inputs["images"] if "images" in inputs else inputs["videos"],
|
55 |
+
question=inputs["text"]
|
|
|
|
|
56 |
)
|
57 |
|
58 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def postprocess(self, inference_output):
|
61 |
+
# Convert inference output to JSON
|
62 |
+
return json.dumps(inference_output)
|
|
|
|
|
|
|
63 |
|
64 |
def __call__(self, request):
|
65 |
try:
|
66 |
# Parse the incoming request data
|
67 |
request_data = json.loads(request)
|
68 |
+
|
69 |
# Preprocess the input data
|
70 |
inputs = self.preprocess(request_data)
|
71 |
+
|
72 |
+
# Perform inference using the VQA pipeline
|
73 |
+
result = self.inference(inputs)
|
74 |
+
|
75 |
+
# Postprocess the result and return JSON output
|
76 |
+
return self.postprocess(result)
|
|
|
77 |
except Exception as e:
|
78 |
return json.dumps({"error": str(e)})
|