File size: 598 Bytes
404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
from train import Trainer
from config import get_config
from utils import prepare_dirs
from data_loader import get_data_loader
def main(config):
# ensure directories are setup
prepare_dirs(config)
# ensure reproducibility
torch.manual_seed(config.seed)
if config.use_gpu:
torch.cuda.manual_seed(config.seed)
# instantiate train data loaders
train_loader = get_data_loader(config=config)
trainer = Trainer(config, train_loader=train_loader)
trainer.train()
if __name__ == '__main__':
config, unparsed = get_config()
main(config) |