from abc import abstractmethod import torch import torch.nn as nn import logging as logger class ProcessorWrapper: def __init__(self, transform, height=378, width=378, image_mean = [0.48145466, 0.4578275, 0.40821073]): self._crop_size = { "height": height, "width": width, } self._transforms = transform #print(transform) self.image_mean = image_mean @property def crop_size(self): return self._crop_size def __call__(self, image, return_tensors='pt'): # Ensure image is a PIL Image if isinstance(image, list): image = image[0] output = {} output['pixel_values'] = [self._transforms(image)] return output def preprocess(self, image, return_tensors='pt'): # Ensure image is a PIL Image if isinstance(image, list): image = image[0] output = {} output['pixel_values'] = [self._transforms(image)] return output class BaseVisionTower(nn.Module): def __init__(self, vision_tower_name, args, delay_load=False): super(BaseVisionTower, self).__init__() self.is_loaded = False self.args = args self.vision_tower_name = vision_tower_name self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') self.unfreeze_mm_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) logger.warning(f"Unfreezing MM Vision Tower: {self.unfreeze_mm_vision_tower}") self.delay_load = delay_load @abstractmethod def load_model(self, device_map=None): raise NotImplementedError("Subclasses must implement load_model") @abstractmethod def _forward(self, images): raise NotImplementedError("Subclasses must implement forward") def forward(self, images): if type(images) is list: image_features = [ self._forward(image.unsqueeze(0)) for image in images ] else: image_features = self._forward(images) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): # Dynamically infer the dtype from the first parameter, if not explicitly specified if hasattr(self.vision_tower, 'dtype'): return self.vision_tower.dtype else: params = list(self.vision_tower.parameters()) return params[0].dtype if len(params) > 0 else torch.float32 # Default to torch.float32 if no parameters @property def device(self): # Dynamically infer the device from the first parameter, if not explicitly specified if hasattr(self.vision_tower, 'device'): return self.vision_tower.device else: params = list(self.vision_tower.parameters()) return params[0].device if len(params) > 0 else torch.device("cpu") # Default to CPU if no parameters @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): try: return self.config.hidden_size except: return self._hidden_size @property def image_size(self): # resolution # return self.config.image_size try: return self.config.image_size except: return self._image_size @property def patch_size(self): # return self.config.patch_size try: return self.config.patch_size except: return self._patch_size @property def num_patches_per_side(self): if self._interp_size is not None: return int(self._interp_size**0.5) try: return self.image_size // self.patch_size except: return self._num_patches_per_side @property def num_patches(self): if self._interp_size is not None: return self._interp_size try: return self.num_patches_per_side ** 2 except: return self._num_patches