APISR / train_code /train.py
HikariDawn's picture
feat: initial push
561c629
# -*- coding: utf-8 -*-
import argparse
import os, shutil, sys
import time
import warnings
warnings.filterwarnings("ignore")
# import from local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt
def storage_manage():
if not os.path.exists("runs_last/"):
os.makedirs("runs_last/")
# copy to the new address
new_address = "runs_last/"+str(int(time.time()))+"/"
shutil.copytree("runs/", new_address)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--auto_resume_closest', action='store_true')
parser.add_argument('--auto_resume_best', action='store_true')
parser.add_argument('--pretrained_path', type = str, default="")
global args
args = parser.parse_args()
if args.auto_resume_closest and args.auto_resume_best:
print("you could only resume either nearest or best, not both")
os._exit(0)
if not args.auto_resume_closest and not args.auto_resume_best:
# Restart tensorboard (delete all things under ./runs)
if os.path.exists("./runs"):
storage_manage()
shutil.rmtree("./runs")
def folder_prepare():
def _make_folder(folder_name):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
def _delete_and_make_folder(folder_name):
if os.path.exists(folder_name):
shutil.rmtree(folder_name)
os.makedirs(folder_name)
# The lists we care about
make_folder_name_lists = ["saved_models/", "saved_models/checkpoints/", "datasets/"]
delete_and_make_folder_name_lists = []
for folder_name in make_folder_name_lists:
_make_folder(folder_name)
for folder_name in delete_and_make_folder_name_lists:
_delete_and_make_folder(folder_name)
def process(options):
print(args)
start = time.time()
# Switch based on the model architecture
if options['architecture'] == "ESRNET":
from train_esrnet import train_esrnet
obj = train_esrnet(options, args)
elif options['architecture'] == "ESRGAN":
from train_esrgan import train_esrgan
obj = train_esrgan(options, args)
elif options['architecture'] == "GRL":
from train_grl import train_grl
obj = train_grl(options, args)
elif options['architecture'] == "GRLGAN":
from train_grlgan import train_grlgan
obj = train_grlgan(options, args)
elif options['architecture'] == "CUNET":
from train_cunet import train_cunet
obj = train_cunet(options, args)
elif options['architecture'] == "CUGAN":
from train_cugan import train_cugan
obj = train_cugan(options, args)
else:
raise NotImplementedError("This is not a supported model architecture")
obj.run()
total_time = time.time() - start
print("All programs spent {} hour {} min {} s".format(str(total_time//3600), str((total_time%3600)//60), str(total_time%3600)))
def main():
parse_args()
folder_prepare()
process(opt)
if __name__ == "__main__":
main()