File size: 2,549 Bytes
44db343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from params import *
from dataset.vocab import Vocab
from dataset.util import load_dataset, load_vsec_dataset

if __name__ == "__main__":
    import argparse

    description = '''
        Corrector:

        Usage: python corrector.py --model tfmwtr --data_path ./data --dataset binhvq

        Params:
            --model
                    tfmwtr - Transformer with Tokenization Repair
            --data_path:    default to ./data
            --dataset:      default to 'binhvq'
                    
    '''
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('--model', type=str, default='tfmwtr')
    parser.add_argument('--data_path', type=str, default='./data')
    parser.add_argument('--dataset', type=str, default='binhvq')
    parser.add_argument('--test_dataset', type=str, default='binhvq')
    parser.add_argument("--beams", type=int, default=2)
    parser.add_argument("--fraction", type=float, default= 1.0)
    parser.add_argument('--text', type=str, default='Bình mnh ơi day ch ưa, café xáng vớitôi dược không?')
    args = parser.parse_args()

    dataset_path = os.path.join(args.data_path, f'{args.test_dataset}')

    weight_ext = 'pth'

    checkpoint_dir = os.path.join(args.data_path, f'checkpoints/{args.model}')

    weight_path = os.path.join(checkpoint_dir, f'{args.dataset}.weights.{weight_ext}')
    vocab_path = os.path.join(args.data_path, f'binhvq/binhvq.vocab.pkl')

    correct_file = f'{args.test_dataset}.test'
    incorrect_file = f'{args.test_dataset}.test.noise'
    length_file = f'{args.dataset}.length.test'

    if args.test_dataset != "vsec":
        test_data = load_dataset(base_path=dataset_path, corr_file=correct_file, incorr_file=incorrect_file,
                              length_file=length_file)
    else:
        test_data = load_vsec_dataset(base_path=dataset_path, corr_file=correct_file, incorr_file=incorrect_file)

    length_of_data = len(test_data)
    test_data = test_data[0 : int(args.fraction * length_of_data) ]

    vocab = Vocab()
    vocab.load_vocab_dict(vocab_path)

    from dataset.autocorrect_dataset import SpellCorrectDataset
    from models.corrector import Corrector
    from models.model import ModelWrapper
    from models.util import load_weights

    test_dataset = SpellCorrectDataset(dataset=test_data)

    model_wrapper = ModelWrapper(args.model, vocab)

    corrector = Corrector(model_wrapper)

    load_weights(corrector.model, weight_path)

    corrector.evaluate(test_dataset, beams = args.beams)