Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
try: | |
import apex | |
except ImportError: | |
print("apex is not installed") | |
from mmcv.runner import OptimizerHook, HOOKS | |
class DistOptimizerHook(OptimizerHook): | |
"""Optimizer hook for distributed training.""" | |
def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): | |
self.grad_clip = grad_clip | |
self.coalesce = coalesce | |
self.bucket_size_mb = bucket_size_mb | |
self.update_interval = update_interval | |
self.use_fp16 = use_fp16 | |
def before_run(self, runner): | |
runner.optimizer.zero_grad() | |
def after_train_iter(self, runner): | |
runner.outputs["loss"] /= self.update_interval | |
if self.use_fp16: | |
# runner.outputs['loss'].backward() | |
with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
runner.outputs["loss"].backward() | |
if self.every_n_iters(runner, self.update_interval): | |
if self.grad_clip is not None: | |
self.clip_grads(runner.model.parameters()) | |
runner.optimizer.step() | |
runner.optimizer.zero_grad() | |