pan-yl commited on
Commit
bf225fc
1 Parent(s): 97e6a2f
infer.py CHANGED
@@ -16,43 +16,12 @@ from scepter.modules.utils.distribute import we
16
  from scepter.modules.utils.logger import get_logger
17
  from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
18
 
19
- def check_list_of_list(ll):
20
- return isinstance(ll, list) and all(isinstance(i, list) for i in ll)
21
-
22
- def pack_imagelist_into_tensor(image_list):
23
- # allow None
24
- example = None
25
- image_tensor, shapes = [], []
26
- for img in image_list:
27
- if img is None:
28
- example = find_example(image_tensor,
29
- image_list) if example is None else example
30
- image_tensor.append(example)
31
- shapes.append(None)
32
- continue
33
- _, c, h, w = img.size()
34
- image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
35
- shapes.append((h, w))
36
-
37
- image_tensor = pad_sequence(image_tensor,
38
- batch_first=True).permute(0, 2, 1) # b, c, l
39
- return image_tensor, shapes
40
-
41
- def to_device(inputs, strict=True):
42
- if inputs is None:
43
- return None
44
- if strict:
45
- assert all(isinstance(i, torch.Tensor) for i in inputs)
46
- return [i.to(we.device_id) if i is not None else None for i in inputs]
47
-
48
-
49
- def unpack_tensor_into_imagelist(image_tensor, shapes):
50
- image_list = []
51
- for img, shape in zip(image_tensor, shapes):
52
- h, w = shape[0], shape[1]
53
- image_list.append(img[:, :h * w].view(1, -1, h, w))
54
-
55
- return image_list
56
 
57
 
58
  def process_edit_image(images,
 
16
  from scepter.modules.utils.logger import get_logger
17
  from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
18
 
19
+ from modules.model.utils.basic_utils import (
20
+ check_list_of_list,
21
+ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
22
+ to_device,
23
+ unpack_tensor_into_imagelist
24
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def process_edit_image(images,
modules/__init__.py CHANGED
@@ -1 +1 @@
1
- from . import data, model, solver
 
1
+ from . import model
modules/data/__init__.py DELETED
@@ -1 +0,0 @@
1
- from . import dataset
 
 
modules/data/dataset/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .dataset import ACEDemoDataset
 
 
modules/data/dataset/dataset.py DELETED
@@ -1,252 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
-
4
- import io
5
- import math
6
- import os
7
- import sys
8
- from collections import defaultdict
9
-
10
- import numpy as np
11
- import torch
12
- import torchvision.transforms as T
13
- from PIL import Image
14
- from torchvision.transforms.functional import InterpolationMode
15
-
16
- from scepter.modules.data.dataset.base_dataset import BaseDataset
17
- from scepter.modules.data.dataset.registry import DATASETS
18
- from scepter.modules.transform.io import pillow_convert
19
- from scepter.modules.utils.config import dict_to_yaml
20
- from scepter.modules.utils.file_system import FS
21
-
22
- Image.MAX_IMAGE_PIXELS = None
23
-
24
- @DATASETS.register_class()
25
- class ACEDemoDataset(BaseDataset):
26
- para_dict = {
27
- 'MS_DATASET_NAME': {
28
- 'value': '',
29
- 'description': 'Modelscope dataset name.'
30
- },
31
- 'MS_DATASET_NAMESPACE': {
32
- 'value': '',
33
- 'description': 'Modelscope dataset namespace.'
34
- },
35
- 'MS_DATASET_SUBNAME': {
36
- 'value': '',
37
- 'description': 'Modelscope dataset subname.'
38
- },
39
- 'MS_DATASET_SPLIT': {
40
- 'value': '',
41
- 'description':
42
- 'Modelscope dataset split set name, default is train.'
43
- },
44
- 'MS_REMAP_KEYS': {
45
- 'value':
46
- None,
47
- 'description':
48
- 'Modelscope dataset header of list file, the default is Target:FILE; '
49
- 'If your file is not this header, please set this field, which is a map dict.'
50
- "For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE"
51
- },
52
- 'MS_REMAP_PATH': {
53
- 'value':
54
- None,
55
- 'description':
56
- 'When modelscope dataset name is not None, that means you use the dataset from modelscope,'
57
- ' default is None. But if you want to use the datalist from modelscope and the file from '
58
- 'local device, you can use this field to set the root path of your images. '
59
- },
60
- 'TRIGGER_WORDS': {
61
- 'value':
62
- '',
63
- 'description':
64
- 'The words used to describe the common features of your data, especially when you customize a '
65
- 'tuner. Use these words you can get what you want.'
66
- },
67
- 'HIGHLIGHT_KEYWORDS': {
68
- 'value':
69
- '',
70
- 'description':
71
- 'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'
72
- },
73
- 'KEYWORDS_SIGN': {
74
- 'value':
75
- '',
76
- 'description':
77
- 'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'
78
- },
79
- }
80
-
81
- def __init__(self, cfg, logger=None):
82
- super().__init__(cfg=cfg, logger=logger)
83
- from modelscope import MsDataset
84
- from modelscope.utils.constant import DownloadMode
85
- ms_dataset_name = cfg.get('MS_DATASET_NAME', None)
86
- ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)
87
- ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)
88
- ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')
89
- ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)
90
- ms_remap_path = cfg.get('MS_REMAP_PATH', None)
91
-
92
- self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
93
- self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)
94
- self.d = cfg.get('DOWNSAMPLE_RATIO', 16)
95
- self.replace_style = cfg.get('REPLACE_STYLE', False)
96
- self.trigger_words = cfg.get('TRIGGER_WORDS', '')
97
- self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')
98
- self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')
99
- self.add_indicator = cfg.get('ADD_INDICATOR', False)
100
- # Use modelscope dataset
101
- if not ms_dataset_name:
102
- raise ValueError(
103
- 'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '
104
- 'as modelscope dataset.')
105
- if FS.exists(ms_dataset_name):
106
- ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)
107
- self.ms_dataset_name = ms_dataset_name
108
- # ms_remap_path = ms_dataset_name
109
- try:
110
- self.data = MsDataset.load(str(ms_dataset_name),
111
- namespace=ms_dataset_namespace,
112
- subset_name=ms_dataset_subname,
113
- split=ms_dataset_split)
114
- except Exception:
115
- self.logger.info(
116
- "Load Modelscope dataset failed, retry with download_mode='force_redownload'."
117
- )
118
- try:
119
- self.data = MsDataset.load(
120
- str(ms_dataset_name),
121
- namespace=ms_dataset_namespace,
122
- subset_name=ms_dataset_subname,
123
- split=ms_dataset_split,
124
- download_mode=DownloadMode.FORCE_REDOWNLOAD)
125
- except Exception as sec_e:
126
- raise ValueError(f'Load Modelscope dataset failed {sec_e}.')
127
- if ms_remap_keys:
128
- self.data = self.data.remap_columns(ms_remap_keys.get_dict())
129
-
130
- if ms_remap_path:
131
-
132
- def map_func(example):
133
- return {
134
- k: os.path.join(ms_remap_path, v)
135
- if k.endswith(':FILE') else v
136
- for k, v in example.items()
137
- }
138
-
139
- self.data = self.data.ds_instance.map(map_func)
140
-
141
- self.transforms = T.Compose([
142
- T.ToTensor(),
143
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
144
- ])
145
-
146
- def __len__(self):
147
- if self.mode == 'train':
148
- return sys.maxsize
149
- else:
150
- return len(self.data)
151
-
152
- def _get(self, index: int):
153
- current_data = self.data[index % len(self.data)]
154
-
155
- tar_image_path = current_data.get('Target:FILE', '')
156
- src_image_path = current_data.get('Source:FILE', '')
157
-
158
- style = current_data.get('Style', '')
159
- prompt = current_data.get('Prompt', current_data.get('prompt', ''))
160
- if self.replace_style and not style == '':
161
- prompt = prompt.replace(style, f'<{self.keywords_sign}>')
162
-
163
- elif not self.replace_keywords.strip() == '':
164
- prompt = prompt.replace(
165
- self.replace_keywords,
166
- '<' + self.replace_keywords + f'{self.keywords_sign}>')
167
-
168
- if not self.trigger_words == '':
169
- prompt = self.trigger_words.strip() + ' ' + prompt
170
-
171
- src_image = self.load_image(self.ms_dataset_name,
172
- src_image_path,
173
- cvt_type='RGB')
174
- tar_image = self.load_image(self.ms_dataset_name,
175
- tar_image_path,
176
- cvt_type='RGB')
177
- src_image = self.image_preprocess(src_image)
178
- tar_image = self.image_preprocess(tar_image)
179
-
180
- tar_image = self.transforms(tar_image)
181
- src_image = self.transforms(src_image)
182
- src_mask = torch.ones_like(src_image[[0]])
183
- tar_mask = torch.ones_like(tar_image[[0]])
184
- if self.add_indicator:
185
- if '{image}' not in prompt:
186
- prompt = '{image}, ' + prompt
187
-
188
- return {
189
- 'edit_image': [src_image],
190
- 'edit_image_mask': [src_mask],
191
- 'image': tar_image,
192
- 'image_mask': tar_mask,
193
- 'prompt': [prompt],
194
- }
195
-
196
- def load_image(self, prefix, img_path, cvt_type=None):
197
- if img_path is None or img_path == '':
198
- return None
199
- img_path = os.path.join(prefix, img_path)
200
- with FS.get_object(img_path) as image_bytes:
201
- image = Image.open(io.BytesIO(image_bytes))
202
- if cvt_type is not None:
203
- image = pillow_convert(image, cvt_type)
204
- return image
205
-
206
- def image_preprocess(self,
207
- img,
208
- size=None,
209
- interpolation=InterpolationMode.BILINEAR):
210
- H, W = img.height, img.width
211
- if H / W > self.max_aspect_ratio:
212
- img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)
213
- elif W / H > self.max_aspect_ratio:
214
- img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)
215
-
216
- if size is None:
217
- # resize image for max_seq_len, while keep the aspect ratio
218
- H, W = img.height, img.width
219
- scale = min(
220
- 1.0,
221
- math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))
222
- rH = int(
223
- H * scale) // self.d * self.d # ensure divisible by self.d
224
- rW = int(W * scale) // self.d * self.d
225
- else:
226
- rH, rW = size
227
- img = T.Resize((rH, rW), interpolation=interpolation,
228
- antialias=True)(img)
229
- return np.array(img, dtype=np.uint8)
230
-
231
- @staticmethod
232
- def get_config_template():
233
- return dict_to_yaml('DATASet',
234
- __class__.__name__,
235
- ACEDemoDataset.para_dict,
236
- set_name=True)
237
-
238
- @staticmethod
239
- def collate_fn(batch):
240
- collect = defaultdict(list)
241
- for sample in batch:
242
- for k, v in sample.items():
243
- collect[k].append(v)
244
-
245
- new_batch = dict()
246
- for k, v in collect.items():
247
- if all([i is None for i in v]):
248
- new_batch[k] = None
249
- else:
250
- new_batch[k] = v
251
-
252
- return new_batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/solver/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .ace_solver import ACESolverV1
 
 
modules/solver/ace_solver.py DELETED
@@ -1,146 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import numpy as np
4
- import torch
5
- from tqdm import tqdm
6
-
7
- from scepter.modules.utils.data import transfer_data_to_cuda
8
- from scepter.modules.utils.distribute import we
9
- from scepter.modules.utils.probe import ProbeData
10
- from scepter.modules.solver.registry import SOLVERS
11
- from scepter.modules.solver.diffusion_solver import LatentDiffusionSolver
12
-
13
-
14
-
15
- @SOLVERS.register_class()
16
- class ACESolverV1(LatentDiffusionSolver):
17
- def __init__(self, cfg, logger=None):
18
- super().__init__(cfg, logger=logger)
19
- self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1)
20
-
21
- def save_results(self, results):
22
- log_data, log_label = [], []
23
- for result in results:
24
- ret_images, ret_labels = [], []
25
- edit_image = result.get('edit_image', None)
26
- edit_mask = result.get('edit_mask', None)
27
- if edit_image is not None:
28
- for i, edit_img in enumerate(result['edit_image']):
29
- if edit_img is None:
30
- continue
31
- ret_images.append(
32
- (edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(
33
- np.uint8))
34
- ret_labels.append(f'edit_image{i}; ')
35
- if edit_mask is not None:
36
- ret_images.append(
37
- (edit_mask[i].permute(1, 2, 0).cpu().numpy() *
38
- 255).astype(np.uint8))
39
- ret_labels.append(f'edit_mask{i}; ')
40
-
41
- target_image = result.get('target_image', None)
42
- target_mask = result.get('target_mask', None)
43
- if target_image is not None:
44
- ret_images.append(
45
- (target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(
46
- np.uint8))
47
- ret_labels.append('target_image; ')
48
- if target_mask is not None:
49
- ret_images.append(
50
- (target_mask.permute(1, 2, 0).cpu().numpy() *
51
- 255).astype(np.uint8))
52
- ret_labels.append('target_mask; ')
53
-
54
- reconstruct_image = result.get('reconstruct_image', None)
55
- if reconstruct_image is not None:
56
- ret_images.append(
57
- (reconstruct_image.permute(1, 2, 0).cpu().numpy() *
58
- 255).astype(np.uint8))
59
- ret_labels.append(f"{result['instruction']}")
60
- log_data.append(ret_images)
61
- log_label.append(ret_labels)
62
- return log_data, log_label
63
-
64
- @torch.no_grad()
65
- def run_eval(self):
66
- self.eval_mode()
67
- self.before_all_iter(self.hooks_dict[self._mode])
68
- all_results = []
69
- for batch_idx, batch_data in tqdm(
70
- enumerate(self.datas[self._mode].dataloader)):
71
- self.before_iter(self.hooks_dict[self._mode])
72
- if self.sample_args:
73
- batch_data.update(self.sample_args.get_lowercase_dict())
74
- with torch.autocast(device_type='cuda',
75
- enabled=self.use_amp,
76
- dtype=self.dtype):
77
- results = self.run_step_eval(transfer_data_to_cuda(batch_data),
78
- batch_idx,
79
- step=self.total_iter,
80
- rank=we.rank)
81
- all_results.extend(results)
82
- self.after_iter(self.hooks_dict[self._mode])
83
- log_data, log_label = self.save_results(all_results)
84
- self.register_probe({'eval_label': log_label})
85
- self.register_probe({
86
- 'eval_image':
87
- ProbeData(log_data,
88
- is_image=True,
89
- build_html=True,
90
- build_label=log_label)
91
- })
92
- self.after_all_iter(self.hooks_dict[self._mode])
93
-
94
- @torch.no_grad()
95
- def run_test(self):
96
- self.test_mode()
97
- self.before_all_iter(self.hooks_dict[self._mode])
98
- all_results = []
99
- for batch_idx, batch_data in tqdm(
100
- enumerate(self.datas[self._mode].dataloader)):
101
- self.before_iter(self.hooks_dict[self._mode])
102
- if self.sample_args:
103
- batch_data.update(self.sample_args.get_lowercase_dict())
104
- with torch.autocast(device_type='cuda',
105
- enabled=self.use_amp,
106
- dtype=self.dtype):
107
- results = self.run_step_eval(transfer_data_to_cuda(batch_data),
108
- batch_idx,
109
- step=self.total_iter,
110
- rank=we.rank)
111
- all_results.extend(results)
112
- self.after_iter(self.hooks_dict[self._mode])
113
- log_data, log_label = self.save_results(all_results)
114
- self.register_probe({'test_label': log_label})
115
- self.register_probe({
116
- 'test_image':
117
- ProbeData(log_data,
118
- is_image=True,
119
- build_html=True,
120
- build_label=log_label)
121
- })
122
-
123
- self.after_all_iter(self.hooks_dict[self._mode])
124
-
125
- @property
126
- def probe_data(self):
127
- if not we.debug and self.mode == 'train':
128
- batch_data = transfer_data_to_cuda(
129
- self.current_batch_data[self.mode])
130
- self.eval_mode()
131
- with torch.autocast(device_type='cuda',
132
- enabled=self.use_amp,
133
- dtype=self.dtype):
134
- batch_data['log_num'] = self.log_train_num
135
- results = self.run_step_eval(batch_data)
136
- self.train_mode()
137
- log_data, log_label = self.save_results(results)
138
- self.register_probe({
139
- 'train_image':
140
- ProbeData(log_data,
141
- is_image=True,
142
- build_html=True,
143
- build_label=log_label)
144
- })
145
- self.register_probe({'train_label': log_label})
146
- return super(LatentDiffusionSolver, self).probe_data