File size: 12,038 Bytes
e7d3e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import logging
import random

import webdataset as wds
from webdataset.tariterators import group_by_keys, tar_file_expander, url_opener

from m4.training.types import DatasetTypes


meta_prefix = "__"
meta_suffix = "__"

logger = logging.getLogger(__name__)
trace = False


def webdoc_valid_sample(sample):
    """Check whether a sample is valid.

    :param sample: sample to be checked
    """
    return (
        sample is not None
        and isinstance(sample, dict)
        and len(list(sample.keys())) > 0
        and not sample.get("__bad__", False)
        and sample_has_all_files(sample)
    )


def sample_has_all_files(current_sample):
    meta = current_sample.get("metadata.value", None)
    if meta is None:
        return False
    meta = meta.decode("utf-8")
    if len(meta) == 0:
        return False
    target_file_list = meta.split("\n")
    fname_keys = [key for key in current_sample.keys() if key.endswith(".fname")]
    fnames = [current_sample[key] for key in fname_keys]
    check = all([fname in fnames for fname in target_file_list])
    if not check:
        return False
    return True


class ImageDecoder:
    def __call__(self, bytes_):
        import io

        import PIL.Image

        img = PIL.Image.open(io.BytesIO(bytes_))
        img.load()
        return img


# Taken from https://github.com/mlfoundations/open_clip/blob/c48111dacac55db24878af229d8a5662c03e6f1c/src/training/data.py#L180-L183
def log_and_continue(exn):
    """Call in an exception handler to ignore any exception, issue a warning, and continue."""
    logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
    return True


# Adapt group_by_keys to our webdocument format in which each samples contains several text and image files
# https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/tariterators.py#L195-L250
def group_by_keys_interleaved(data, handler=log_and_continue):
    """Return function over iterator that groups key, value pairs into samples."""
    current_sample = None
    for filesample in data:
        try:
            assert isinstance(filesample, dict)
            fname, value = filesample["fname"], filesample["data"]
            fname = fname.strip("./")
            if fname.endswith(".metadata.txt"):
                prefix, data_type, extension = fname.split(".")
                suffix = data_type
            else:
                prefix, idx, data_type, extension = fname.split(".")
                if data_type not in ["text", "image"]:
                    raise ValueError(f"{fname}: unknown data type {data_type}")
                suffix = idx
            if trace:
                print(
                    f"prefix: {prefix}, idx: {idx}, data_type: {data_type}, extension: {extension}, keys:"
                    f" {current_sample.keys() if isinstance(current_sample, dict) else None}"
                )
            if prefix is None:
                continue
            if current_sample is None or prefix != current_sample["__key__"]:
                valid = webdoc_valid_sample(current_sample)
                if valid:
                    yield current_sample
                elif current_sample is not None:
                    logging.warning(f"{fname}: invalid sample {current_sample} ignored")
                current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
            if suffix in current_sample:
                raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
            current_sample[f"{suffix}.value"] = value
            current_sample[f"{suffix}.type"] = data_type
            current_sample[f"{suffix}.fname"] = fname
        except Exception as exn:
            exn.args = exn.args + (filesample.get("stream"), filesample.get("url"))
            if handler(exn):
                continue
            else:
                break

    if webdoc_valid_sample(current_sample):
        yield current_sample


def _tarfile_to_webdocument_samples(src, handler=log_and_continue):
    streams = url_opener(src, handler=handler)
    files = tar_file_expander(streams, handler=handler)
    samples = group_by_keys_interleaved(files, handler=handler)
    return samples


tarfile_to_webdocument_samples = wds.filters.pipelinefilter(_tarfile_to_webdocument_samples)


def _collate_texts_and_images_webdocument(data, handler=log_and_continue):
    for sample in data:
        try:
            max_example_indices = max(
                [int(key.split(".")[0]) for key in sample.keys() if key.endswith(".value") and key != "metadata.value"]
            )
            texts = [None for _ in range(max_example_indices + 1)]
            images = [None for _ in range(max_example_indices + 1)]
            for idx in range(max_example_indices + 1):
                if f"{idx}.value" not in sample:
                    continue
                if "text" in sample[f"{idx}.type"]:
                    texts[idx] = sample[f"{idx}.value"]
                elif "image" in sample[f"{idx}.type"]:
                    images[idx] = sample[f"{idx}.value"]
                else:
                    raise ValueError(f"Unknown data type: {sample[f'{idx}.type']}")
            example = {"__key__": sample["__key__"], "__url__": sample["__url__"], "texts": texts, "images": images}
            yield example
        except Exception as exn:
            exn.args = exn.args + (sample.get("stream"), sample.get("url"))
            if handler(exn):
                continue
            else:
                break


collate_texts_and_images_webdocument = wds.filters.pipelinefilter(_collate_texts_and_images_webdocument)


def _decode_image_and_text_webdocument(data, handler=log_and_continue):
    image_decoder = ImageDecoder()
    for sample in data:
        try:
            sample["images"] = [image_decoder(image) if image is not None else None for image in sample["images"]]
            sample["texts"] = [text.decode("utf-8") if text is not None else None for text in sample["texts"]]
            yield sample
        except Exception as exn:
            exn.args = exn.args + (sample.get("stream"), sample.get("url"))
            if handler(exn):
                continue
            else:
                break


decode_image_and_text_webdocument = wds.filters.pipelinefilter(_decode_image_and_text_webdocument)


def collate_dicts(samples):
    keys = samples[0].keys()
    batched_samples = {key: [sample[key] for sample in samples] for key in keys}
    return batched_samples


def get_webdocuments_webdataset(
    urls,
    batch_size,
    shuffle_initial_urls_list=False,
    shuffle_before_split_by_node_buffer_size=100,
    shuffle_before_split_by_worker_buffer_size=100,
    shuffle_after_tarfile_to_samples_buffer_size=100,
    shuffle_after_batching_buffer_size=1000,
):
    if shuffle_initial_urls_list:
        random.shuffle(urls)

    pipeline_list = [wds.SimpleShardList(urls)]

    if shuffle_before_split_by_node_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))

    pipeline_list.append(wds.split_by_node)

    if shuffle_before_split_by_worker_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))

    pipeline_list.extend(
        [
            wds.split_by_worker,
            tarfile_to_webdocument_samples(),
        ]
    )

    if shuffle_after_tarfile_to_samples_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))

    pipeline_list.extend(
        [
            collate_texts_and_images_webdocument(),
            decode_image_and_text_webdocument(),
            wds.batched(batch_size, collation_fn=collate_dicts, partial=True),
        ]
    )

    if shuffle_after_batching_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))

    dataset = wds.DataPipeline(pipeline_list)
    return dataset


def split_keep_2(x):
    x = x.strip("./")
    x_splitter = x.split(".")
    return x_splitter[0], x_splitter[1]


def _tarfile_to_pair_samples(src, handler=log_and_continue):
    streams = url_opener(src, handler=handler)
    files = tar_file_expander(streams, handler=handler)
    samples = group_by_keys(files, keys=split_keep_2, handler=handler)
    return samples


tarfile_to_pair_samples = wds.filters.pipelinefilter(_tarfile_to_pair_samples)


def _decode_image_and_text_pairs(data, handler=log_and_continue):
    image_decoder = ImageDecoder()
    for sample in data:
        try:
            sample["image"] = image_decoder(sample["image"])
            sample["text"] = sample["text"].decode("utf-8")
            yield sample
        except Exception as exn:
            exn.args = exn.args + (sample.get("stream"), sample.get("url"))
            if handler(exn):
                continue
            else:
                break


decode_image_and_text_pairs = wds.filters.pipelinefilter(_decode_image_and_text_pairs)


def get_image_caption_pairs_webdataset(
    urls,
    batch_size,
    shuffle_initial_urls_list=False,
    shuffle_before_split_by_node_buffer_size=100,
    shuffle_before_split_by_worker_buffer_size=100,
    shuffle_after_tarfile_to_samples_buffer_size=100,
    shuffle_after_batching_buffer_size=1000,
):
    if shuffle_initial_urls_list:
        random.shuffle(urls)

    pipeline_list = [wds.SimpleShardList(urls)]

    if shuffle_before_split_by_node_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))

    pipeline_list.append(wds.split_by_node)

    if shuffle_before_split_by_worker_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))

    pipeline_list.extend(
        [
            wds.split_by_worker,
            tarfile_to_pair_samples(handler=log_and_continue),
        ]
    )

    if shuffle_after_tarfile_to_samples_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))

    pipeline_list.extend(
        [
            decode_image_and_text_pairs(),
            wds.batched(batch_size, collation_fn=collate_dicts, partial=True),  # todo: check if partial is needed
        ]
    )

    if shuffle_after_batching_buffer_size is not None:
        pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))

    dataset = wds.DataPipeline(pipeline_list)
    return dataset


def get_webdataset(
    urls,
    ds_type: DatasetTypes,
    batch_size: int,
    shuffle_initial_urls_list,
    shuffle_before_split_by_node_buffer_size,
    shuffle_before_split_by_worker_buffer_size,
    shuffle_after_tarfile_to_samples_buffer_size,
    shuffle_after_batching_buffer_size,
):
    if ds_type == DatasetTypes.WEB_DOCUMENTS:
        return get_webdocuments_webdataset(
            urls,
            batch_size,
            shuffle_initial_urls_list,
            shuffle_before_split_by_node_buffer_size,
            shuffle_before_split_by_worker_buffer_size,
            shuffle_after_tarfile_to_samples_buffer_size,
            shuffle_after_batching_buffer_size,
        )
    elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS:
        return get_image_caption_pairs_webdataset(
            urls,
            batch_size,
            shuffle_initial_urls_list,
            shuffle_before_split_by_node_buffer_size,
            shuffle_before_split_by_worker_buffer_size,
            shuffle_after_tarfile_to_samples_buffer_size,
            shuffle_after_batching_buffer_size,
        )
    else:
        raise ValueError(f"Unknown dataset type: {ds_type}")


def check_webdataset_command(command):
    if "s3:/" not in command:
        return True

    command = command.strip()
    if not command.startswith("pipe:bash"):
        return False

    if not command.endswith(".tar"):
        return False

    if "get_file.sh" not in command:
        return False

    return True