lorocksUMD commited on
Commit
fbfc736
1 Parent(s): afcb7d3

Update llava/mm_utils.py

Browse files
Files changed (1) hide show
  1. llava/mm_utils.py +248 -247
llava/mm_utils.py CHANGED
@@ -1,247 +1,248 @@
1
- from PIL import Image
2
- from io import BytesIO
3
- import base64
4
- import torch
5
- import math
6
- import ast
7
-
8
- from transformers import StoppingCriteria
9
- from llava.constants import IMAGE_TOKEN_INDEX
10
-
11
-
12
- def select_best_resolution(original_size, possible_resolutions):
13
- """
14
- Selects the best resolution from a list of possible resolutions based on the original size.
15
-
16
- Args:
17
- original_size (tuple): The original size of the image in the format (width, height).
18
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
-
20
- Returns:
21
- tuple: The best fit resolution in the format (width, height).
22
- """
23
- original_width, original_height = original_size
24
- best_fit = None
25
- max_effective_resolution = 0
26
- min_wasted_resolution = float('inf')
27
-
28
- for width, height in possible_resolutions:
29
- scale = min(width / original_width, height / original_height)
30
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
- wasted_resolution = (width * height) - effective_resolution
33
-
34
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
- max_effective_resolution = effective_resolution
36
- min_wasted_resolution = wasted_resolution
37
- best_fit = (width, height)
38
-
39
- return best_fit
40
-
41
-
42
- def resize_and_pad_image(image, target_resolution):
43
- """
44
- Resize and pad an image to a target resolution while maintaining aspect ratio.
45
-
46
- Args:
47
- image (PIL.Image.Image): The input image.
48
- target_resolution (tuple): The target resolution (width, height) of the image.
49
-
50
- Returns:
51
- PIL.Image.Image: The resized and padded image.
52
- """
53
- original_width, original_height = image.size
54
- target_width, target_height = target_resolution
55
-
56
- scale_w = target_width / original_width
57
- scale_h = target_height / original_height
58
-
59
- if scale_w < scale_h:
60
- new_width = target_width
61
- new_height = min(math.ceil(original_height * scale_w), target_height)
62
- else:
63
- new_height = target_height
64
- new_width = min(math.ceil(original_width * scale_h), target_width)
65
-
66
- # Resize the image
67
- resized_image = image.resize((new_width, new_height))
68
-
69
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
- paste_x = (target_width - new_width) // 2
71
- paste_y = (target_height - new_height) // 2
72
- new_image.paste(resized_image, (paste_x, paste_y))
73
-
74
- return new_image
75
-
76
-
77
- def divide_to_patches(image, patch_size):
78
- """
79
- Divides an image into patches of a specified size.
80
-
81
- Args:
82
- image (PIL.Image.Image): The input image.
83
- patch_size (int): The size of each patch.
84
-
85
- Returns:
86
- list: A list of PIL.Image.Image objects representing the patches.
87
- """
88
- patches = []
89
- width, height = image.size
90
- for i in range(0, height, patch_size):
91
- for j in range(0, width, patch_size):
92
- box = (j, i, j + patch_size, i + patch_size)
93
- patch = image.crop(box)
94
- patches.append(patch)
95
-
96
- return patches
97
-
98
-
99
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
- """
101
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
-
103
- Args:
104
- image_size (tuple): The size of the input image in the format (width, height).
105
- grid_pinpoints (str): A string representation of a list of possible resolutions.
106
- patch_size (int): The size of each image patch.
107
-
108
- Returns:
109
- tuple: The shape of the image patch grid in the format (width, height).
110
- """
111
- if type(grid_pinpoints) is list:
112
- possible_resolutions = grid_pinpoints
113
- else:
114
- possible_resolutions = ast.literal_eval(grid_pinpoints)
115
- width, height = select_best_resolution(image_size, possible_resolutions)
116
- return width // patch_size, height // patch_size
117
-
118
-
119
- def process_anyres_image(image, processor, grid_pinpoints):
120
- """
121
- Process an image with variable resolutions.
122
-
123
- Args:
124
- image (PIL.Image.Image): The input image to be processed.
125
- processor: The image processor object.
126
- grid_pinpoints (str): A string representation of a list of possible resolutions.
127
-
128
- Returns:
129
- torch.Tensor: A tensor containing the processed image patches.
130
- """
131
- if type(grid_pinpoints) is list:
132
- possible_resolutions = grid_pinpoints
133
- else:
134
- possible_resolutions = ast.literal_eval(grid_pinpoints)
135
- best_resolution = select_best_resolution(image.size, possible_resolutions)
136
- image_padded = resize_and_pad_image(image, best_resolution)
137
-
138
- patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
-
140
- image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141
-
142
- image_patches = [image_original_resize] + patches
143
- image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144
- for image_patch in image_patches]
145
- return torch.stack(image_patches, dim=0)
146
-
147
-
148
- def load_image_from_base64(image):
149
- return Image.open(BytesIO(base64.b64decode(image)))
150
-
151
-
152
- def expand2square(pil_img, background_color):
153
- width, height = pil_img.size
154
- if width == height:
155
- return pil_img
156
- elif width > height:
157
- result = Image.new(pil_img.mode, (width, width), background_color)
158
- result.paste(pil_img, (0, (width - height) // 2))
159
- return result
160
- else:
161
- result = Image.new(pil_img.mode, (height, height), background_color)
162
- result.paste(pil_img, ((height - width) // 2, 0))
163
- return result
164
-
165
-
166
- def process_images(images, image_processor, model_cfg):
167
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168
- new_images = []
169
- if image_aspect_ratio == 'pad':
170
- for image in images:
171
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173
- new_images.append(image)
174
- elif image_aspect_ratio == "anyres":
175
- for image in images:
176
- image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177
- new_images.append(image)
178
- else:
179
- return image_processor(images, return_tensors='pt')['pixel_values']
180
- if all(x.shape == new_images[0].shape for x in new_images):
181
- new_images = torch.stack(new_images, dim=0)
182
- return new_images
183
-
184
-
185
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
186
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
187
-
188
- def insert_separator(X, sep):
189
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190
-
191
- input_ids = []
192
- offset = 0
193
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
194
- offset = 1
195
- input_ids.append(prompt_chunks[0][0])
196
-
197
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
198
- input_ids.extend(x[offset:])
199
-
200
- if return_tensors is not None:
201
- if return_tensors == 'pt':
202
- return torch.tensor(input_ids, dtype=torch.long)
203
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
204
- return input_ids
205
-
206
-
207
- def get_model_name_from_path(model_path):
208
- model_path = model_path.strip("/")
209
- model_paths = model_path.split("/")
210
- if model_paths[-1].startswith('checkpoint-'):
211
- return model_paths[-2] + "_" + model_paths[-1]
212
- else:
213
- return model_paths[-1]
214
-
215
- class KeywordsStoppingCriteria(StoppingCriteria):
216
- def __init__(self, keywords, tokenizer, input_ids):
217
- self.keywords = keywords
218
- self.keyword_ids = []
219
- self.max_keyword_len = 0
220
- for keyword in keywords:
221
- cur_keyword_ids = tokenizer(keyword).input_ids
222
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
223
- cur_keyword_ids = cur_keyword_ids[1:]
224
- if len(cur_keyword_ids) > self.max_keyword_len:
225
- self.max_keyword_len = len(cur_keyword_ids)
226
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
227
- self.tokenizer = tokenizer
228
- self.start_len = input_ids.shape[1]
229
-
230
- def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
231
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
232
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
233
- for keyword_id in self.keyword_ids:
234
- truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
235
- if torch.equal(truncated_output_ids, keyword_id):
236
- return True
237
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
238
- for keyword in self.keywords:
239
- if keyword in outputs:
240
- return True
241
- return False
242
-
243
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
- outputs = []
245
- for i in range(output_ids.shape[0]):
246
- outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
247
- return all(outputs)
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+
8
+ from transformers import StoppingCriteria
9
+ from llava.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def select_best_resolution(original_size, possible_resolutions):
13
+ """
14
+ Selects the best resolution from a list of possible resolutions based on the original size.
15
+
16
+ Args:
17
+ original_size (tuple): The original size of the image in the format (width, height).
18
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
+
20
+ Returns:
21
+ tuple: The best fit resolution in the format (width, height).
22
+ """
23
+ original_width, original_height = original_size
24
+ best_fit = None
25
+ max_effective_resolution = 0
26
+ min_wasted_resolution = float('inf')
27
+
28
+ for width, height in possible_resolutions:
29
+ scale = min(width / original_width, height / original_height)
30
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
+ wasted_resolution = (width * height) - effective_resolution
33
+
34
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
+ max_effective_resolution = effective_resolution
36
+ min_wasted_resolution = wasted_resolution
37
+ best_fit = (width, height)
38
+
39
+ return best_fit
40
+
41
+
42
+ def resize_and_pad_image(image, target_resolution):
43
+ """
44
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
45
+
46
+ Args:
47
+ image (PIL.Image.Image): The input image.
48
+ target_resolution (tuple): The target resolution (width, height) of the image.
49
+
50
+ Returns:
51
+ PIL.Image.Image: The resized and padded image.
52
+ """
53
+ original_width, original_height = image.size
54
+ target_width, target_height = target_resolution
55
+
56
+ scale_w = target_width / original_width
57
+ scale_h = target_height / original_height
58
+
59
+ if scale_w < scale_h:
60
+ new_width = target_width
61
+ new_height = min(math.ceil(original_height * scale_w), target_height)
62
+ else:
63
+ new_height = target_height
64
+ new_width = min(math.ceil(original_width * scale_h), target_width)
65
+
66
+ # Resize the image
67
+ resized_image = image.resize((new_width, new_height))
68
+
69
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
+ paste_x = (target_width - new_width) // 2
71
+ paste_y = (target_height - new_height) // 2
72
+ new_image.paste(resized_image, (paste_x, paste_y))
73
+
74
+ return new_image
75
+
76
+
77
+ def divide_to_patches(image, patch_size):
78
+ """
79
+ Divides an image into patches of a specified size.
80
+
81
+ Args:
82
+ image (PIL.Image.Image): The input image.
83
+ patch_size (int): The size of each patch.
84
+
85
+ Returns:
86
+ list: A list of PIL.Image.Image objects representing the patches.
87
+ """
88
+ patches = []
89
+ width, height = image.size
90
+ for i in range(0, height, patch_size):
91
+ for j in range(0, width, patch_size):
92
+ box = (j, i, j + patch_size, i + patch_size)
93
+ patch = image.crop(box)
94
+ patches.append(patch)
95
+
96
+ return patches
97
+
98
+
99
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
+ """
101
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
+
103
+ Args:
104
+ image_size (tuple): The size of the input image in the format (width, height).
105
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
106
+ patch_size (int): The size of each image patch.
107
+
108
+ Returns:
109
+ tuple: The shape of the image patch grid in the format (width, height).
110
+ """
111
+ if type(grid_pinpoints) is list:
112
+ possible_resolutions = grid_pinpoints
113
+ else:
114
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
115
+ width, height = select_best_resolution(image_size, possible_resolutions)
116
+ return width // patch_size, height // patch_size
117
+
118
+
119
+ def process_anyres_image(image, processor, grid_pinpoints):
120
+ """
121
+ Process an image with variable resolutions.
122
+
123
+ Args:
124
+ image (PIL.Image.Image): The input image to be processed.
125
+ processor: The image processor object.
126
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
127
+
128
+ Returns:
129
+ torch.Tensor: A tensor containing the processed image patches.
130
+ """
131
+ if type(grid_pinpoints) is list:
132
+ possible_resolutions = grid_pinpoints
133
+ else:
134
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
135
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
136
+ image_padded = resize_and_pad_image(image, best_resolution)
137
+
138
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
+
140
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141
+
142
+ image_patches = [image_original_resize] + patches
143
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144
+ for image_patch in image_patches]
145
+ return torch.stack(image_patches, dim=0)
146
+
147
+
148
+ def load_image_from_base64(image):
149
+ return Image.open(BytesIO(base64.b64decode(image)))
150
+
151
+
152
+ def expand2square(pil_img, background_color):
153
+ width, height = pil_img.size
154
+ if width == height:
155
+ return pil_img
156
+ elif width > height:
157
+ result = Image.new(pil_img.mode, (width, width), background_color)
158
+ result.paste(pil_img, (0, (width - height) // 2))
159
+ return result
160
+ else:
161
+ result = Image.new(pil_img.mode, (height, height), background_color)
162
+ result.paste(pil_img, ((height - width) // 2, 0))
163
+ return result
164
+
165
+
166
+ def process_images(images, image_processor, model_cfg):
167
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168
+ image_aspect_ratio = 'pad'
169
+ new_images = []
170
+ if image_aspect_ratio == 'pad':
171
+ for image in images:
172
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
173
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
174
+ new_images.append(image)
175
+ elif image_aspect_ratio == "anyres":
176
+ for image in images:
177
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
178
+ new_images.append(image)
179
+ else:
180
+ return image_processor(images, return_tensors='pt')['pixel_values']
181
+ if all(x.shape == new_images[0].shape for x in new_images):
182
+ new_images = torch.stack(new_images, dim=0)
183
+ return new_images
184
+
185
+
186
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
187
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
188
+
189
+ def insert_separator(X, sep):
190
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
191
+
192
+ input_ids = []
193
+ offset = 0
194
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
195
+ offset = 1
196
+ input_ids.append(prompt_chunks[0][0])
197
+
198
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
199
+ input_ids.extend(x[offset:])
200
+
201
+ if return_tensors is not None:
202
+ if return_tensors == 'pt':
203
+ return torch.tensor(input_ids, dtype=torch.long)
204
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
205
+ return input_ids
206
+
207
+
208
+ def get_model_name_from_path(model_path):
209
+ model_path = model_path.strip("/")
210
+ model_paths = model_path.split("/")
211
+ if model_paths[-1].startswith('checkpoint-'):
212
+ return model_paths[-2] + "_" + model_paths[-1]
213
+ else:
214
+ return model_paths[-1]
215
+
216
+ class KeywordsStoppingCriteria(StoppingCriteria):
217
+ def __init__(self, keywords, tokenizer, input_ids):
218
+ self.keywords = keywords
219
+ self.keyword_ids = []
220
+ self.max_keyword_len = 0
221
+ for keyword in keywords:
222
+ cur_keyword_ids = tokenizer(keyword).input_ids
223
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
224
+ cur_keyword_ids = cur_keyword_ids[1:]
225
+ if len(cur_keyword_ids) > self.max_keyword_len:
226
+ self.max_keyword_len = len(cur_keyword_ids)
227
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
228
+ self.tokenizer = tokenizer
229
+ self.start_len = input_ids.shape[1]
230
+
231
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
232
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
233
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
234
+ for keyword_id in self.keyword_ids:
235
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
236
+ if torch.equal(truncated_output_ids, keyword_id):
237
+ return True
238
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
239
+ for keyword in self.keywords:
240
+ if keyword in outputs:
241
+ return True
242
+ return False
243
+
244
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
245
+ outputs = []
246
+ for i in range(output_ids.shape[0]):
247
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
248
+ return all(outputs)