hperkins commited on
Commit
f43b9bc
1 Parent(s): 20ed6e0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -27
handler.py CHANGED
@@ -2,37 +2,30 @@ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
2
  from qwen_vl_utils import process_vision_info
3
  import torch
4
  import json
5
- import os
6
-
7
- # Set the environment variable to handle memory fragmentation
8
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir):
12
- # Load the model with automatic device dispatching
13
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  model_dir,
15
- torch_dtype=torch.float16, # Use FP16 for memory efficiency
16
- device_map="auto", # Auto device dispatch across available GPUs
17
- low_cpu_mem_usage=True # Minimize CPU memory usage
18
  )
19
  self.processor = AutoProcessor.from_pretrained(model_dir)
20
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- # No need to move model to device manually; device_map handles it
22
  self.model.eval()
23
 
24
- # Enable gradient checkpointing for further memory optimization
25
  self.model.gradient_checkpointing_enable()
26
 
27
  def preprocess(self, request_data):
28
- # Handle the request and extract vision data (images, videos)
29
  messages = request_data.get('messages')
30
  if not messages:
31
  raise ValueError("Messages are required")
32
-
33
- # Process vision input from the messages
34
  image_inputs, video_inputs = process_vision_info(messages)
35
-
36
  # Prepare text input for the chat model
37
  text = self.processor.apply_chat_template(
38
  messages, tokenize=False, add_generation_prompt=True
@@ -47,30 +40,31 @@ class EndpointHandler:
47
  return_tensors="pt",
48
  )
49
 
50
- return inputs.to(self.device)
51
 
52
  def inference(self, inputs):
53
- # Perform inference using memory-efficient settings
54
  with torch.no_grad():
55
  generated_ids = self.model.generate(
56
- **inputs,
57
- max_new_tokens=64, # Reduce max tokens for memory optimization
58
- num_beams=1, # Reduce beam size to save memory
59
- max_batch_size=1 # Keep batch size small to minimize memory usage
 
60
  )
61
 
62
- # Trim the output by removing input tokens from the generated output
63
  generated_ids_trimmed = [
64
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
65
  ]
66
 
67
- # Clear CUDA memory cache after inference to free up memory
68
  torch.cuda.empty_cache()
69
 
70
  return generated_ids_trimmed
71
 
72
  def postprocess(self, inference_output):
73
- # Decode the model's output into human-readable text
74
  output_text = self.processor.batch_decode(
75
  inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False
76
  )
@@ -78,13 +72,13 @@ class EndpointHandler:
78
 
79
  def __call__(self, request):
80
  try:
81
- # Parse the JSON request
82
  request_data = json.loads(request)
83
- # Preprocess the input data
84
  inputs = self.preprocess(request_data)
85
  # Perform inference
86
  outputs = self.inference(inputs)
87
- # Postprocess the output and return the result
88
  result = self.postprocess(outputs)
89
  return json.dumps({"result": result})
90
  except Exception as e:
 
2
  from qwen_vl_utils import process_vision_info
3
  import torch
4
  import json
 
 
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
+ # Load the model and processor for Qwen2-VL-7B
9
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
10
  model_dir,
11
+ torch_dtype=torch.float32, # Use float16 for reduced memory usage
12
+ device_map="auto" # Automatically assign to available GPU(s)
 
13
  )
14
  self.processor = AutoProcessor.from_pretrained(model_dir)
 
 
15
  self.model.eval()
16
 
17
+ # Enable gradient checkpointing for memory savings
18
  self.model.gradient_checkpointing_enable()
19
 
20
  def preprocess(self, request_data):
21
+ # Handle image and video input from the request
22
  messages = request_data.get('messages')
23
  if not messages:
24
  raise ValueError("Messages are required")
25
+
26
+ # Process vision info (image or video) from the messages
27
  image_inputs, video_inputs = process_vision_info(messages)
28
+
29
  # Prepare text input for the chat model
30
  text = self.processor.apply_chat_template(
31
  messages, tokenize=False, add_generation_prompt=True
 
40
  return_tensors="pt",
41
  )
42
 
43
+ return inputs.to(self.model.device)
44
 
45
  def inference(self, inputs):
46
+ # Perform inference with the model
47
  with torch.no_grad():
48
  generated_ids = self.model.generate(
49
+ **inputs,
50
+ max_new_tokens=256, # Increased token length for richer output
51
+ num_beams=5, # Increase beam size for better quality
52
+ early_stopping=True, # Stop when all beams have finished
53
+ max_batch_size=1 # Keep batch size small to manage memory usage
54
  )
55
 
56
+ # Trim the output (remove input tokens from generated output)
57
  generated_ids_trimmed = [
58
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
59
  ]
60
 
61
+ # Clear the CUDA cache after inference to release unused memory
62
  torch.cuda.empty_cache()
63
 
64
  return generated_ids_trimmed
65
 
66
  def postprocess(self, inference_output):
67
+ # Decode the generated output from the model
68
  output_text = self.processor.batch_decode(
69
  inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False
70
  )
 
72
 
73
  def __call__(self, request):
74
  try:
75
+ # Parse the JSON request data
76
  request_data = json.loads(request)
77
+ # Preprocess the input data (text, images, videos)
78
  inputs = self.preprocess(request_data)
79
  # Perform inference
80
  outputs = self.inference(inputs)
81
+ # Postprocess the output
82
  result = self.postprocess(outputs)
83
  return json.dumps({"result": result})
84
  except Exception as e: