Spaces:
Runtime error
Runtime error
# Copyright 2022 Google LLC | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""The training loop for frame interpolation. | |
gin_config: The gin configuration file containing model, losses and datasets. | |
To run on GPUs: | |
python3 -m frame_interpolation.training.train \ | |
--gin_config <path to network.gin> \ | |
--base_folder <base folder for all training runs> \ | |
--label <descriptive label for the run> | |
To debug the training loop on CPU: | |
python3 -m frame_interpolation.training.train \ | |
--gin_config <path to config.gin> \ | |
--base_folder /tmp | |
--label test_run \ | |
--mode cpu | |
The training output directory will be created at <base_folder>/<label>. | |
""" | |
import os | |
from . import augmentation_lib | |
from . import data_lib | |
from . import eval_lib | |
from . import metrics_lib | |
from . import model_lib | |
from . import train_lib | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import gin.tf | |
from ..losses import losses | |
# Reduce tensorflow logs to ERRORs only. | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
import tensorflow as tf # pylint: disable=g-import-not-at-top | |
tf.get_logger().setLevel('ERROR') | |
_GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.') | |
_LABEL = flags.DEFINE_string('label', 'run0', | |
'Descriptive label for this run.') | |
_BASE_FOLDER = flags.DEFINE_string('base_folder', None, | |
'Path to checkpoints/summaries.') | |
_MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'], | |
'Distributed strategy approach.') | |
class TrainingOptions(object): | |
"""Training-related options.""" | |
def __init__(self, learning_rate: float, learning_rate_decay_steps: int, | |
learning_rate_decay_rate: int, learning_rate_staircase: int, | |
num_steps: int): | |
self.learning_rate = learning_rate | |
self.learning_rate_decay_steps = learning_rate_decay_steps | |
self.learning_rate_decay_rate = learning_rate_decay_rate | |
self.learning_rate_staircase = learning_rate_staircase | |
self.num_steps = num_steps | |
def main(argv): | |
if len(argv) > 1: | |
raise app.UsageError('Too many command-line arguments.') | |
output_dir = os.path.join(_BASE_FOLDER.value, _LABEL.value) | |
logging.info('Creating output_dir @ %s ...', output_dir) | |
# Copy config file to <base_folder>/<label>/config.gin. | |
tf.io.gfile.makedirs(output_dir) | |
tf.io.gfile.copy( | |
_GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True) | |
gin.external_configurable( | |
tf.keras.optimizers.schedules.PiecewiseConstantDecay, | |
module='tf.keras.optimizers.schedules') | |
gin_configs = [_GIN_CONFIG.value] | |
gin.parse_config_files_and_bindings( | |
config_files=gin_configs, bindings=None, skip_unknown=True) | |
training_options = TrainingOptions() # pylint: disable=no-value-for-parameter | |
learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( | |
training_options.learning_rate, | |
training_options.learning_rate_decay_steps, | |
training_options.learning_rate_decay_rate, | |
training_options.learning_rate_staircase, | |
name='learning_rate') | |
# Initialize data augmentation functions | |
augmentation_fns = augmentation_lib.data_augmentations() | |
saved_model_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, | |
'saved_model') | |
train_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train') | |
eval_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'eval') | |
train_lib.train( | |
strategy=train_lib.get_strategy(_MODE.value), | |
train_folder=train_folder, | |
saved_model_folder=saved_model_folder, | |
n_iterations=training_options.num_steps, | |
create_model_fn=model_lib.create_model, | |
create_losses_fn=losses.training_losses, | |
create_metrics_fn=metrics_lib.create_metrics_fn, | |
dataset=data_lib.create_training_dataset( | |
augmentation_fns=augmentation_fns), | |
learning_rate=learning_rate, | |
eval_loop_fn=eval_lib.eval_loop, | |
eval_folder=eval_folder, | |
eval_datasets=data_lib.create_eval_datasets() or None) | |
if __name__ == '__main__': | |
app.run(main) | |