hperkins commited on
Commit
d9da728
1 Parent(s): f40466f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -27
handler.py CHANGED
@@ -1,12 +1,12 @@
1
  import json
2
  import torch
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
  from qwen_vl_utils import process_vision_info
5
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
9
- # Load the model and processor for Qwen2-VL
10
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
11
  model_dir,
12
  torch_dtype=torch.float16, # FP16 for memory efficiency
@@ -16,6 +16,13 @@ class EndpointHandler:
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  self.model.eval()
18
 
 
 
 
 
 
 
 
19
  def preprocess(self, request_data):
20
  # Parse messages, extract video and text inputs
21
  messages = request_data.get('messages')
@@ -41,42 +48,31 @@ class EndpointHandler:
41
  return inputs.to(self.device)
42
 
43
  def inference(self, inputs):
44
- # Run inference on the model
45
  with torch.no_grad():
46
- generated_ids = self.model.generate(
47
- **inputs,
48
- max_new_tokens=128, # Limit the output length
49
- num_beams=1, # Reduce memory usage
50
- max_batch_size=1 # Process one batch at a time
51
  )
52
 
53
- # Trim generated outputs to remove input tokens
54
- generated_ids_trimmed = [
55
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
56
- ]
57
-
58
- return generated_ids_trimmed
59
 
60
  def postprocess(self, inference_output):
61
- # Decode generated output into human-readable text
62
- output_text = self.processor.batch_decode(
63
- inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False
64
- )
65
- return output_text
66
 
67
  def __call__(self, request):
68
  try:
69
  # Parse the incoming request data
70
  request_data = json.loads(request)
71
-
72
  # Preprocess the input data
73
  inputs = self.preprocess(request_data)
74
-
75
- # Perform inference
76
- outputs = self.inference(inputs)
77
-
78
- # Postprocess the outputs and return results
79
- result = self.postprocess(outputs)
80
- return json.dumps({"result": result})
81
  except Exception as e:
82
  return json.dumps({"error": str(e)})
 
1
  import json
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline
4
  from qwen_vl_utils import process_vision_info
5
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
9
+ # Initialize the model and processor for Visual Question Answering (VQA)
10
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
11
  model_dir,
12
  torch_dtype=torch.float16, # FP16 for memory efficiency
 
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  self.model.eval()
18
 
19
+ # Initialize the VQA pipeline
20
+ self.vqa_pipeline = pipeline(
21
+ task="visual-question-answering",
22
+ model=self.model,
23
+ device=0 if torch.cuda.is_available() else -1
24
+ )
25
+
26
  def preprocess(self, request_data):
27
  # Parse messages, extract video and text inputs
28
  messages = request_data.get('messages')
 
48
  return inputs.to(self.device)
49
 
50
  def inference(self, inputs):
51
+ # Use the VQA pipeline for inference
52
  with torch.no_grad():
53
+ result = self.vqa_pipeline(
54
+ images=inputs["images"] if "images" in inputs else inputs["videos"],
55
+ question=inputs["text"]
 
 
56
  )
57
 
58
+ return result
 
 
 
 
 
59
 
60
  def postprocess(self, inference_output):
61
+ # Convert inference output to JSON
62
+ return json.dumps(inference_output)
 
 
 
63
 
64
  def __call__(self, request):
65
  try:
66
  # Parse the incoming request data
67
  request_data = json.loads(request)
68
+
69
  # Preprocess the input data
70
  inputs = self.preprocess(request_data)
71
+
72
+ # Perform inference using the VQA pipeline
73
+ result = self.inference(inputs)
74
+
75
+ # Postprocess the result and return JSON output
76
+ return self.postprocess(result)
 
77
  except Exception as e:
78
  return json.dumps({"error": str(e)})