os-solu / main.py
inwaves's picture
Refactor config class, add argparser
405f5b1
raw
history blame
1.96 kB
import torch as t
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import argparse
from utils import OsSoluConfig
from model import OsSoluModel
from typing import Tuple
def parse_arguments() -> argparse.Namespace:
# TODO: command-line args for hparams
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional. ")
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
args = parser.parse_args()
return args
def train(config: OsSoluConfig, model: OsSoluModel) -> OsSoluModel:
# TODO: training loop
return model
def eval():
pass
def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
# TODO: wandb logging
args = parse_arguments()
config = OsSoluConfig(args)
model = OsSoluModel(config)
return config, model
if __name__=="__main__":
config, model = setup()
trained_model = train(config, model)
eval()