hperkins commited on
Commit
d9d7db9
1 Parent(s): d4f4061

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -23
handler.py CHANGED
@@ -1,19 +1,18 @@
1
  import json
2
  import torch
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, PreTrainedImageProcessor
4
  from qwen_vl_utils import process_vision_info
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
- # Configure device settings
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  try:
12
- # Load the model with automatic device mapping and memory-efficient precision
13
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  model_dir,
15
- torch_dtype=torch.float16, # Use half-precision for better GPU use
16
- device_map="auto" # Automatically map model to GPU(s)
17
  )
18
  self.model.to(self.device)
19
  except Exception as e:
@@ -21,31 +20,27 @@ class EndpointHandler:
21
  raise
22
 
23
  try:
24
- # Initialize processor
25
  self.processor = AutoProcessor.from_pretrained(model_dir)
 
26
  except Exception as e:
27
  print(f"Error loading processor: {e}")
28
  raise
29
 
30
- # Define a VQA pipeline with explicitly provided processor
31
  self.vqa_pipeline = pipeline(
32
  task="visual-question-answering",
33
  model=self.model,
34
- image_processor=self.processor, # Explicitly pass the image processor
35
- device=0 if torch.cuda.is_available() else -1 # Use first GPU or CPU
36
  )
37
 
38
  def preprocess(self, request_data):
39
- # Extract messages
40
  messages = request_data.get('messages')
41
  if not messages:
42
  raise ValueError("Missing 'messages' in request data.")
43
 
44
- # Process visual and text inputs
45
  image_inputs, video_inputs = process_vision_info(messages)
46
  text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
47
 
48
- # Prepare inputs for the model
49
  inputs = self.processor(
50
  text=[text],
51
  images=image_inputs,
@@ -53,36 +48,26 @@ class EndpointHandler:
53
  padding=True,
54
  return_tensors="pt"
55
  ).to(self.device)
56
-
57
  return inputs
58
 
59
  def inference(self, inputs):
60
- # Execute model inference without gradient computation
61
  with torch.no_grad():
62
  result = self.vqa_pipeline(
63
  images=inputs.get("images", None),
64
  videos=inputs.get("videos", None),
65
  question=inputs["text"]
66
  )
67
-
68
  return result
69
 
70
  def postprocess(self, inference_output):
71
- # Serialize inference result to JSON
72
  return json.dumps(inference_output)
73
 
74
  def __call__(self, request):
75
  try:
76
- # Parse the incoming request
77
  request_data = json.loads(request)
78
-
79
- # Preprocess input data
80
  inputs = self.preprocess(request_data)
81
-
82
- # Perform inference
83
  result = self.inference(inputs)
84
-
85
- # Return postprocessed result
86
  return self.postprocess(result)
87
  except Exception as e:
88
  error_message = f"Error: {str(e)}"
 
1
  import json
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, AutoImageProcessor
4
  from qwen_vl_utils import process_vision_info
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
+ # Setup device configuration
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  try:
 
12
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
13
  model_dir,
14
+ torch_dtype=torch.float16,
15
+ device_map="auto"
16
  )
17
  self.model.to(self.device)
18
  except Exception as e:
 
20
  raise
21
 
22
  try:
 
23
  self.processor = AutoProcessor.from_pretrained(model_dir)
24
+ self.image_processor = AutoImageProcessor.from_pretrained(model_dir) # Ensure you have the correct processor
25
  except Exception as e:
26
  print(f"Error loading processor: {e}")
27
  raise
28
 
 
29
  self.vqa_pipeline = pipeline(
30
  task="visual-question-answering",
31
  model=self.model,
32
+ image_processor=self.image_processor, # Explicit image processor if needed
33
+ device=0 if torch.cuda.is_available() else -1
34
  )
35
 
36
  def preprocess(self, request_data):
 
37
  messages = request_data.get('messages')
38
  if not messages:
39
  raise ValueError("Missing 'messages' in request data.")
40
 
 
41
  image_inputs, video_inputs = process_vision_info(messages)
42
  text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
43
 
 
44
  inputs = self.processor(
45
  text=[text],
46
  images=image_inputs,
 
48
  padding=True,
49
  return_tensors="pt"
50
  ).to(self.device)
51
+
52
  return inputs
53
 
54
  def inference(self, inputs):
 
55
  with torch.no_grad():
56
  result = self.vqa_pipeline(
57
  images=inputs.get("images", None),
58
  videos=inputs.get("videos", None),
59
  question=inputs["text"]
60
  )
 
61
  return result
62
 
63
  def postprocess(self, inference_output):
 
64
  return json.dumps(inference_output)
65
 
66
  def __call__(self, request):
67
  try:
 
68
  request_data = json.loads(request)
 
 
69
  inputs = self.preprocess(request_data)
 
 
70
  result = self.inference(inputs)
 
 
71
  return self.postprocess(result)
72
  except Exception as e:
73
  error_message = f"Error: {str(e)}"