Update handler.py
Browse files- handler.py +6 -12
handler.py
CHANGED
@@ -3,15 +3,15 @@ from qwen_vl_utils import process_vision_info
|
|
3 |
import torch
|
4 |
import json
|
5 |
|
6 |
-
class
|
7 |
-
def __init__(self):
|
8 |
# Load the model and processor for Qwen2-VL-7B without FlashAttention2
|
9 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
10 |
-
|
11 |
torch_dtype=torch.float16, # Use FP16 for reduced memory usage
|
12 |
device_map="auto" # Automatically assigns the model to the available GPU(s)
|
13 |
)
|
14 |
-
self.processor = AutoProcessor.from_pretrained(
|
15 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
self.model.to(self.device)
|
17 |
self.model.eval()
|
@@ -72,7 +72,7 @@ class Qwen2VL7bHandler:
|
|
72 |
)
|
73 |
return output_text
|
74 |
|
75 |
-
def
|
76 |
try:
|
77 |
# Parse the JSON request data
|
78 |
request_data = json.loads(request)
|
@@ -84,10 +84,4 @@ class Qwen2VL7bHandler:
|
|
84 |
result = self.postprocess(outputs)
|
85 |
return json.dumps({"result": result})
|
86 |
except Exception as e:
|
87 |
-
return json.dumps({"error": str(e)})
|
88 |
-
|
89 |
-
# Instantiate the handler for deployment
|
90 |
-
_service = Qwen2VL7bHandler()
|
91 |
-
|
92 |
-
def handle(request):
|
93 |
-
return _service.handle(request)
|
|
|
3 |
import torch
|
4 |
import json
|
5 |
|
6 |
+
class EndpointHandler:
|
7 |
+
def __init__(self, model_dir):
|
8 |
# Load the model and processor for Qwen2-VL-7B without FlashAttention2
|
9 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
10 |
+
model_dir,
|
11 |
torch_dtype=torch.float16, # Use FP16 for reduced memory usage
|
12 |
device_map="auto" # Automatically assigns the model to the available GPU(s)
|
13 |
)
|
14 |
+
self.processor = AutoProcessor.from_pretrained(model_dir)
|
15 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
self.model.to(self.device)
|
17 |
self.model.eval()
|
|
|
72 |
)
|
73 |
return output_text
|
74 |
|
75 |
+
def __call__(self, request):
|
76 |
try:
|
77 |
# Parse the JSON request data
|
78 |
request_data = json.loads(request)
|
|
|
84 |
result = self.postprocess(outputs)
|
85 |
return json.dumps({"result": result})
|
86 |
except Exception as e:
|
87 |
+
return json.dumps({"error": str(e)})
|
|
|
|
|
|
|
|
|
|
|
|