Spaces:
Running
Running
File size: 5,539 Bytes
5c72fe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from src.dataset import TokenizerDataset
from src.bert import BERT
from src.pretrainer import BERTFineTuneTrainer1
from src.vocab import Vocab
import pandas as pd
def preprocess_labels(label_csv_path):
try:
labels_df = pd.read_csv(label_csv_path)
labels = labels_df['last_hint_class'].values.astype(int)
return torch.tensor(labels, dtype=torch.long)
except Exception as e:
print(f"Error reading dataset file: {e}")
return None
def preprocess_data(data_path, vocab, max_length=128):
try:
with open(data_path, 'r') as f:
sequences = f.readlines()
except Exception as e:
print(f"Error reading data file: {e}")
return None, None
tokenized_sequences = []
for sequence in sequences:
sequence = sequence.strip()
if sequence:
encoded = vocab.to_seq(sequence, seq_len=max_length)
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
segment_label = [0] * max_length
tokenized_sequences.append({
'input_ids': torch.tensor(encoded),
'segment_label': torch.tensor(segment_label)
})
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)
print(f"Input IDs shape: {input_ids.shape}")
print(f"Segment labels shape: {segment_labels.shape}")
return input_ids, segment_labels
def custom_collate_fn(batch):
inputs = [item['input_ids'].unsqueeze(0) for item in batch]
labels = [item['label'].unsqueeze(0) for item in batch]
segment_labels = [item['segment_label'].unsqueeze(0) for item in batch]
inputs = torch.cat(inputs, dim=0)
labels = torch.cat(labels, dim=0)
segment_labels = torch.cat(segment_labels, dim=0)
return {
'input': inputs,
'label': labels,
'segment_label': segment_labels
}
def main(opt):
# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load vocabulary
vocab = Vocab(opt.vocab_file)
vocab.load_vocab()
# Preprocess data and labels
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=50) # Using sequence length 50
labels = preprocess_labels(opt.dataset)
if input_ids is None or segment_labels is None or labels is None:
print("Error in preprocessing data. Exiting.")
return
# Create TensorDataset and split into train and validation sets
dataset = TensorDataset(input_ids, segment_labels, labels)
val_size = len(dataset) - int(0.8 * len(dataset))
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
# Create DataLoaders for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
# Initialize custom BERT model and move it to the device
custom_model = CustomBERTModel(
vocab_size=len(vocab.vocab),
output_dim=2,
pre_trained_model_path=opt.pre_trained_model_path
).to(device)
# Initialize the fine-tuning trainer
trainer = BERTFineTuneTrainer1(
bert=custom_model,
vocab_size=len(vocab.vocab),
train_dataloader=train_dataloader,
test_dataloader=val_dataloader,
lr=1e-5, # Using learning rate 10^-5 as specified
num_labels=2,
with_cuda=torch.cuda.is_available(),
log_freq=10,
workspace_name=opt.output_dir,
log_folder_path=opt.log_folder_path
)
# Train the model
trainer.train(epoch=20)
# Save the model
os.makedirs(opt.output_dir, exist_ok=True)
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_3.pth')
torch.save(custom_model, output_model_file)
print(f'Model saved to {output_model_file}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct', help='Path to the folder for saving logs.')
opt = parser.parse_args()
main(opt)
|