from typing import Dict, Any, List from PIL import Image import torch from transformers import AutoModelForCausalLM, AutoProcessor from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format class EndpointHandler: def __init__(self, model_path: str): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoProcessor.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device) self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents self.bos_token = self.processor.tokenizer.bos_token self.bad_words_ids = self.processor.tokenizer(["", ""], add_special_tokens=False).input_ids def convert_to_rgb(self, image: Image.Image) -> Image.Image: if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite def custom_transform(self, image: Image.Image) -> torch.Tensor: image = self.convert_to_rgb(image) image = to_numpy_array(image) image = resize(image, (960, 960), resample=PILImageResampling.BILINEAR) image = self.processor.image_processor.rescale(image, scale=1 / 255) image = self.processor.image_processor.normalize( image, mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std ) image = to_channel_dimension_format(image, ChannelDimension.FIRST) return torch.tensor(image) def generate_responses(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: results = [] image = data.get("inputs") if isinstance(image, str): try: image = Image.open(image) except Exception as e: results.append({"error": f"Failed to open image: {e}"}) return results try: inputs = self.processor.tokenizer( f"{self.bos_token}{'' * self.image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform) inputs = {k: v.to(self.device) for k, v in inputs.items()} generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True) generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] results.append({"label": generated_text, "score": 1.0}) except torch.cuda.CudaError as e: results.append({"error": f"CUDA error: {e}"}) except Exception as e: results.append({"error": f"Unexpected error: {e}"}) return results def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: return self.generate_responses(data)