Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- infer.py +6 -37
- modules/__init__.py +1 -1
- modules/data/__init__.py +0 -1
- modules/data/dataset/__init__.py +0 -1
- modules/data/dataset/dataset.py +0 -252
- modules/solver/__init__.py +0 -1
- modules/solver/ace_solver.py +0 -146
infer.py
CHANGED
@@ -16,43 +16,12 @@ from scepter.modules.utils.distribute import we
|
|
16 |
from scepter.modules.utils.logger import get_logger
|
17 |
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
image_tensor, shapes = [], []
|
26 |
-
for img in image_list:
|
27 |
-
if img is None:
|
28 |
-
example = find_example(image_tensor,
|
29 |
-
image_list) if example is None else example
|
30 |
-
image_tensor.append(example)
|
31 |
-
shapes.append(None)
|
32 |
-
continue
|
33 |
-
_, c, h, w = img.size()
|
34 |
-
image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
|
35 |
-
shapes.append((h, w))
|
36 |
-
|
37 |
-
image_tensor = pad_sequence(image_tensor,
|
38 |
-
batch_first=True).permute(0, 2, 1) # b, c, l
|
39 |
-
return image_tensor, shapes
|
40 |
-
|
41 |
-
def to_device(inputs, strict=True):
|
42 |
-
if inputs is None:
|
43 |
-
return None
|
44 |
-
if strict:
|
45 |
-
assert all(isinstance(i, torch.Tensor) for i in inputs)
|
46 |
-
return [i.to(we.device_id) if i is not None else None for i in inputs]
|
47 |
-
|
48 |
-
|
49 |
-
def unpack_tensor_into_imagelist(image_tensor, shapes):
|
50 |
-
image_list = []
|
51 |
-
for img, shape in zip(image_tensor, shapes):
|
52 |
-
h, w = shape[0], shape[1]
|
53 |
-
image_list.append(img[:, :h * w].view(1, -1, h, w))
|
54 |
-
|
55 |
-
return image_list
|
56 |
|
57 |
|
58 |
def process_edit_image(images,
|
|
|
16 |
from scepter.modules.utils.logger import get_logger
|
17 |
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
|
18 |
|
19 |
+
from modules.model.utils.basic_utils import (
|
20 |
+
check_list_of_list,
|
21 |
+
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
|
22 |
+
to_device,
|
23 |
+
unpack_tensor_into_imagelist
|
24 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
def process_edit_image(images,
|
modules/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from . import
|
|
|
1 |
+
from . import model
|
modules/data/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from . import dataset
|
|
|
|
modules/data/dataset/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .dataset import ACEDemoDataset
|
|
|
|
modules/data/dataset/dataset.py
DELETED
@@ -1,252 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
|
4 |
-
import io
|
5 |
-
import math
|
6 |
-
import os
|
7 |
-
import sys
|
8 |
-
from collections import defaultdict
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import torchvision.transforms as T
|
13 |
-
from PIL import Image
|
14 |
-
from torchvision.transforms.functional import InterpolationMode
|
15 |
-
|
16 |
-
from scepter.modules.data.dataset.base_dataset import BaseDataset
|
17 |
-
from scepter.modules.data.dataset.registry import DATASETS
|
18 |
-
from scepter.modules.transform.io import pillow_convert
|
19 |
-
from scepter.modules.utils.config import dict_to_yaml
|
20 |
-
from scepter.modules.utils.file_system import FS
|
21 |
-
|
22 |
-
Image.MAX_IMAGE_PIXELS = None
|
23 |
-
|
24 |
-
@DATASETS.register_class()
|
25 |
-
class ACEDemoDataset(BaseDataset):
|
26 |
-
para_dict = {
|
27 |
-
'MS_DATASET_NAME': {
|
28 |
-
'value': '',
|
29 |
-
'description': 'Modelscope dataset name.'
|
30 |
-
},
|
31 |
-
'MS_DATASET_NAMESPACE': {
|
32 |
-
'value': '',
|
33 |
-
'description': 'Modelscope dataset namespace.'
|
34 |
-
},
|
35 |
-
'MS_DATASET_SUBNAME': {
|
36 |
-
'value': '',
|
37 |
-
'description': 'Modelscope dataset subname.'
|
38 |
-
},
|
39 |
-
'MS_DATASET_SPLIT': {
|
40 |
-
'value': '',
|
41 |
-
'description':
|
42 |
-
'Modelscope dataset split set name, default is train.'
|
43 |
-
},
|
44 |
-
'MS_REMAP_KEYS': {
|
45 |
-
'value':
|
46 |
-
None,
|
47 |
-
'description':
|
48 |
-
'Modelscope dataset header of list file, the default is Target:FILE; '
|
49 |
-
'If your file is not this header, please set this field, which is a map dict.'
|
50 |
-
"For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE"
|
51 |
-
},
|
52 |
-
'MS_REMAP_PATH': {
|
53 |
-
'value':
|
54 |
-
None,
|
55 |
-
'description':
|
56 |
-
'When modelscope dataset name is not None, that means you use the dataset from modelscope,'
|
57 |
-
' default is None. But if you want to use the datalist from modelscope and the file from '
|
58 |
-
'local device, you can use this field to set the root path of your images. '
|
59 |
-
},
|
60 |
-
'TRIGGER_WORDS': {
|
61 |
-
'value':
|
62 |
-
'',
|
63 |
-
'description':
|
64 |
-
'The words used to describe the common features of your data, especially when you customize a '
|
65 |
-
'tuner. Use these words you can get what you want.'
|
66 |
-
},
|
67 |
-
'HIGHLIGHT_KEYWORDS': {
|
68 |
-
'value':
|
69 |
-
'',
|
70 |
-
'description':
|
71 |
-
'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'
|
72 |
-
},
|
73 |
-
'KEYWORDS_SIGN': {
|
74 |
-
'value':
|
75 |
-
'',
|
76 |
-
'description':
|
77 |
-
'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'
|
78 |
-
},
|
79 |
-
}
|
80 |
-
|
81 |
-
def __init__(self, cfg, logger=None):
|
82 |
-
super().__init__(cfg=cfg, logger=logger)
|
83 |
-
from modelscope import MsDataset
|
84 |
-
from modelscope.utils.constant import DownloadMode
|
85 |
-
ms_dataset_name = cfg.get('MS_DATASET_NAME', None)
|
86 |
-
ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)
|
87 |
-
ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)
|
88 |
-
ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')
|
89 |
-
ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)
|
90 |
-
ms_remap_path = cfg.get('MS_REMAP_PATH', None)
|
91 |
-
|
92 |
-
self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
|
93 |
-
self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)
|
94 |
-
self.d = cfg.get('DOWNSAMPLE_RATIO', 16)
|
95 |
-
self.replace_style = cfg.get('REPLACE_STYLE', False)
|
96 |
-
self.trigger_words = cfg.get('TRIGGER_WORDS', '')
|
97 |
-
self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')
|
98 |
-
self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')
|
99 |
-
self.add_indicator = cfg.get('ADD_INDICATOR', False)
|
100 |
-
# Use modelscope dataset
|
101 |
-
if not ms_dataset_name:
|
102 |
-
raise ValueError(
|
103 |
-
'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '
|
104 |
-
'as modelscope dataset.')
|
105 |
-
if FS.exists(ms_dataset_name):
|
106 |
-
ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)
|
107 |
-
self.ms_dataset_name = ms_dataset_name
|
108 |
-
# ms_remap_path = ms_dataset_name
|
109 |
-
try:
|
110 |
-
self.data = MsDataset.load(str(ms_dataset_name),
|
111 |
-
namespace=ms_dataset_namespace,
|
112 |
-
subset_name=ms_dataset_subname,
|
113 |
-
split=ms_dataset_split)
|
114 |
-
except Exception:
|
115 |
-
self.logger.info(
|
116 |
-
"Load Modelscope dataset failed, retry with download_mode='force_redownload'."
|
117 |
-
)
|
118 |
-
try:
|
119 |
-
self.data = MsDataset.load(
|
120 |
-
str(ms_dataset_name),
|
121 |
-
namespace=ms_dataset_namespace,
|
122 |
-
subset_name=ms_dataset_subname,
|
123 |
-
split=ms_dataset_split,
|
124 |
-
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
125 |
-
except Exception as sec_e:
|
126 |
-
raise ValueError(f'Load Modelscope dataset failed {sec_e}.')
|
127 |
-
if ms_remap_keys:
|
128 |
-
self.data = self.data.remap_columns(ms_remap_keys.get_dict())
|
129 |
-
|
130 |
-
if ms_remap_path:
|
131 |
-
|
132 |
-
def map_func(example):
|
133 |
-
return {
|
134 |
-
k: os.path.join(ms_remap_path, v)
|
135 |
-
if k.endswith(':FILE') else v
|
136 |
-
for k, v in example.items()
|
137 |
-
}
|
138 |
-
|
139 |
-
self.data = self.data.ds_instance.map(map_func)
|
140 |
-
|
141 |
-
self.transforms = T.Compose([
|
142 |
-
T.ToTensor(),
|
143 |
-
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
144 |
-
])
|
145 |
-
|
146 |
-
def __len__(self):
|
147 |
-
if self.mode == 'train':
|
148 |
-
return sys.maxsize
|
149 |
-
else:
|
150 |
-
return len(self.data)
|
151 |
-
|
152 |
-
def _get(self, index: int):
|
153 |
-
current_data = self.data[index % len(self.data)]
|
154 |
-
|
155 |
-
tar_image_path = current_data.get('Target:FILE', '')
|
156 |
-
src_image_path = current_data.get('Source:FILE', '')
|
157 |
-
|
158 |
-
style = current_data.get('Style', '')
|
159 |
-
prompt = current_data.get('Prompt', current_data.get('prompt', ''))
|
160 |
-
if self.replace_style and not style == '':
|
161 |
-
prompt = prompt.replace(style, f'<{self.keywords_sign}>')
|
162 |
-
|
163 |
-
elif not self.replace_keywords.strip() == '':
|
164 |
-
prompt = prompt.replace(
|
165 |
-
self.replace_keywords,
|
166 |
-
'<' + self.replace_keywords + f'{self.keywords_sign}>')
|
167 |
-
|
168 |
-
if not self.trigger_words == '':
|
169 |
-
prompt = self.trigger_words.strip() + ' ' + prompt
|
170 |
-
|
171 |
-
src_image = self.load_image(self.ms_dataset_name,
|
172 |
-
src_image_path,
|
173 |
-
cvt_type='RGB')
|
174 |
-
tar_image = self.load_image(self.ms_dataset_name,
|
175 |
-
tar_image_path,
|
176 |
-
cvt_type='RGB')
|
177 |
-
src_image = self.image_preprocess(src_image)
|
178 |
-
tar_image = self.image_preprocess(tar_image)
|
179 |
-
|
180 |
-
tar_image = self.transforms(tar_image)
|
181 |
-
src_image = self.transforms(src_image)
|
182 |
-
src_mask = torch.ones_like(src_image[[0]])
|
183 |
-
tar_mask = torch.ones_like(tar_image[[0]])
|
184 |
-
if self.add_indicator:
|
185 |
-
if '{image}' not in prompt:
|
186 |
-
prompt = '{image}, ' + prompt
|
187 |
-
|
188 |
-
return {
|
189 |
-
'edit_image': [src_image],
|
190 |
-
'edit_image_mask': [src_mask],
|
191 |
-
'image': tar_image,
|
192 |
-
'image_mask': tar_mask,
|
193 |
-
'prompt': [prompt],
|
194 |
-
}
|
195 |
-
|
196 |
-
def load_image(self, prefix, img_path, cvt_type=None):
|
197 |
-
if img_path is None or img_path == '':
|
198 |
-
return None
|
199 |
-
img_path = os.path.join(prefix, img_path)
|
200 |
-
with FS.get_object(img_path) as image_bytes:
|
201 |
-
image = Image.open(io.BytesIO(image_bytes))
|
202 |
-
if cvt_type is not None:
|
203 |
-
image = pillow_convert(image, cvt_type)
|
204 |
-
return image
|
205 |
-
|
206 |
-
def image_preprocess(self,
|
207 |
-
img,
|
208 |
-
size=None,
|
209 |
-
interpolation=InterpolationMode.BILINEAR):
|
210 |
-
H, W = img.height, img.width
|
211 |
-
if H / W > self.max_aspect_ratio:
|
212 |
-
img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)
|
213 |
-
elif W / H > self.max_aspect_ratio:
|
214 |
-
img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)
|
215 |
-
|
216 |
-
if size is None:
|
217 |
-
# resize image for max_seq_len, while keep the aspect ratio
|
218 |
-
H, W = img.height, img.width
|
219 |
-
scale = min(
|
220 |
-
1.0,
|
221 |
-
math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))
|
222 |
-
rH = int(
|
223 |
-
H * scale) // self.d * self.d # ensure divisible by self.d
|
224 |
-
rW = int(W * scale) // self.d * self.d
|
225 |
-
else:
|
226 |
-
rH, rW = size
|
227 |
-
img = T.Resize((rH, rW), interpolation=interpolation,
|
228 |
-
antialias=True)(img)
|
229 |
-
return np.array(img, dtype=np.uint8)
|
230 |
-
|
231 |
-
@staticmethod
|
232 |
-
def get_config_template():
|
233 |
-
return dict_to_yaml('DATASet',
|
234 |
-
__class__.__name__,
|
235 |
-
ACEDemoDataset.para_dict,
|
236 |
-
set_name=True)
|
237 |
-
|
238 |
-
@staticmethod
|
239 |
-
def collate_fn(batch):
|
240 |
-
collect = defaultdict(list)
|
241 |
-
for sample in batch:
|
242 |
-
for k, v in sample.items():
|
243 |
-
collect[k].append(v)
|
244 |
-
|
245 |
-
new_batch = dict()
|
246 |
-
for k, v in collect.items():
|
247 |
-
if all([i is None for i in v]):
|
248 |
-
new_batch[k] = None
|
249 |
-
else:
|
250 |
-
new_batch[k] = v
|
251 |
-
|
252 |
-
return new_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/solver/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .ace_solver import ACESolverV1
|
|
|
|
modules/solver/ace_solver.py
DELETED
@@ -1,146 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
from tqdm import tqdm
|
6 |
-
|
7 |
-
from scepter.modules.utils.data import transfer_data_to_cuda
|
8 |
-
from scepter.modules.utils.distribute import we
|
9 |
-
from scepter.modules.utils.probe import ProbeData
|
10 |
-
from scepter.modules.solver.registry import SOLVERS
|
11 |
-
from scepter.modules.solver.diffusion_solver import LatentDiffusionSolver
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
@SOLVERS.register_class()
|
16 |
-
class ACESolverV1(LatentDiffusionSolver):
|
17 |
-
def __init__(self, cfg, logger=None):
|
18 |
-
super().__init__(cfg, logger=logger)
|
19 |
-
self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1)
|
20 |
-
|
21 |
-
def save_results(self, results):
|
22 |
-
log_data, log_label = [], []
|
23 |
-
for result in results:
|
24 |
-
ret_images, ret_labels = [], []
|
25 |
-
edit_image = result.get('edit_image', None)
|
26 |
-
edit_mask = result.get('edit_mask', None)
|
27 |
-
if edit_image is not None:
|
28 |
-
for i, edit_img in enumerate(result['edit_image']):
|
29 |
-
if edit_img is None:
|
30 |
-
continue
|
31 |
-
ret_images.append(
|
32 |
-
(edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
33 |
-
np.uint8))
|
34 |
-
ret_labels.append(f'edit_image{i}; ')
|
35 |
-
if edit_mask is not None:
|
36 |
-
ret_images.append(
|
37 |
-
(edit_mask[i].permute(1, 2, 0).cpu().numpy() *
|
38 |
-
255).astype(np.uint8))
|
39 |
-
ret_labels.append(f'edit_mask{i}; ')
|
40 |
-
|
41 |
-
target_image = result.get('target_image', None)
|
42 |
-
target_mask = result.get('target_mask', None)
|
43 |
-
if target_image is not None:
|
44 |
-
ret_images.append(
|
45 |
-
(target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
46 |
-
np.uint8))
|
47 |
-
ret_labels.append('target_image; ')
|
48 |
-
if target_mask is not None:
|
49 |
-
ret_images.append(
|
50 |
-
(target_mask.permute(1, 2, 0).cpu().numpy() *
|
51 |
-
255).astype(np.uint8))
|
52 |
-
ret_labels.append('target_mask; ')
|
53 |
-
|
54 |
-
reconstruct_image = result.get('reconstruct_image', None)
|
55 |
-
if reconstruct_image is not None:
|
56 |
-
ret_images.append(
|
57 |
-
(reconstruct_image.permute(1, 2, 0).cpu().numpy() *
|
58 |
-
255).astype(np.uint8))
|
59 |
-
ret_labels.append(f"{result['instruction']}")
|
60 |
-
log_data.append(ret_images)
|
61 |
-
log_label.append(ret_labels)
|
62 |
-
return log_data, log_label
|
63 |
-
|
64 |
-
@torch.no_grad()
|
65 |
-
def run_eval(self):
|
66 |
-
self.eval_mode()
|
67 |
-
self.before_all_iter(self.hooks_dict[self._mode])
|
68 |
-
all_results = []
|
69 |
-
for batch_idx, batch_data in tqdm(
|
70 |
-
enumerate(self.datas[self._mode].dataloader)):
|
71 |
-
self.before_iter(self.hooks_dict[self._mode])
|
72 |
-
if self.sample_args:
|
73 |
-
batch_data.update(self.sample_args.get_lowercase_dict())
|
74 |
-
with torch.autocast(device_type='cuda',
|
75 |
-
enabled=self.use_amp,
|
76 |
-
dtype=self.dtype):
|
77 |
-
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
|
78 |
-
batch_idx,
|
79 |
-
step=self.total_iter,
|
80 |
-
rank=we.rank)
|
81 |
-
all_results.extend(results)
|
82 |
-
self.after_iter(self.hooks_dict[self._mode])
|
83 |
-
log_data, log_label = self.save_results(all_results)
|
84 |
-
self.register_probe({'eval_label': log_label})
|
85 |
-
self.register_probe({
|
86 |
-
'eval_image':
|
87 |
-
ProbeData(log_data,
|
88 |
-
is_image=True,
|
89 |
-
build_html=True,
|
90 |
-
build_label=log_label)
|
91 |
-
})
|
92 |
-
self.after_all_iter(self.hooks_dict[self._mode])
|
93 |
-
|
94 |
-
@torch.no_grad()
|
95 |
-
def run_test(self):
|
96 |
-
self.test_mode()
|
97 |
-
self.before_all_iter(self.hooks_dict[self._mode])
|
98 |
-
all_results = []
|
99 |
-
for batch_idx, batch_data in tqdm(
|
100 |
-
enumerate(self.datas[self._mode].dataloader)):
|
101 |
-
self.before_iter(self.hooks_dict[self._mode])
|
102 |
-
if self.sample_args:
|
103 |
-
batch_data.update(self.sample_args.get_lowercase_dict())
|
104 |
-
with torch.autocast(device_type='cuda',
|
105 |
-
enabled=self.use_amp,
|
106 |
-
dtype=self.dtype):
|
107 |
-
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
|
108 |
-
batch_idx,
|
109 |
-
step=self.total_iter,
|
110 |
-
rank=we.rank)
|
111 |
-
all_results.extend(results)
|
112 |
-
self.after_iter(self.hooks_dict[self._mode])
|
113 |
-
log_data, log_label = self.save_results(all_results)
|
114 |
-
self.register_probe({'test_label': log_label})
|
115 |
-
self.register_probe({
|
116 |
-
'test_image':
|
117 |
-
ProbeData(log_data,
|
118 |
-
is_image=True,
|
119 |
-
build_html=True,
|
120 |
-
build_label=log_label)
|
121 |
-
})
|
122 |
-
|
123 |
-
self.after_all_iter(self.hooks_dict[self._mode])
|
124 |
-
|
125 |
-
@property
|
126 |
-
def probe_data(self):
|
127 |
-
if not we.debug and self.mode == 'train':
|
128 |
-
batch_data = transfer_data_to_cuda(
|
129 |
-
self.current_batch_data[self.mode])
|
130 |
-
self.eval_mode()
|
131 |
-
with torch.autocast(device_type='cuda',
|
132 |
-
enabled=self.use_amp,
|
133 |
-
dtype=self.dtype):
|
134 |
-
batch_data['log_num'] = self.log_train_num
|
135 |
-
results = self.run_step_eval(batch_data)
|
136 |
-
self.train_mode()
|
137 |
-
log_data, log_label = self.save_results(results)
|
138 |
-
self.register_probe({
|
139 |
-
'train_image':
|
140 |
-
ProbeData(log_data,
|
141 |
-
is_image=True,
|
142 |
-
build_html=True,
|
143 |
-
build_label=log_label)
|
144 |
-
})
|
145 |
-
self.register_probe({'train_label': log_label})
|
146 |
-
return super(LatentDiffusionSolver, self).probe_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|