VictorSanh's picture
working but checkpoints are weird
ec50e73
import os
import torch
from transformers import AutoProcessor
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
from typing import List
from PIL import Image
PROCESSOR = AutoProcessor.from_pretrained(
"HuggingFaceM4/idefics2",
token=os.environ["HF_AUTH_TOKEN"],
)
def convert_to_rgb(image):
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
# for transparent images. The call to `alpha_composite` handles this case
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
def custom_transform(x):
x = convert_to_rgb(x)
x = to_numpy_array(x)
height, width = x.shape[:2]
aspect_ratio = width / height
if width >= height and width > 980:
width = 980
height = int(width / aspect_ratio)
elif height > width and height > 980:
height = 980
width = int(height * aspect_ratio)
width = max(width, 378)
height = max(height, 378)
x = resize(x, (height, width), resample=PILImageResampling.BILINEAR)
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
x = PROCESSOR.image_processor.normalize(
x,
mean=PROCESSOR.image_processor.image_mean,
std=PROCESSOR.image_processor.image_std
)
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
x = torch.tensor(x)
return x
def create_model_inputs(
input_texts: List[str],
image_lists: List[List[Image.Image]],
):
"""
All this logic will eventually be handled inside the model processor.
"""
inputs = PROCESSOR.tokenizer(
input_texts,
return_tensors="pt",
add_special_tokens=False,
padding=True,
)
output_images = [
[PROCESSOR.image_processor(img, transform=custom_transform) for img in im_list]
for im_list in image_lists
]
total_batch_size = len(output_images)
max_num_images = max([len(img_l) for img_l in output_images])
if max_num_images > 0:
max_height = max([i.size(2) for img_l in output_images for i in img_l])
max_width = max([i.size(3) for img_l in output_images for i in img_l])
padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width)
padded_pixel_attention_masks = torch.zeros(
total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool
)
for batch_idx, img_l in enumerate(output_images):
for img_idx, img in enumerate(img_l):
im_height, im_width = img.size()[2:]
padded_image_tensor[batch_idx, img_idx, :, :im_height, :im_width] = img
padded_pixel_attention_masks[batch_idx, img_idx, :im_height, :im_width] = True
inputs["pixel_values"] = padded_image_tensor
inputs["pixel_attention_mask"] = padded_pixel_attention_masks
return inputs