|
import torch as t |
|
import torch.nn as nn |
|
import torch.functional as F |
|
import torch.optim as optim |
|
import argparse |
|
from utils import OsSoluConfig |
|
from model import OsSoluModel |
|
from typing import Tuple |
|
|
|
def parse_arguments() -> argparse.Namespace: |
|
|
|
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.") |
|
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.") |
|
parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.") |
|
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.") |
|
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.") |
|
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.") |
|
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.") |
|
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.") |
|
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.") |
|
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional. ") |
|
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.") |
|
args = parser.parse_args() |
|
return args |
|
|
|
def train(config: OsSoluConfig, model: OsSoluModel) -> OsSoluModel: |
|
|
|
|
|
return model |
|
|
|
def eval(): |
|
pass |
|
|
|
def setup() -> Tuple[OsSoluConfig, OsSoluModel]: |
|
|
|
args = parse_arguments() |
|
config = OsSoluConfig(args) |
|
model = OsSoluModel(config) |
|
return config, model |
|
|
|
if __name__=="__main__": |
|
config, model = setup() |
|
trained_model = train(config, model) |
|
eval() |