import torch import lightning from torch.utils.data import Dataset from typing import Any, Dict import argparse from pydantic import BaseModel from get_dataset_dictionaries import get_dict_pair import os import shutil import optuna from optuna.integration import PyTorchLightningPruningCallback from functools import partial class FFNModule(torch.nn.Module): """ A pytorch module that regresses from a hidden state representation of a word to its continuous linguistic feature norm vector. It is a FFN with the general structure of: input -> (linear -> nonlinearity -> dropout) x (num_layers - 1) -> linear -> output """ def __init__( self, input_size: int, output_size: int, hidden_size: int, num_layers: int, dropout: float, ): super(FFNModule, self).__init__() layers = [] for _ in range(num_layers - 1): layers.append(torch.nn.Linear(input_size, hidden_size)) layers.append(torch.nn.ReLU()) layers.append(torch.nn.Dropout(dropout)) # changes input size to hidden size after first layer input_size = hidden_size layers.append(torch.nn.Linear(hidden_size, output_size)) self.network = torch.nn.Sequential(*layers) def forward(self, x): return self.network(x) class FFNParams(BaseModel): input_size: int output_size: int hidden_size: int num_layers: int dropout: float class TrainingParams(BaseModel): num_epochs: int batch_size: int learning_rate: float weight_decay: float class FeatureNormPredictor(lightning.LightningModule): def __init__(self, ffn_params : FFNParams, training_params : TrainingParams): super().__init__() self.save_hyperparameters() self.ffn_params = ffn_params self.training_params = training_params self.model = FFNModule(**ffn_params.model_dump()) self.loss_function = torch.nn.MSELoss() self.training_params = training_params def training_step(self, batch, batch_idx): x,y = batch outputs = self.model(x) loss = self.loss_function(outputs, y) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): x,y = batch outputs = self.model(x) loss = self.loss_function(outputs, y) self.log("val_loss", loss, on_epoch=True, prog_bar=True) return loss def test_step(self, batch, batch_idx): return self.model(batch) def predict(self, batch): return self.model(batch) def __call__(self, input): return self.model(input) def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.training_params.learning_rate, weight_decay=self.training_params.weight_decay, ) return optimizer def save_model(self, path: str): torch.save(self.model.state_dict(), path) def load_model(self, path: str): self.model.load_state_dict(torch.load(path)) class HiddenStateFeatureNormDataset(Dataset): def __init__( self, input_embeddings: Dict[str, torch.Tensor], feature_norms: Dict[str, torch.Tensor], ): # Invariant: input_embeddings and target_feature_norms have exactly the same keys # this should be done by the train/test split and upstream data processing assert(input_embeddings.keys() == feature_norms.keys()) self.words = list(input_embeddings.keys()) self.input_embeddings = torch.stack([ input_embeddings[word] for word in self.words ]) self.feature_norms = torch.stack([ feature_norms[word] for word in self.words ]) def __len__(self): return len(self.words) def __getitem__(self, idx): return self.input_embeddings[idx], self.feature_norms[idx] # this is used when not optimizing def train(args : Dict[str, Any]): # input_embeddings = torch.load(args.input_embeddings) # feature_norms = torch.load(args.feature_norms) # words = list(input_embeddings.keys()) input_embeddings, feature_norms, norm_list = get_dict_pair( args.norm, args.embedding_dir, args.lm_layer, translated= False if args.raw_buchanan else True, normalized= True if args.normal_buchanan else False ) norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w') norms_file.write("\n".join(norm_list)) norms_file.close() words = list(input_embeddings.keys()) model = FeatureNormPredictor( FFNParams( input_size=input_embeddings[words[0]].shape[0], output_size=feature_norms[words[0]].shape[0], hidden_size=args.hidden_size, num_layers=args.num_layers, dropout=args.dropout, ), TrainingParams( num_epochs=args.num_epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, weight_decay=args.weight_decay, ), ) # train/val split train_size = int(len(words) * 0.8) valid_size = len(words) - train_size train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size]) # TODO: Methodology Decision: should we be normalizing the hidden states/feature norms? train_embeddings = {word: input_embeddings[word] for word in train_words} train_feature_norms = {word: feature_norms[word] for word in train_words} validation_embeddings = {word: input_embeddings[word] for word in validation_words} validation_feature_norms = {word: feature_norms[word] for word in validation_words} train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, ) validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, shuffle=True, ) callbacks = [ lightning.pytorch.callbacks.ModelCheckpoint( save_last=True, dirpath=args.save_dir, filename=args.save_model_name, ), ] if args.early_stopping is not None: callbacks.append(lightning.pytorch.callbacks.EarlyStopping( monitor="val_loss", patience=args.early_stopping, mode='min', min_delta=0.0 )) #TODO Design Decision - other trainer args? Is device necessary? # cpu is fine for the scale of this model - only a few layers and a few hundred words trainer = lightning.Trainer( max_epochs=args.num_epochs, callbacks=callbacks, accelerator="cpu", log_every_n_steps=7 ) trainer.fit(model, train_dataloader, validation_dataloader) trainer.validate(model, validation_dataloader) return model # this is used when optimizing def objective(trial: optuna.trial.Trial, args: Dict[str, Any]) -> float: # optimizing hidden size, batch size, and learning rate input_embeddings, feature_norms, norm_list = get_dict_pair( args.norm, args.embedding_dir, args.lm_layer, translated= False if args.raw_buchanan else True, normalized= True if args.normal_buchanan else False ) norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w') norms_file.write("\n".join(norm_list)) norms_file.close() words = list(input_embeddings.keys()) input_size=input_embeddings[words[0]].shape[0] output_size=feature_norms[words[0]].shape[0] min_size = min(output_size, input_size) max_size = min(output_size, 2*input_size)if min_size == input_size else min(2*output_size, input_size) hidden_size = trial.suggest_int("hidden_size", min_size, max_size, log=True) batch_size = trial.suggest_int("batch_size", 16, 128, log=True) learning_rate = trial.suggest_float("learning_rate", 1e-6, 1, log=True) model = FeatureNormPredictor( FFNParams( input_size=input_size, output_size=output_size, hidden_size=hidden_size, num_layers=args.num_layers, dropout=args.dropout, ), TrainingParams( num_epochs=args.num_epochs, batch_size=batch_size, learning_rate=learning_rate, weight_decay=args.weight_decay, ), ) # train/val split train_size = int(len(words) * 0.8) valid_size = len(words) - train_size train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size]) train_embeddings = {word: input_embeddings[word] for word in train_words} train_feature_norms = {word: feature_norms[word] for word in train_words} validation_embeddings = {word: input_embeddings[word] for word in validation_words} validation_feature_norms = {word: feature_norms[word] for word in validation_words} train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, ) validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, shuffle=True, ) callbacks = [ # all trial models will be saved in temporary directory lightning.pytorch.callbacks.ModelCheckpoint( save_last=True, dirpath=os.path.join(args.save_dir,'optuna_trials'), filename="{}".format(trial.number) ), ] if args.prune is not None: callbacks.append(PyTorchLightningPruningCallback( trial, monitor='val_loss' )) if args.early_stopping is not None: callbacks.append(lightning.pytorch.callbacks.EarlyStopping( monitor="val_loss", patience=args.early_stopping, mode='min', min_delta=0.0 )) # note that if optimizing is chosen, will automatically not implement vanilla early stopping #TODO Design Decision - other trainer args? Is device necessary? # cpu is fine for the scale of this model - only a few layers and a few hundred words trainer = lightning.Trainer( max_epochs=args.num_epochs, callbacks=callbacks, accelerator="cpu", log_every_n_steps=7, # enable_checkpointing=False ) trainer.fit(model, train_dataloader, validation_dataloader) trainer.validate(model, validation_dataloader) return trainer.callback_metrics['val_loss'].item() if __name__ == "__main__": # parse args parser = argparse.ArgumentParser() #TODO: Design Decision: Should we input paths, to the pre-extracted layers, or the model/layer we want to generate them from # required inputs parser.add_argument("--norm", type=str, required=True, help="feature norm set to use") parser.add_argument("--embedding_dir", type=str, required=True, help=" directory containing embeddings") parser.add_argument("--lm_layer", type=int, required=True, help="layer of embeddings to use") # if user selects optimize, hidden_size, batch_size and learning_rate will be optimized. parser.add_argument("--optimize", action="store_true", help="optimize hyperparameters for training") parser.add_argument("--prune", action="store_true", help="prune unpromising trials when optimizing") # optional hyperparameter specs parser.add_argument("--num_layers", type=int, default=2, help="number of layers in FFN") parser.add_argument("--hidden_size", type=int, default=100, help="hidden size of FFN") parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate of FFN") # set this to at least 100 if doing early stopping parser.add_argument("--num_epochs", type=int, default=10, help="number of epochs to train for") parser.add_argument("--batch_size", type=int, default=32, help="batch size for training") parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate for training") parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for training") parser.add_argument("--early_stopping", type=int, default=None, help="number of epochs to wait for early stopping") # optional dataset specs, for buchanan really parser.add_argument('--raw_buchanan', action="store_true", help="do not use translated values for buchanan") parser.add_argument('--normal_buchanan', action="store_true", help="use normalized features for buchanan") # required for output parser.add_argument("--save_dir", type=str, required=True, help="directory to save model to") parser.add_argument("--save_model_name", type=str, required=True, help="name of model to save") args = parser.parse_args() if args.early_stopping is not None: args.num_epochs = max(50, args.num_epochs) torch.manual_seed(10) if args.optimize: # call optimizer code here print("optimizing for learning rate, batch size, and hidden size") pruner = optuna.pruners.MedianPruner() if args.prune else optuna.pruners.NopPruner() sampler = optuna.samplers.TPESampler(seed=10) study = optuna.create_study(direction='minimize', pruner=pruner, sampler=sampler) study.optimize(partial(objective, args=args), n_trials = 100, timeout=600) other_params = { "num_layers": args.num_layers, "num_epochs": args.num_epochs, "dropout": args.dropout, "weight_decay": args.weight_decay, } print("Number of finished trials: {}".format(len(study.trials))) trial = study.best_trial print("Best trial: "+str(trial.number)) print(" Validation Loss: {}".format(trial.value)) print(" Optimized Params: ") for key, value in trial.params.items(): print(" {}: {}".format(key, value)) print(" User Defined Params: ") for key, value in other_params.items(): print(" {}: {}".format(key, value)) print('saving best trial') for filename in os.listdir(os.path.join(args.save_dir,'optuna_trials')): if filename == "{}.ckpt".format(trial.number): shutil.move(os.path.join(args.save_dir,'optuna_trials',filename), os.path.join(args.save_dir, "{}.ckpt".format(args.save_model_name))) shutil.rmtree(os.path.join(args.save_dir,'optuna_trials')) else: model = train(args)