File size: 8,612 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""
This script provides an example to wrap TencentPretrain for ChID (a multiple choice dataset).
"""
import sys
import os
import argparse
import json
import random
import torch

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.utils.constants import *
from tencentpretrain.utils.tokenizers import *
from tencentpretrain.utils.optimizers import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.seed import set_seed
from tencentpretrain.utils.logging import init_logger
from tencentpretrain.model_saver import save_model
from tencentpretrain.opts import finetune_opts, adv_opts
from finetune.run_c3 import MultipleChoice
from finetune.run_classifier import build_optimizer, load_or_initialize_parameters, train_model, batch_loader, evaluate


def tokenize_chid(text):
    output = []
    first_idiom = True
    while True:
        if first_idiom:
            idiom_index = text.find("#idiom")
            output.extend(text[:idiom_index])
            output.append(text[idiom_index : idiom_index + 13])
            pre_idiom_index = idiom_index
            first_idiom = False
        else:
            if text[idiom_index + 1 :].find("#idiom") == -1:
                output.extend(text[pre_idiom_index + 13 :])
                break
            else:
                idiom_index = idiom_index + 1 + text[idiom_index + 1 :].find("#idiom")
                output.extend(text[pre_idiom_index + 13 : idiom_index])
                output.append(text[idiom_index : idiom_index + 13])
                pre_idiom_index = idiom_index

    return output


def add_tokens_around(tokens, idiom_index, tokens_num):
    left_tokens_num = tokens_num // 2
    right_tokens_num = tokens_num - left_tokens_num

    if idiom_index >= left_tokens_num and (len(tokens) - 1 - idiom_index) >= right_tokens_num:
        left_tokens = tokens[idiom_index - left_tokens_num : idiom_index]
        right_tokens = tokens[idiom_index + 1 : idiom_index + 1 + right_tokens_num]
    elif idiom_index < left_tokens_num:
        left_tokens = tokens[:idiom_index]
        right_tokens = tokens[idiom_index + 1 : idiom_index + 1 + tokens_num - len(left_tokens)]
    elif (len(tokens) - 1 - idiom_index) < right_tokens_num:
        right_tokens = tokens[idiom_index + 1 :]
        left_tokens = tokens[idiom_index - (tokens_num - len(right_tokens)) : idiom_index]

    return left_tokens, right_tokens


def read_dataset(args, data_path, answer_path):
    if answer_path is not None:
        answers = json.load(open(answer_path))
    dataset = []
    max_tokens_for_doc = args.seq_length - 3
    group_index = 0

    for line in open(data_path, mode="r", encoding="utf-8"):
        example = json.loads(line)
        options = example["candidates"]
        for context in example["content"]:
            chid_tokens = tokenize_chid(context)
            tags = [token for token in chid_tokens if "#idiom" in token]
            for tag in tags:
                if answer_path is not None:
                    tgt = answers[tag]
                else:
                    tgt = -1
                tokens = []
                for i, token in enumerate(chid_tokens):
                    if "#idiom" in token:
                        sub_tokens = [str(token)]
                    else:
                        sub_tokens = args.tokenizer.tokenize(token)
                    for sub_token in sub_tokens:
                        tokens.append(sub_token)
                idiom_index = tokens.index(tag)
                left_tokens, right_tokens = add_tokens_around(tokens, idiom_index, max_tokens_for_doc - 1)

                for i in range(len(left_tokens)):
                    if "#idiom" in left_tokens[i] and left_tokens[i] != tag:
                        left_tokens[i] = MASK_TOKEN
                for i in range(len(right_tokens)):
                    if "#idiom" in right_tokens[i] and right_tokens[i] != tag:
                        right_tokens[i] = MASK_TOKEN

                dataset.append(([], tgt, [], tag, group_index))

                for option in options:
                    option_tokens = args.tokenizer.tokenize(option)
                    tokens = [CLS_TOKEN] + option_tokens + [SEP_TOKEN] + left_tokens + [SEP_TOKEN] + right_tokens + [SEP_TOKEN]

                    src = args.tokenizer.convert_tokens_to_ids(tokens)[: args.seq_length]
                    seg = [0] * len(src)

                    while len(src) < args.seq_length:
                        src.append(0)
                        seg.append(0)

                    dataset[-1][0].append(src)
                    dataset[-1][2].append(seg)

                while len(dataset[-1][0]) < args.max_choices_num:
                    dataset[-1][0].append([0] * args.seq_length)
                    dataset[-1][2].append([0] * args.seq_length)
        group_index += 1

    return dataset


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    finetune_opts(parser)

    parser.add_argument("--vocab_path", default=None, type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--spm_model_path", default=None, type=str,
                        help="Path of the sentence piece model.")
    parser.add_argument("--train_answer_path", type=str, required=True,
                        help="Path of the answers for trainset.")
    parser.add_argument("--dev_answer_path", type=str, required=True,
                        help="Path of the answers for devset.")

    parser.add_argument("--max_choices_num", default=10, type=int,
                        help="The maximum number of cadicate answer, shorter than this will be padded.")

    adv_opts(parser)

    args = parser.parse_args()

    args.labels_num = args.max_choices_num

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    set_seed(args.seed)

    # Build tokenizer.
    args.tokenizer = CharTokenizer(args)

    # Build multiple choice model.
    model = MultipleChoice(args)

    # Load or initialize parameters.
    load_or_initialize_parameters(args, model)

    # Get logger.
    args.logger = init_logger(args)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)

    # Training phase.
    trainset = read_dataset(args, args.train_path, args.train_answer_path)
    instances_num = len(trainset)
    batch_size = args.batch_size

    args.train_steps = int(instances_num * args.epochs_num / batch_size) + 1

    args.logger.info("Batch size: {}".format(batch_size))
    args.logger.info("The number of training instances: {}".format(instances_num))

    optimizer, scheduler = build_optimizer(args, model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        args.amp = amp

    if torch.cuda.device_count() > 1:
        args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
    args.model = model

    if args.use_adv:
        args.adv_method = str2adv[args.adv_type](model)

    total_loss, result, best_result = 0.0, 0.0, 0.0

    args.logger.info("Start training.")

    for epoch in range(1, args.epochs_num + 1):
        random.shuffle(trainset)
        src = torch.LongTensor([example[0] for example in trainset])
        tgt = torch.LongTensor([example[1] for example in trainset])
        seg = torch.LongTensor([example[2] for example in trainset])

        model.train()
        for i, (src_batch, tgt_batch, seg_batch, _) in enumerate(batch_loader(batch_size, src, tgt, seg)):

            loss = train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_batch)
            total_loss += loss.item()

            if (i + 1) % args.report_steps == 0:
                args.logger.info("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i + 1, total_loss / args.report_steps))
                total_loss = 0.0

        result = evaluate(args, read_dataset(args, args.dev_path, args.dev_answer_path))
        if result[0] > best_result:
            best_result = result[0]
            save_model(model, args.output_model_path)


if __name__ == "__main__":
    main()