File size: 5,539 Bytes
5c72fe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from src.dataset import TokenizerDataset
from src.bert import BERT
from src.pretrainer import BERTFineTuneTrainer1
from src.vocab import Vocab
import pandas as pd

def preprocess_labels(label_csv_path):
    try:
        labels_df = pd.read_csv(label_csv_path)
        labels = labels_df['last_hint_class'].values.astype(int)
        return torch.tensor(labels, dtype=torch.long)
    except Exception as e:
        print(f"Error reading dataset file: {e}")
        return None

def preprocess_data(data_path, vocab, max_length=128):
    try:
        with open(data_path, 'r') as f:
            sequences = f.readlines()
    except Exception as e:
        print(f"Error reading data file: {e}")
        return None, None

    tokenized_sequences = []
    for sequence in sequences:
        sequence = sequence.strip()
        if sequence:
            encoded = vocab.to_seq(sequence, seq_len=max_length)
            encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
            segment_label = [0] * max_length

            tokenized_sequences.append({
                'input_ids': torch.tensor(encoded),
                'segment_label': torch.tensor(segment_label)
            })

    input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
    segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)

    print(f"Input IDs shape: {input_ids.shape}")
    print(f"Segment labels shape: {segment_labels.shape}")

    return input_ids, segment_labels

def custom_collate_fn(batch):
    inputs = [item['input_ids'].unsqueeze(0) for item in batch]
    labels = [item['label'].unsqueeze(0) for item in batch]
    segment_labels = [item['segment_label'].unsqueeze(0) for item in batch]

    inputs = torch.cat(inputs, dim=0)
    labels = torch.cat(labels, dim=0)
    segment_labels = torch.cat(segment_labels, dim=0)

    return {
        'input': inputs,
        'label': labels,
        'segment_label': segment_labels
    }

def main(opt):
    # Set device to GPU if available, otherwise use CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load vocabulary
    vocab = Vocab(opt.vocab_file)
    vocab.load_vocab()

    # Preprocess data and labels
    input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=50)  # Using sequence length 50
    labels = preprocess_labels(opt.dataset)

    if input_ids is None or segment_labels is None or labels is None:
        print("Error in preprocessing data. Exiting.")
        return

    # Create TensorDataset and split into train and validation sets
    dataset = TensorDataset(input_ids, segment_labels, labels)
    val_size = len(dataset) - int(0.8 * len(dataset))
    val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])

    # Create DataLoaders for training and validation
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)

    # Initialize custom BERT model and move it to the device
    custom_model = CustomBERTModel(
        vocab_size=len(vocab.vocab),
        output_dim=2,
        pre_trained_model_path=opt.pre_trained_model_path
    ).to(device)

    # Initialize the fine-tuning trainer
    trainer = BERTFineTuneTrainer1(
        bert=custom_model,
        vocab_size=len(vocab.vocab),
        train_dataloader=train_dataloader,
        test_dataloader=val_dataloader,
        lr=1e-5,  # Using learning rate 10^-5 as specified
        num_labels=2,
        with_cuda=torch.cuda.is_available(),
        log_freq=10,
        workspace_name=opt.output_dir,
        log_folder_path=opt.log_folder_path
    )

    # Train the model
    trainer.train(epoch=20)

    # Save the model
    os.makedirs(opt.output_dir, exist_ok=True)
    output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_3.pth')
    torch.save(custom_model, output_model_file)
    print(f'Model saved to {output_model_file}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
    parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
    parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
    parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
    parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
    parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
    parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct', help='Path to the folder for saving logs.')


    opt = parser.parse_args()
    main(opt)