SpaGAN / models /spabert /train_mlm.py
JasonTPhillipsJr's picture
Upload 76 files
46e0dd0 verified
raw
history blame
8.55 kB
import os
import sys
from transformers import RobertaTokenizer, BertTokenizer
from tqdm import tqdm # for our progress bar
from transformers import AdamW
import torch
from torch.utils.data import DataLoader
from models.spatial_bert_model import SpatialBertModel
from models.spatial_bert_model import SpatialBertConfig
from models.spatial_bert_model import SpatialBertForMaskedLM
from datasets.osm_sample_loader import PbfMapDataset
from transformers.models.bert.modeling_bert import BertForMaskedLM
import numpy as np
import argparse
import pdb
DEBUG = False
def training(args):
num_workers = args.num_workers
batch_size = args.batch_size
epochs = args.epochs
lr = args.lr #1e-7 # 5e-5
save_interval = args.save_interval
max_token_len = args.max_token_len
distance_norm_factor = args.distance_norm_factor
spatial_dist_fill=args.spatial_dist_fill
with_type = args.with_type
sep_between_neighbors = args.sep_between_neighbors
freeze_backbone = args.freeze_backbone
bert_option = args.bert_option
if_no_spatial_distance = args.no_spatial_distance
assert bert_option in ['bert-base','bert-large']
london_file_path = '../data/sql_output/osm-point-london.json'
california_file_path = '../data/sql_output/osm-point-california.json'
if args.model_save_dir is None:
sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
context_pathstr = '_nocontext' if if_no_spatial_distance else '_withcontext'
model_save_dir = '/data2/zekun/spatial_bert_weights/mlm_mem_lr' + str("{:.0e}".format(lr)) + sep_pathstr + context_pathstr +'/'+bert_option+ freeze_pathstr + '_mlm_mem_london_california_bsize' + str(batch_size)
if not os.path.isdir(model_save_dir):
os.makedirs(model_save_dir)
else:
model_save_dir = args.model_save_dir
print('model_save_dir', model_save_dir)
print('\n')
if bert_option == 'bert-base':
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance)
elif bert_option == 'bert-large':
bert_model = BertForMaskedLM.from_pretrained('bert-large-uncased')
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
else:
raise NotImplementedError
model = SpatialBertForMaskedLM(config)
model.load_state_dict(bert_model.state_dict() , strict = False) # load sentence position embedding weights as well
if bert_option == 'bert-large' and freeze_backbone:
print('freezing backbone weights')
for param in model.parameters():
param.requires_grad = False
for param in model.cls.parameters():
param.requires_grad = True
for param in model.bert.encoder.layer[21].parameters():
param.requires_grad = True
for param in model.bert.encoder.layer[22].parameters():
param.requires_grad = True
for param in model.bert.encoder.layer[23].parameters():
param.requires_grad = True
london_train_dataset = PbfMapDataset(data_file_path = london_file_path,
tokenizer = tokenizer,
max_token_len = max_token_len,
distance_norm_factor = distance_norm_factor,
spatial_dist_fill = spatial_dist_fill,
with_type = with_type,
sep_between_neighbors = sep_between_neighbors,
label_encoder = None,
mode = None)
california_train_dataset = PbfMapDataset(data_file_path = california_file_path,
tokenizer = tokenizer,
max_token_len = max_token_len,
distance_norm_factor = distance_norm_factor,
spatial_dist_fill = spatial_dist_fill,
with_type = with_type,
sep_between_neighbors = sep_between_neighbors,
label_encoder = None,
mode = None)
train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
if DEBUG:
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
shuffle=False, pin_memory=True, drop_last=True)
else:
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
shuffle=True, pin_memory=True, drop_last=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.train()
# initialize optimizer
optim = AdamW(model.parameters(), lr = lr)
print('start training...')
for epoch in range(epochs):
# setup loop with TQDM and dataloader
loop = tqdm(train_loader, leave=True)
iter = 0
for batch in loop:
# initialize calculated gradients (from prev step)
optim.zero_grad()
# pull all tensor batches required for training
input_ids = batch['masked_input'].to(device)
attention_mask = batch['attention_mask'].to(device)
position_list_x = batch['norm_lng_list'].to(device)
position_list_y = batch['norm_lat_list'].to(device)
sent_position_ids = batch['sent_position_ids'].to(device)
labels = batch['pseudo_sentence'].to(device)
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels)
loss = outputs.loss
loss.backward()
optim.step()
loop.set_description(f'Epoch {epoch}')
loop.set_postfix({'loss':loss.item()})
if DEBUG:
print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() )
iter += 1
if iter % save_interval == 0 or iter == loop.total:
save_path = os.path.join(model_save_dir, 'mlm_mem_keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
+ '_' +str("{:.4f}".format(loss.item())) +'.pth' )
torch.save(model.state_dict(), save_path)
print('saving model checkpoint to', save_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--num_workers', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=12)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--save_interval', type=int, default=2000)
parser.add_argument('--max_token_len', type=int, default=300)
parser.add_argument('--lr', type=float, default = 5e-5)
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
parser.add_argument('--with_type', default=False, action='store_true')
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
parser.add_argument('--freeze_backbone', default=False, action='store_true')
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
parser.add_argument('--bert_option', type=str, default='bert-base')
parser.add_argument('--model_save_dir', type=str, default=None)
args = parser.parse_args()
print('\n')
print(args)
print('\n')
# out_dir not None, and out_dir does not exist, then create out_dir
if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
os.makedirs(args.model_save_dir)
training(args)
if __name__ == '__main__':
main()