Spaces:
Running
on
T4
Running
on
T4
# -*- coding: utf-8 -*- | |
import argparse | |
import os, shutil, sys | |
import time | |
import warnings | |
warnings.filterwarnings("ignore") | |
# import from local folder | |
root_path = os.path.abspath('.') | |
sys.path.append(root_path) | |
from opt import opt | |
def storage_manage(): | |
if not os.path.exists("runs_last/"): | |
os.makedirs("runs_last/") | |
# copy to the new address | |
new_address = "runs_last/"+str(int(time.time()))+"/" | |
shutil.copytree("runs/", new_address) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--auto_resume_closest', action='store_true') | |
parser.add_argument('--auto_resume_best', action='store_true') | |
parser.add_argument('--pretrained_path', type = str, default="") | |
global args | |
args = parser.parse_args() | |
if args.auto_resume_closest and args.auto_resume_best: | |
print("you could only resume either nearest or best, not both") | |
os._exit(0) | |
if not args.auto_resume_closest and not args.auto_resume_best: | |
# Restart tensorboard (delete all things under ./runs) | |
if os.path.exists("./runs"): | |
storage_manage() | |
shutil.rmtree("./runs") | |
def folder_prepare(): | |
def _make_folder(folder_name): | |
if not os.path.exists(folder_name): | |
os.makedirs(folder_name) | |
def _delete_and_make_folder(folder_name): | |
if os.path.exists(folder_name): | |
shutil.rmtree(folder_name) | |
os.makedirs(folder_name) | |
# The lists we care about | |
make_folder_name_lists = ["saved_models/", "saved_models/checkpoints/", "datasets/"] | |
delete_and_make_folder_name_lists = [] | |
for folder_name in make_folder_name_lists: | |
_make_folder(folder_name) | |
for folder_name in delete_and_make_folder_name_lists: | |
_delete_and_make_folder(folder_name) | |
def process(options): | |
print(args) | |
start = time.time() | |
# Switch based on the model architecture | |
if options['architecture'] == "ESRNET": | |
from train_esrnet import train_esrnet | |
obj = train_esrnet(options, args) | |
elif options['architecture'] == "ESRGAN": | |
from train_esrgan import train_esrgan | |
obj = train_esrgan(options, args) | |
elif options['architecture'] == "GRL": | |
from train_grl import train_grl | |
obj = train_grl(options, args) | |
elif options['architecture'] == "GRLGAN": | |
from train_grlgan import train_grlgan | |
obj = train_grlgan(options, args) | |
elif options['architecture'] == "CUNET": | |
from train_cunet import train_cunet | |
obj = train_cunet(options, args) | |
elif options['architecture'] == "CUGAN": | |
from train_cugan import train_cugan | |
obj = train_cugan(options, args) | |
else: | |
raise NotImplementedError("This is not a supported model architecture") | |
obj.run() | |
total_time = time.time() - start | |
print("All programs spent {} hour {} min {} s".format(str(total_time//3600), str((total_time%3600)//60), str(total_time%3600))) | |
def main(): | |
parse_args() | |
folder_prepare() | |
process(opt) | |
if __name__ == "__main__": | |
main() |