Spaces:
Running
on
Zero
Running
on
Zero
from PIL import Image | |
from io import BytesIO | |
import base64 | |
import math | |
import ast | |
import torch | |
from transformers import StoppingCriteria | |
from oryx.constants import IMAGE_TOKEN_INDEX | |
import os | |
video_base = 0 | |
video_ps = 64 | |
highres_base = 0 | |
highres_ps = 32 | |
MAXRES = 1536 | |
MINRES = 0 | |
VIDEO_MAXRES = 480 | |
VIDEO_MINRES = 288 | |
LOWRES_RESIZE = (384,32) | |
PAD2STRIDE=False | |
def pad_image(image, target_resolution, value=0): | |
""" | |
Resize and pad an image to a target resolution while maintaining aspect ratio. | |
Args: | |
image (PIL.Image.Image): The input image. | |
target_resolution (tuple): The target resolution (width, height) of the image. | |
Returns: | |
PIL.Image.Image: The resized and padded image. | |
""" | |
original_width, original_height = image.size | |
target_width, target_height = target_resolution | |
# Create a new image with the target size and paste the resized image onto it | |
new_image = Image.new('RGB', (target_width, target_height), (value, value, value)) | |
paste_x = (target_width - original_width) // 2 | |
paste_y = (target_height - original_height) // 2 | |
new_image.paste(image, (paste_x, paste_y)) | |
return new_image | |
def resize_images(image, patch_size=14, base_size=896): | |
h, w = image.size | |
if base_size == 0: | |
if h * w > MAXRES * MAXRES: | |
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') | |
scale = MAXRES * MAXRES / (h * w) | |
scale = math.sqrt(scale) | |
elif h * w < MINRES * MINRES: | |
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') | |
scale = MINRES * MINRES / (h * w) | |
scale = math.sqrt(scale) | |
else: | |
scale = None | |
else: | |
scale = base_size * base_size / (h * w) | |
scale = math.sqrt(scale) | |
if scale is not None: | |
new_h = int(h * scale / patch_size) * patch_size | |
new_w = int(w * scale / patch_size) * patch_size | |
image = image.resize((new_h, new_w)) | |
elif PAD2STRIDE: | |
if h % patch_size == 0: | |
new_h = h | |
else: | |
new_h = (h // patch_size + 1) * patch_size | |
if w % patch_size == 0: | |
new_w = w | |
else: | |
new_w = (w // patch_size + 1) * patch_size | |
image = pad_image(image, (new_h, new_w), value=127) | |
else: | |
scale = 1.0 | |
new_h = int(h * scale / patch_size) * patch_size | |
new_w = int(w * scale / patch_size) * patch_size | |
image = image.resize((new_h, new_w)) | |
return image | |
def resize_video(image, patch_size=14, base_size=896): | |
h, w = image.size | |
if base_size == 0: | |
if h * w > VIDEO_MAXRES * VIDEO_MAXRES: | |
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') | |
scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w) | |
scale = math.sqrt(scale) | |
elif h * w < VIDEO_MINRES * VIDEO_MINRES: | |
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') | |
scale = VIDEO_MINRES * VIDEO_MINRES / (h * w) | |
scale = math.sqrt(scale) | |
else: | |
scale = None | |
else: | |
scale = base_size * base_size / (h * w) | |
scale = math.sqrt(scale) | |
if scale is not None: | |
new_h = int(h * scale / patch_size) * patch_size | |
new_w = int(w * scale / patch_size) * patch_size | |
image = image.resize((new_h, new_w)) | |
elif PAD2STRIDE: | |
if h % patch_size == 0: | |
new_h = h | |
else: | |
new_h = (h // patch_size + 1) * patch_size | |
if w % patch_size == 0: | |
new_w = w | |
else: | |
new_w = (w // patch_size + 1) * patch_size | |
image = pad_image(image, (new_h, new_w), value=127) | |
else: | |
scale = 1.0 | |
new_h = int(h * scale / patch_size) * patch_size | |
new_w = int(w * scale / patch_size) * patch_size | |
image = image.resize((new_h, new_w)) | |
return image | |
def process_anyres_video_genli(image, processor): | |
image = resize_video(image, patch_size=video_ps, base_size=video_base) | |
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
return image.unsqueeze(0) | |
def process_anyres_video_genli_long(image, processor): | |
image = resize_video(image, patch_size=video_ps * 2, base_size=video_base) | |
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
return image.unsqueeze(0) | |
def load_image_from_base64(image): | |
return Image.open(BytesIO(base64.b64decode(image))) | |
def process_anyres_highres_image_genli(image, processor): | |
h, w = image.size | |
if h < 32 and w < 32: | |
min_size = min(h, w) | |
ratio = 64 / min_size | |
image = image.resize((int(h * ratio), int(w * ratio))) | |
elif h < 32: | |
ratio = 64 / h | |
image = image.resize((int(h * ratio), int(w * ratio))) | |
elif w < 32: | |
ratio = 64 / w | |
image = image.resize((int(h * ratio), int(w * ratio))) | |
image = resize_images(image, patch_size=highres_ps, base_size=highres_base) | |
image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0]) | |
# image_patches = [image_original_resize] + [image_original_resize] | |
# image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] | |
# for image_patch in image_patches] | |
image_patches = processor.preprocess(image_original_resize, return_tensors='pt')['pixel_values'][0] | |
image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
# return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0) | |
return image_patches.unsqueeze(0), image_padded.unsqueeze(0) | |
def read_image_patch(patch_info): | |
if 'img_path' in patch_info.keys(): | |
image = Image.open(patch_info['img_path']).convert('RGB') | |
else: | |
if 'image_encoing' in patch_info.keys(): | |
patch_info['image_encoding'] = patch_info['image_encoing'] | |
image_file_name = patch_info['patch'] | |
start_bytes = int(patch_info['start_num']) | |
file_size = int(patch_info['size']) | |
with open(image_file_name, 'rb') as f: | |
f.seek(start_bytes) | |
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64': | |
image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB") | |
else: | |
image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB") | |
return image | |
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): | |
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')] | |
def insert_separator(X, sep): | |
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] | |
input_ids = [] | |
offset = 0 | |
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: | |
offset = 1 | |
input_ids.append(prompt_chunks[0][0]) | |
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): | |
input_ids.extend(x[offset:]) | |
if return_tensors is not None: | |
if return_tensors == 'pt': | |
return torch.tensor(input_ids, dtype=torch.long) | |
raise ValueError(f'Unsupported tensor type: {return_tensors}') | |
return input_ids | |
def get_model_name_from_path(model_path): | |
model_path = model_path.strip("/") | |
model_paths = model_path.split("/") | |
if model_paths[-1].startswith('checkpoint-'): | |
return model_paths[-2] + "_" + model_paths[-1] | |
else: | |
return model_paths[-1] | |
class KeywordsStoppingCriteria(StoppingCriteria): | |
def __init__(self, keywords, tokenizer, input_ids): | |
self.keywords = keywords | |
self.keyword_ids = [] | |
for keyword in keywords: | |
cur_keyword_ids = tokenizer(keyword).input_ids | |
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: | |
cur_keyword_ids = cur_keyword_ids[1:] | |
self.keyword_ids.append(torch.tensor(cur_keyword_ids)) | |
self.tokenizer = tokenizer | |
self.start_len = input_ids.shape[1] | |
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO | |
offset = min(output_ids.shape[1] - self.start_len, 3) | |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] | |
for keyword_id in self.keyword_ids: | |
if output_ids[0, -keyword_id.shape[0]:] == keyword_id: | |
return True | |
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] | |
for keyword in self.keywords: | |
if keyword in outputs: | |
return True | |
return False | |