semantic-demo / model.py
jwalanthi's picture
actual fix hopefully
99ad741
raw
history blame
15.1 kB
import torch
import lightning
from torch.utils.data import Dataset
from typing import Any, Dict
import argparse
from pydantic import BaseModel
from get_dataset_dictionaries import get_dict_pair
import os
import shutil
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from functools import partial
class FFNModule(torch.nn.Module):
"""
A pytorch module that regresses from a hidden state representation of a word
to its continuous linguistic feature norm vector.
It is a FFN with the general structure of:
input -> (linear -> nonlinearity -> dropout) x (num_layers - 1) -> linear -> output
"""
def __init__(
self,
input_size: int,
output_size: int,
hidden_size: int,
num_layers: int,
dropout: float,
):
super(FFNModule, self).__init__()
layers = []
for _ in range(num_layers - 1):
layers.append(torch.nn.Linear(input_size, hidden_size))
layers.append(torch.nn.ReLU())
layers.append(torch.nn.Dropout(dropout))
# changes input size to hidden size after first layer
input_size = hidden_size
layers.append(torch.nn.Linear(hidden_size, output_size))
self.network = torch.nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
class FFNParams(BaseModel):
input_size: int
output_size: int
hidden_size: int
num_layers: int
dropout: float
class TrainingParams(BaseModel):
num_epochs: int
batch_size: int
learning_rate: float
weight_decay: float
class FeatureNormPredictor(lightning.LightningModule):
def __init__(self, ffn_params : FFNParams, training_params : TrainingParams):
super().__init__()
self.save_hyperparameters()
self.ffn_params = ffn_params
self.training_params = training_params
self.model = FFNModule(**ffn_params.model_dump())
self.loss_function = torch.nn.MSELoss()
self.training_params = training_params
def training_step(self, batch, batch_idx):
x,y = batch
outputs = self.model(x)
loss = self.loss_function(outputs, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x,y = batch
outputs = self.model(x)
loss = self.loss_function(outputs, y)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
return self.model(batch)
def predict(self, batch):
return self.model(batch)
def __call__(self, input):
return self.model(input)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.training_params.learning_rate,
weight_decay=self.training_params.weight_decay,
)
return optimizer
def save_model(self, path: str):
torch.save(self.model.state_dict(), path)
def load_model(self, path: str):
self.model.load_state_dict(torch.load(path))
class HiddenStateFeatureNormDataset(Dataset):
def __init__(
self,
input_embeddings: Dict[str, torch.Tensor],
feature_norms: Dict[str, torch.Tensor],
):
# Invariant: input_embeddings and target_feature_norms have exactly the same keys
# this should be done by the train/test split and upstream data processing
assert(input_embeddings.keys() == feature_norms.keys())
self.words = list(input_embeddings.keys())
self.input_embeddings = torch.stack([
input_embeddings[word] for word in self.words
])
self.feature_norms = torch.stack([
feature_norms[word] for word in self.words
])
def __len__(self):
return len(self.words)
def __getitem__(self, idx):
return self.input_embeddings[idx], self.feature_norms[idx]
# this is used when not optimizing
def train(args : Dict[str, Any]):
# input_embeddings = torch.load(args.input_embeddings)
# feature_norms = torch.load(args.feature_norms)
# words = list(input_embeddings.keys())
input_embeddings, feature_norms, norm_list = get_dict_pair(
args.norm,
args.embedding_dir,
args.lm_layer,
translated= False if args.raw_buchanan else True,
normalized= True if args.normal_buchanan else False
)
norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w')
norms_file.write("\n".join(norm_list))
norms_file.close()
words = list(input_embeddings.keys())
model = FeatureNormPredictor(
FFNParams(
input_size=input_embeddings[words[0]].shape[0],
output_size=feature_norms[words[0]].shape[0],
hidden_size=args.hidden_size,
num_layers=args.num_layers,
dropout=args.dropout,
),
TrainingParams(
num_epochs=args.num_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
),
)
# train/val split
train_size = int(len(words) * 0.8)
valid_size = len(words) - train_size
train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size])
# TODO: Methodology Decision: should we be normalizing the hidden states/feature norms?
train_embeddings = {word: input_embeddings[word] for word in train_words}
train_feature_norms = {word: feature_norms[word] for word in train_words}
validation_embeddings = {word: input_embeddings[word] for word in validation_words}
validation_feature_norms = {word: feature_norms[word] for word in validation_words}
train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
)
validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms)
validation_dataloader = torch.utils.data.DataLoader(
validation_dataset,
batch_size=args.batch_size,
shuffle=True,
)
callbacks = [
lightning.pytorch.callbacks.ModelCheckpoint(
save_last=True,
dirpath=args.save_dir,
filename=args.save_model_name,
),
]
if args.early_stopping is not None:
callbacks.append(lightning.pytorch.callbacks.EarlyStopping(
monitor="val_loss",
patience=args.early_stopping,
mode='min',
min_delta=0.0
))
#TODO Design Decision - other trainer args? Is device necessary?
# cpu is fine for the scale of this model - only a few layers and a few hundred words
trainer = lightning.Trainer(
max_epochs=args.num_epochs,
callbacks=callbacks,
accelerator="cpu",
log_every_n_steps=7
)
trainer.fit(model, train_dataloader, validation_dataloader)
trainer.validate(model, validation_dataloader)
return model
# this is used when optimizing
def objective(trial: optuna.trial.Trial, args: Dict[str, Any]) -> float:
# optimizing hidden size, batch size, and learning rate
input_embeddings, feature_norms, norm_list = get_dict_pair(
args.norm,
args.embedding_dir,
args.lm_layer,
translated= False if args.raw_buchanan else True,
normalized= True if args.normal_buchanan else False
)
norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w')
norms_file.write("\n".join(norm_list))
norms_file.close()
words = list(input_embeddings.keys())
input_size=input_embeddings[words[0]].shape[0]
output_size=feature_norms[words[0]].shape[0]
min_size = min(output_size, input_size)
max_size = min(output_size, 2*input_size)if min_size == input_size else min(2*output_size, input_size)
hidden_size = trial.suggest_int("hidden_size", min_size, max_size, log=True)
batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
learning_rate = trial.suggest_float("learning_rate", 1e-6, 1, log=True)
model = FeatureNormPredictor(
FFNParams(
input_size=input_size,
output_size=output_size,
hidden_size=hidden_size,
num_layers=args.num_layers,
dropout=args.dropout,
),
TrainingParams(
num_epochs=args.num_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=args.weight_decay,
),
)
# train/val split
train_size = int(len(words) * 0.8)
valid_size = len(words) - train_size
train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size])
train_embeddings = {word: input_embeddings[word] for word in train_words}
train_feature_norms = {word: feature_norms[word] for word in train_words}
validation_embeddings = {word: input_embeddings[word] for word in validation_words}
validation_feature_norms = {word: feature_norms[word] for word in validation_words}
train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
)
validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms)
validation_dataloader = torch.utils.data.DataLoader(
validation_dataset,
batch_size=args.batch_size,
shuffle=True,
)
callbacks = [
# all trial models will be saved in temporary directory
lightning.pytorch.callbacks.ModelCheckpoint(
save_last=True,
dirpath=os.path.join(args.save_dir,'optuna_trials'),
filename="{}".format(trial.number)
),
]
if args.prune is not None:
callbacks.append(PyTorchLightningPruningCallback(
trial,
monitor='val_loss'
))
if args.early_stopping is not None:
callbacks.append(lightning.pytorch.callbacks.EarlyStopping(
monitor="val_loss",
patience=args.early_stopping,
mode='min',
min_delta=0.0
))
# note that if optimizing is chosen, will automatically not implement vanilla early stopping
#TODO Design Decision - other trainer args? Is device necessary?
# cpu is fine for the scale of this model - only a few layers and a few hundred words
trainer = lightning.Trainer(
max_epochs=args.num_epochs,
callbacks=callbacks,
accelerator="cpu",
log_every_n_steps=7,
# enable_checkpointing=False
)
trainer.fit(model, train_dataloader, validation_dataloader)
trainer.validate(model, validation_dataloader)
return trainer.callback_metrics['val_loss'].item()
if __name__ == "__main__":
# parse args
parser = argparse.ArgumentParser()
#TODO: Design Decision: Should we input paths, to the pre-extracted layers, or the model/layer we want to generate them from
# required inputs
parser.add_argument("--norm", type=str, required=True, help="feature norm set to use")
parser.add_argument("--embedding_dir", type=str, required=True, help=" directory containing embeddings")
parser.add_argument("--lm_layer", type=int, required=True, help="layer of embeddings to use")
# if user selects optimize, hidden_size, batch_size and learning_rate will be optimized.
parser.add_argument("--optimize", action="store_true", help="optimize hyperparameters for training")
parser.add_argument("--prune", action="store_true", help="prune unpromising trials when optimizing")
# optional hyperparameter specs
parser.add_argument("--num_layers", type=int, default=2, help="number of layers in FFN")
parser.add_argument("--hidden_size", type=int, default=100, help="hidden size of FFN")
parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate of FFN")
# set this to at least 100 if doing early stopping
parser.add_argument("--num_epochs", type=int, default=10, help="number of epochs to train for")
parser.add_argument("--batch_size", type=int, default=32, help="batch size for training")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate for training")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for training")
parser.add_argument("--early_stopping", type=int, default=None, help="number of epochs to wait for early stopping")
# optional dataset specs, for buchanan really
parser.add_argument('--raw_buchanan', action="store_true", help="do not use translated values for buchanan")
parser.add_argument('--normal_buchanan', action="store_true", help="use normalized features for buchanan")
# required for output
parser.add_argument("--save_dir", type=str, required=True, help="directory to save model to")
parser.add_argument("--save_model_name", type=str, required=True, help="name of model to save")
args = parser.parse_args()
if args.early_stopping is not None:
args.num_epochs = max(50, args.num_epochs)
torch.manual_seed(10)
if args.optimize:
# call optimizer code here
print("optimizing for learning rate, batch size, and hidden size")
pruner = optuna.pruners.MedianPruner() if args.prune else optuna.pruners.NopPruner()
sampler = optuna.samplers.TPESampler(seed=10)
study = optuna.create_study(direction='minimize', pruner=pruner, sampler=sampler)
study.optimize(partial(objective, args=args), n_trials = 100, timeout=600)
other_params = {
"num_layers": args.num_layers,
"num_epochs": args.num_epochs,
"dropout": args.dropout,
"weight_decay": args.weight_decay,
}
print("Number of finished trials: {}".format(len(study.trials)))
trial = study.best_trial
print("Best trial: "+str(trial.number))
print(" Validation Loss: {}".format(trial.value))
print(" Optimized Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
print(" User Defined Params: ")
for key, value in other_params.items():
print(" {}: {}".format(key, value))
print('saving best trial')
for filename in os.listdir(os.path.join(args.save_dir,'optuna_trials')):
if filename == "{}.ckpt".format(trial.number):
shutil.move(os.path.join(args.save_dir,'optuna_trials',filename), os.path.join(args.save_dir, "{}.ckpt".format(args.save_model_name)))
shutil.rmtree(os.path.join(args.save_dir,'optuna_trials'))
else:
model = train(args)