File size: 4,085 Bytes
8e2b754
 
 
 
 
98c2b8e
8e2b754
 
 
 
 
 
 
98c2b8e
 
 
 
 
8e2b754
 
98c2b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e2b754
98c2b8e
8e2b754
 
 
 
98c2b8e
8e2b754
98c2b8e
8e2b754
 
 
 
 
 
 
 
 
 
 
 
 
98c2b8e
8e2b754
 
98c2b8e
 
 
 
 
8e2b754
 
 
 
 
98c2b8e
8e2b754
 
 
 
 
 
 
 
 
98c2b8e
8e2b754
 
 
98c2b8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import json
import logging
import os
import time
from typing import List
import urllib.request
import urllib.error

import pandas as pd
from tqdm import tqdm


logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float):
    total_lines = len(lines)
    train_lines = lines[:int(total_lines * train_proportion)]
    valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))]
    test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):]

    with open(f"{output_dir}/train_dataset.json", "w") as f:
        f.write("\n".join(train_lines))

    with open(f"{output_dir}/valid_dataset.json", "w") as f:
        f.write("\n".join(valid_lines))

    with open(f"{output_dir}/test_dataset.json", "w") as f:
        f.write("\n".join(test_lines))

def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, backup_period: int, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=0.1, retries: int=5):
    os.makedirs(output_dir, exist_ok=True)
    logger.info("Loading dataset")
    df = pd.read_csv(tsv, sep="\t", engine="python")
    df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())]
    # Shuffle
    df = df.sample(frac=1.0, random_state=seed)
    logger.info("Download started")
    lines = []
    count = 0
    try:
        with tqdm(total=len(df)) as pbar:
            for i, row in tqdm(df.iterrows()):
                url = row[url_col]
                caption = row[caption_col]
                # Trim image file names so that they are no longer than 100 characters
                image_filename = url.split('/')[-1][-100:]
                image_path = f"{output_dir}/{image_filename}"
                for retry in range(retries):
                    try:
                        # Download file
                        urllib.request.urlretrieve(url, image_path)
                        lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False))
                        count += 1
                        break
                    except urllib.error.HTTPError as e:
                        # time.sleep(pause)
                        pass
                if count % backup_period == 0:
                    logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
                    split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
                if retry == retries:
                    raise ValueError("Rate limit achieved:", e)
                pbar.update(1)
    # Save existing dataset, even upon failure
    finally:
        split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset")
    parser.add_argument("--tsv", type=str, default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv")
    parser.add_argument("--language", type=str, default="es")
    parser.add_argument("--output_dir", type=str, default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset")
    parser.add_argument("--random_seed", type=int, default=0)
    parser.add_argument("--train_proportion", type=float, default=0.8)
    parser.add_argument("--valid_proportion", type=float, default=0.1)
    parser.add_argument("--backup_period", type=int, default=1000)

    args = parser.parse_args()
    assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0"
    prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion, args.backup_period)