|
import torch, os, PIL, numbers |
|
from PIL import Image |
|
import cv2 |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.models.siglip.modeling_siglip import SiglipVisionModel |
|
from transformers import AutoConfig, AutoModel, SiglipImageProcessor, SiglipVisionConfig, PretrainedConfig |
|
from typing import Union |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
|
|
def crop_clip(clip, min_h, min_w, h, w): |
|
if isinstance(clip[0], np.ndarray): |
|
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] |
|
|
|
elif isinstance(clip[0], PIL.Image.Image): |
|
cropped = [ |
|
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip |
|
] |
|
else: |
|
raise TypeError('Expected numpy.ndarray or PIL.Image' + |
|
'but got list of {0}'.format(type(clip[0]))) |
|
return cropped |
|
|
|
|
|
class Normalize(object): |
|
"""Normalize a clip with mean and standard deviation. |
|
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform |
|
will normalize each channel of the input ``torch.*Tensor`` i.e. |
|
``input[channel] = (input[channel] - mean[channel]) / std[channel]`` |
|
.. note:: |
|
This transform acts out of place, i.e., it does not mutates the input tensor. |
|
Args: |
|
mean (sequence): Sequence of means for each channel. |
|
std (sequence): Sequence of standard deviations for each channel. |
|
""" |
|
|
|
def __init__(self, mean, std): |
|
self.mean = mean |
|
self.std = std |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. |
|
Returns: |
|
Tensor: Normalized Tensor clip. |
|
""" |
|
return normalize(clip, self.mean, self.std) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) |
|
|
|
|
|
class CenterCrop(object): |
|
"""Extract center crop at the same location for a list of images |
|
Args: |
|
size (sequence or int): Desired output size for the |
|
crop in format (h, w) |
|
""" |
|
|
|
def __init__(self, size): |
|
if isinstance(size, numbers.Number): |
|
size = (size, size) |
|
|
|
self.size = size |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
img (PIL.Image or numpy.ndarray): List of images to be cropped |
|
in format (h, w, c) in numpy.ndarray |
|
Returns: |
|
PIL.Image or numpy.ndarray: Cropped list of images |
|
""" |
|
h, w = self.size |
|
if isinstance(clip[0], np.ndarray): |
|
im_h, im_w, im_c = clip[0].shape |
|
elif isinstance(clip[0], PIL.Image.Image): |
|
im_w, im_h = clip[0].size |
|
else: |
|
raise TypeError('Expected numpy.ndarray or PIL.Image' + |
|
'but got list of {0}'.format(type(clip[0]))) |
|
if w > im_w or h > im_h: |
|
error_msg = ( |
|
'Initial image size should be larger then ' |
|
'cropped size but got cropped sizes : ({w}, {h}) while ' |
|
'initial image is ({im_w}, {im_h})'.format( |
|
im_w=im_w, im_h=im_h, w=w, h=h)) |
|
raise ValueError(error_msg) |
|
|
|
x1 = int(round((im_w - w) / 2.)) |
|
y1 = int(round((im_h - h) / 2.)) |
|
cropped = crop_clip(clip, y1, x1, h, w) |
|
|
|
return cropped |
|
|
|
|
|
def resize_clip(clip, size, interpolation='bilinear'): |
|
if isinstance(clip[0], np.ndarray): |
|
if isinstance(size, numbers.Number): |
|
im_h, im_w, im_c = clip[0].shape |
|
|
|
if (im_w <= im_h and im_w == size) or (im_h <= im_w |
|
and im_h == size): |
|
return clip |
|
new_h, new_w = get_resize_sizes(im_h, im_w, size) |
|
size = (new_w, new_h) |
|
else: |
|
size = size[0], size[1] |
|
if interpolation == 'bilinear': |
|
np_inter = cv2.INTER_LINEAR |
|
else: |
|
np_inter = cv2.INTER_NEAREST |
|
scaled = [ |
|
cv2.resize(img, size, interpolation=np_inter) for img in clip |
|
] |
|
elif isinstance(clip[0], PIL.Image.Image): |
|
if isinstance(size, numbers.Number): |
|
im_w, im_h = clip[0].size |
|
|
|
if (im_w <= im_h and im_w == size) or (im_h <= im_w |
|
and im_h == size): |
|
return clip |
|
new_h, new_w = get_resize_sizes(im_h, im_w, size) |
|
size = (new_w, new_h) |
|
else: |
|
size = size[1], size[0] |
|
if interpolation == 'bilinear': |
|
pil_inter = PIL.Image.BILINEAR |
|
else: |
|
pil_inter = PIL.Image.NEAREST |
|
scaled = [img.resize(size, pil_inter) for img in clip] |
|
else: |
|
raise TypeError('Expected numpy.ndarray or PIL.Image' + |
|
'but got list of {0}'.format(type(clip[0]))) |
|
return scaled |
|
|
|
|
|
def _is_tensor_clip(clip): |
|
return torch.is_tensor(clip) and clip.ndimension() == 4 |
|
|
|
|
|
def get_resize_sizes(im_h, im_w, size): |
|
if im_w < im_h: |
|
ow = size |
|
oh = int(size * im_h / im_w) |
|
else: |
|
oh = size |
|
ow = int(size * im_w / im_h) |
|
return oh, ow |
|
|
|
|
|
def normalize(clip, mean, std, inplace=False): |
|
if not _is_tensor_clip(clip): |
|
raise TypeError('tensor is not a torch clip.') |
|
|
|
if not inplace: |
|
clip = clip.clone() |
|
|
|
dtype = clip.dtype |
|
mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) |
|
std = torch.as_tensor(std, dtype=dtype, device=clip.device) |
|
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) |
|
|
|
return clip |
|
|
|
|
|
class Resize(object): |
|
"""Resizes a list of (H x W x C) numpy.ndarray to the final size |
|
The larger the original image is, the more times it takes to |
|
interpolate |
|
Args: |
|
interpolation (str): Can be one of 'nearest', 'bilinear' |
|
defaults to nearest |
|
size (tuple): (widht, height) |
|
""" |
|
|
|
def __init__(self, size, interpolation='nearest'): |
|
self.size = size |
|
self.interpolation = interpolation |
|
|
|
def __call__(self, clip): |
|
resized = resize_clip( |
|
clip, self.size, interpolation=self.interpolation) |
|
return resized |
|
|
|
|
|
class Compose(object): |
|
"""Composes several transforms |
|
Args: |
|
transforms (list of ``Transform`` objects): list of transforms |
|
to compose |
|
""" |
|
|
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, clip): |
|
for t in self.transforms: |
|
clip = t(clip) |
|
return clip |
|
|
|
|
|
def convert_img(img): |
|
"""Converts (H, W, C) numpy.ndarray to (C, W, H) format""" |
|
if len(img.shape) == 3: |
|
img = img.transpose(2, 0, 1) |
|
if len(img.shape) == 2: |
|
img = np.expand_dims(img, 0) |
|
return img |
|
|
|
|
|
class ClipToTensor(object): |
|
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] |
|
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] |
|
""" |
|
|
|
def __init__(self, channel_nb=3, div_255=True, numpy=False): |
|
self.channel_nb = channel_nb |
|
self.div_255 = div_255 |
|
self.numpy = numpy |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: clip (list of numpy.ndarray): clip (list of images) |
|
to be converted to tensor. |
|
""" |
|
|
|
if isinstance(clip[0], np.ndarray): |
|
h, w, ch = clip[0].shape |
|
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) |
|
elif isinstance(clip[0], Image.Image): |
|
w, h = clip[0].size |
|
else: |
|
raise TypeError( |
|
"Expected numpy.ndarray or PIL.Image\ |
|
but got list of {0}".format( |
|
type(clip[0]) |
|
) |
|
) |
|
|
|
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) |
|
|
|
|
|
for img_idx, img in enumerate(clip): |
|
if isinstance(img, np.ndarray): |
|
pass |
|
elif isinstance(img, Image.Image): |
|
img = np.array(img, copy=False) |
|
else: |
|
raise TypeError( |
|
"Expected numpy.ndarray or PIL.Image\ |
|
but got list of {0}".format( |
|
type(clip[0]) |
|
) |
|
) |
|
img = convert_img(img) |
|
np_clip[:, img_idx, :, :] = img |
|
if self.numpy: |
|
if self.div_255: |
|
np_clip = np_clip / 255.0 |
|
return np_clip |
|
|
|
else: |
|
tensor_clip = torch.from_numpy(np_clip) |
|
|
|
if not isinstance(tensor_clip, torch.FloatTensor): |
|
tensor_clip = tensor_clip.float() |
|
if self.div_255: |
|
tensor_clip = torch.div(tensor_clip, 255) |
|
return tensor_clip |
|
|
|
|
|
class VisionTowerConfig(PretrainedConfig): |
|
model_type = "vision_tower" |
|
|
|
def __init__(self, vision_tower_name: str = None, **kwargs): |
|
super().__init__() |
|
self.vision_tower_name = vision_tower_name |
|
|
|
|
|
class ProcessorWrapper: |
|
def __init__(self, transform=None, processor=None, height=378, width=378, frames_per_clip=1, |
|
image_mean=[0.48145466, 0.4578275, 0.40821073]): |
|
assert transform is not None or processor is not None, "ERROR: you did not define both `transform` and `processor`! You must define either transform or processor" |
|
assert transform is None or processor is None, "ERROR: you did defined both `transform` and `processor`! You must define only one of: transform or processor" |
|
self._size = { |
|
"height": height, |
|
"width": width, |
|
"frames_per_clip": frames_per_clip |
|
} |
|
self._transforms = transform |
|
self._processor = processor |
|
self.image_mean = image_mean |
|
|
|
@property |
|
def size(self): |
|
return self._size |
|
|
|
def preprocess(self, image, return_tensors='pt'): |
|
|
|
output = {} |
|
if self._transforms is not None: |
|
output['pixel_values'] = [self._transforms(image)] |
|
|
|
else: |
|
output = self._processor(image, return_tensors='pt') |
|
return output |
|
|
|
def save_pretrained(self, save_path): |
|
if self._transforms is not None: |
|
transform_dict = transform_to_dict(self._transforms) |
|
transform_dict["image_processor_type"] = "transforms" |
|
with open(os.path.join(save_path, 'preprocessor_config.json'), 'w') as f: |
|
json.dump(transform_dict, f, indent=4) |
|
else: |
|
self._processor.save_pretrained(save_path) |
|
return |
|
|
|
|
|
class VisionTower(PreTrainedModel): |
|
config_class = VisionTowerConfig |
|
|
|
def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: VisionTowerConfig = None): |
|
super().__init__(vision_config) |
|
self.vision_tower_name = model_name_or_path |
|
self.vision_config = vision_config |
|
self.select_layer = getattr(config, "mm_vision_select_layer", -2) |
|
self.select_feature = getattr(config, "mm_vision_select_feature", "patch") |
|
self.encode_batch_size = getattr(config, "encode_batch_size", 0) // 2 |
|
self.num_encode_batch = getattr(config, "num_encode_batch", 0) // 2 |
|
self.temporal_tubelet_size = getattr(vision_config, "tubelet_size", 1) |
|
|
|
def feature_select(self, image_features): |
|
if self.select_layer is not None: |
|
image_features = image_features.hidden_states[self.select_layer] |
|
|
|
if self.select_feature == "patch": |
|
image_features = image_features[:, 1:] |
|
elif self.select_feature == "cls_patch": |
|
image_features = image_features |
|
else: |
|
raise ValueError(f"Unexpected select feature: {self.select_feature}") |
|
|
|
return image_features |
|
|
|
def vision_tower_forward(self, image): |
|
image_feature = self.vision_tower(image, output_hidden_states=True) |
|
return image_feature |
|
|
|
def _forward(self, images, out_T=1): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_feature = self.vision_tower_forward(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) |
|
image_feature = self.feature_select(image_feature).to(image.dtype) |
|
image_feature = image_features.reshape(image_feature.shape[0], self.W, self.H, self.D) |
|
image_features.append(image_feature) |
|
else: |
|
original_shape = images.shape |
|
if len(original_shape) == 5 and self.T == 1: |
|
|
|
images = images[:, ::original_shape[1] // out_T, ...] |
|
original_shape = images.shape |
|
images = images.view(-1, *original_shape[2:]) |
|
|
|
image_features = self.vision_tower_forward(images.to(device=self.device, dtype=self.dtype)) |
|
image_features = self.feature_select(image_features).to(images.dtype) |
|
|
|
if len(original_shape) == 5 and self.T == 1: |
|
|
|
new_shape = list(image_features.shape[:-2]) + [self.W, self.H, self.hidden_size] |
|
image_features = image_features.reshape(new_shape) |
|
feature_size = image_features.shape[1:] |
|
image_features = image_features.view(original_shape[0], original_shape[1], *feature_size) |
|
|
|
else: |
|
image_features = image_features.reshape(image_features.shape[0], self.T, self.W, self.H, self.hidden_size) |
|
|
|
return image_features |
|
|
|
def forward(self, images): |
|
return self._forward(images) |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_tower.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.vision_tower.device |
|
|
|
@property |
|
def num_patches(self): |
|
return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
|
class InternVideoTower(VisionTower): |
|
def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None): |
|
if vision_config is None: |
|
vision_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
|
|
super().__init__(model_name_or_path, config, vision_config) |
|
self.vision_config = vision_config |
|
normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
|
|
print('loading: ', model_name_or_path) |
|
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
self.vision_tower = model.to(dtype=eval(config.model_dtype)) |
|
|
|
transform = Compose([ |
|
Resize(self.vision_config.img_size, interpolation='bilinear'), |
|
CenterCrop(size=(self.vision_config.img_size, self.vision_config.img_size)), |
|
ClipToTensor(), |
|
Normalize(mean=normalize[0], std=normalize[1]) |
|
]) |
|
|
|
self.vision_processor = ProcessorWrapper(transform=transform, |
|
height=self.vision_config.img_size, |
|
width=self.vision_config.img_size, |
|
frames_per_clip=self.vision_config.num_frames, |
|
image_mean=normalize[0]) |
|
|
|
self.W = self.H = vision_config.img_size // vision_config.patch_size |
|
self.T = self.vision_config.num_frames // self.vision_config.tubelet_size |
|
self.num_frames = self.vision_config.num_frames |
|
self.hidden_size = vision_config.d_model |
|
self.vision_select_layer=self.select_layer |
|
self.select_layer=None |
|
|
|
def vision_tower_forward(self, video): |
|
if video.shape[-3] < self.num_frames: |
|
video = video.repeat_interleave(self.num_frames, dim=-3) |
|
elif video.shape[-3] > self.num_frames: |
|
video = video[:, :, ::video.shape[-3] // self.num_frames, ...] |
|
|
|
video_feature = self.vision_tower(video.to(device=self.device, dtype=self.dtype), |
|
x_vis_return_idx=self.vision_select_layer, x_vis_only=True) |
|
|
|
return video_feature |
|
|
|
@property |
|
def device(self): |
|
return self.vision_tower.pos_embed.device |
|
|
|
|
|
class SiglipVisionTower(VisionTower): |
|
def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None): |
|
if vision_config is None: |
|
vision_config = SiglipVisionConfig.from_pretrained(model_name_or_path) |
|
|
|
super().__init__(model_name_or_path, config, vision_config) |
|
self.vision_config = vision_config |
|
self.vision_tower_name = model_name_or_path |
|
self.vision_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name) |
|
|
|
print('loading: ', model_name_or_path) |
|
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name) |
|
|
|
self.hidden_size = self.vision_config.hidden_size |
|
self.W = self.H = self.vision_config.image_size // self.vision_config.patch_size |
|
self.T = 1 |
|
self.select_feature = "cls_patch" |
|
|
|
|
|
class ApolloVisionTower(PreTrainedModel): |
|
def __init__(self, config, vision_tower_cfg): |
|
super(ApolloVisionTower, self).__init__(config, vision_tower_cfg) |
|
self.model_name_or_path = vision_tower_cfg._name_or_path |
|
self.vision_towers = vision_tower_cfg.vision_towers |
|
self._config = vision_tower_cfg |
|
|
|
for vision_tower_name in self.vision_towers: |
|
if 'internvideo' in vision_tower_name.lower(): |
|
vision_tower = InternVideoTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name), config) |
|
elif 'siglip' in vision_tower_name.lower(): |
|
vision_tower = SiglipVisionTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name), |
|
config) |
|
|
|
setattr(self, vision_tower_name, vision_tower) |
|
|
|
self.vision_processor = [getattr(self, vt).vision_processor for vt in self.vision_towers] |
|
self.num_vision_encoders = len(self.vision_towers) |
|
self.W = self.H = max([getattr(self, vt).W for vt in self.vision_towers]) |
|
self.T = max([getattr(self, vt).T for vt in self.vision_towers]) |
|
self.max_tubelet_size = max( |
|
[getattr(getattr(self, vt).vision_config, 'tubelet_size', 1) for vt in self.vision_towers]) |
|
|
|
self._hidden_size = sum([getattr(self, vt).hidden_size for vt in self.vision_towers]) |
|
self.token_output_shape = (self.T, self.W, self.H) |
|
self.config.num_vision_encoders = self.num_vision_encoders |
|
self.config.vision_towers = self.vision_towers |
|
self.config.token_output_shape = self.token_output_shape |
|
|
|
def forward(self, x): |
|
output_features = [] |
|
for x_s, vision_tower_name in zip(x, self.vision_towers): |
|
vision_tower = getattr(self, vision_tower_name) |
|
features = vision_tower._forward(x_s, out_T=self.T) |
|
|
|
if len(features.shape) != len(self.token_output_shape) + 2: |
|
features = features.unsqueeze(1) |
|
|
|
if features.shape[-len(self.token_output_shape) - 1:-1] != self.token_output_shape: |
|
features = features.permute(0, 4, 1, 2, 3).contiguous() |
|
features = F.interpolate(features.to(torch.float32), size=self.token_output_shape, mode='trilinear', |
|
align_corners=False).to(features.dtype) |
|
features = features.permute(0, 2, 3, 4, 1).contiguous() |
|
|
|
output_features.append(features) |
|
|
|
output_features = torch.cat(output_features, dim=-1) |
|
output_features = torch.flatten(output_features, start_dim=1, end_dim=-2) |
|
return output_features |
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: Union[str, os.PathLike], |
|
state_dict=None, |
|
**kwargs, |
|
): |
|
if state_dict is None: |
|
state_dict = self.state_dict() |
|
|
|
for vision_tower_name in self.vision_towers: |
|
vision_tower = getattr(self, vision_tower_name) |
|
vision_tower_state_dict = OrderedDict( |
|
{k.split(f"vision_tower.{vision_tower_name}.vision_tower.")[-1]: v for k, v in state_dict.items() if |
|
vision_tower_name in k} |
|
) |
|
vision_tower.vision_tower.save_pretrained(os.path.join(save_directory, vision_tower_name), |
|
state_dict=vision_tower_state_dict, **kwargs) |
|
vision_tower.vision_processor.save_pretrained(os.path.join(save_directory, vision_tower_name)) |
|
|
|
config = self.config |
|
config.configs = {} |
|
config.save_pretrained(save_directory) |
|
|
|
@property |
|
def patch_size(self): |
|
return self._patch_size |
|
|
|
@property |
|
def image_size(self): |
|
return self._image_size |
|
|
|
@property |
|
def hidden_size(self): |
|
return self._hidden_size |
|
|
|
|