File size: 2,627 Bytes
f40466f d9d7db9 171cc73 f40466f 6b10bcf d9d7db9 ae2331d 057b8f0 b150b57 d9d7db9 b150b57 d9d7db9 b150b57 d9da728 d9d7db9 d9da728 ed47265 b150b57 f40466f ed47265 b150b57 f43b9bc ed47265 b150b57 d9d7db9 b150b57 33ce564 d9da728 b150b57 d9da728 057b8f0 d9da728 33ce564 d9da728 33ce564 6b10bcf 33ce564 171cc73 33ce564 d9da728 33ce564 b150b57 b45af94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import json
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, AutoImageProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, model_dir):
# Setup device configuration
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_dir,
torch_dtype=torch.float16,
device_map="auto"
)
self.model.to(self.device)
except Exception as e:
print(f"Error loading model: {e}")
raise
try:
self.processor = AutoProcessor.from_pretrained(model_dir)
self.image_processor = AutoImageProcessor.from_pretrained(model_dir) # Ensure you have the correct processor
except Exception as e:
print(f"Error loading processor: {e}")
raise
self.vqa_pipeline = pipeline(
task="visual-question-answering",
model=self.model,
image_processor=self.image_processor, # Explicit image processor if needed
device=0 if torch.cuda.is_available() else -1
)
def preprocess(self, request_data):
messages = request_data.get('messages')
if not messages:
raise ValueError("Missing 'messages' in request data.")
image_inputs, video_inputs = process_vision_info(messages)
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(self.device)
return inputs
def inference(self, inputs):
with torch.no_grad():
result = self.vqa_pipeline(
images=inputs.get("images", None),
videos=inputs.get("videos", None),
question=inputs["text"]
)
return result
def postprocess(self, inference_output):
return json.dumps(inference_output)
def __call__(self, request):
try:
request_data = json.loads(request)
inputs = self.preprocess(request_data)
result = self.inference(inputs)
return self.postprocess(result)
except Exception as e:
error_message = f"Error: {str(e)}"
print(error_message)
return json.dumps({"error": error_message}) |