Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Helpers to train with 16-bit precision. | |
""" | |
import numpy as np | |
import torch as th | |
import torch.nn as nn | |
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |
from . import logger | |
INITIAL_LOG_LOSS_SCALE = 20.0 | |
def convert_module_to_f16(l): | |
""" | |
Convert primitive modules to float16. | |
""" | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | |
l.weight.data = l.weight.data.half() | |
if l.bias is not None: | |
l.bias.data = l.bias.data.half() | |
def convert_module_to_f32(l): | |
""" | |
Convert primitive modules to float32, undoing convert_module_to_f16(). | |
""" | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | |
l.weight.data = l.weight.data.float() | |
if l.bias is not None: | |
l.bias.data = l.bias.data.float() | |
def make_master_params(param_groups_and_shapes): | |
""" | |
Copy model parameters into a (differently-shaped) list of full-precision | |
parameters. | |
""" | |
master_params = [] | |
for param_group, shape in param_groups_and_shapes: | |
master_param = nn.Parameter( | |
_flatten_dense_tensors([ | |
param.detach().float() for (_, param) in param_group | |
]).view(shape)) | |
master_param.requires_grad = True | |
master_params.append(master_param) | |
return master_params | |
def model_grads_to_master_grads(param_groups_and_shapes, master_params): | |
""" | |
Copy the gradients from the model parameters into the master parameters | |
from make_master_params(). | |
""" | |
for master_param, (param_group, shape) in zip(master_params, | |
param_groups_and_shapes): | |
master_param.grad = _flatten_dense_tensors([ | |
param_grad_or_zeros(param) for (_, param) in param_group | |
]).view(shape) | |
def master_params_to_model_params(param_groups_and_shapes, master_params): | |
""" | |
Copy the master parameter data back into the model parameters. | |
""" | |
# Without copying to a list, if a generator is passed, this will | |
# silently not copy any parameters. | |
for master_param, (param_group, _) in zip(master_params, | |
param_groups_and_shapes): | |
for (_, param), unflat_master_param in zip( | |
param_group, | |
unflatten_master_params(param_group, master_param.view(-1))): | |
param.detach().copy_(unflat_master_param) | |
def unflatten_master_params(param_group, master_param): | |
return _unflatten_dense_tensors(master_param, | |
[param for (_, param) in param_group]) | |
def get_param_groups_and_shapes(named_model_params): | |
named_model_params = list(named_model_params) | |
scalar_vector_named_params = ( | |
[(n, p) for (n, p) in named_model_params if p.ndim <= 1], | |
(-1), | |
) | |
matrix_named_params = ( | |
[(n, p) for (n, p) in named_model_params if p.ndim > 1], | |
(1, -1), | |
) | |
return [scalar_vector_named_params, matrix_named_params] | |
def master_params_to_state_dict(model, param_groups_and_shapes, master_params, | |
use_fp16): | |
if use_fp16: | |
state_dict = model.state_dict() | |
for master_param, (param_group, _) in zip(master_params, | |
param_groups_and_shapes): | |
for (name, _), unflat_master_param in zip( | |
param_group, | |
unflatten_master_params(param_group, | |
master_param.view(-1))): | |
assert name in state_dict | |
state_dict[name] = unflat_master_param | |
else: | |
state_dict = model.state_dict() | |
for i, (name, _value) in enumerate(model.named_parameters()): | |
assert name in state_dict | |
state_dict[name] = master_params[i] | |
return state_dict | |
def state_dict_to_master_params(model, state_dict, use_fp16): | |
if use_fp16: | |
named_model_params = [(name, state_dict[name]) | |
for name, _ in model.named_parameters()] | |
param_groups_and_shapes = get_param_groups_and_shapes( | |
named_model_params) | |
master_params = make_master_params(param_groups_and_shapes) | |
else: | |
master_params = [ | |
state_dict[name] for name, _ in model.named_parameters() | |
] | |
return master_params | |
def zero_master_grads(master_params): | |
for param in master_params: | |
param.grad = None | |
def zero_grad(model_params): | |
for param in model_params: | |
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group | |
if param.grad is not None: | |
param.grad.detach_() | |
param.grad.zero_() | |
def param_grad_or_zeros(param): | |
if param.grad is not None: | |
return param.grad.data.detach() | |
else: | |
return th.zeros_like(param) | |
class MixedPrecisionTrainer: | |
def __init__(self, | |
*, | |
model, | |
use_fp16=False, | |
use_amp=False, | |
fp16_scale_growth=1e-3, | |
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, | |
model_name='ddpm', | |
submodule_name='', | |
model_params=None): | |
self.model_name = model_name | |
self.model = model | |
self.use_fp16 = use_fp16 | |
self.use_amp = use_amp | |
if self.use_amp: | |
# https://github.com/pytorch/pytorch/issues/40497#issuecomment-1262373602 | |
# https://github.com/pytorch/pytorch/issues/111739 | |
self.scaler = th.cuda.amp.GradScaler(enabled=use_amp, init_scale=2**15, growth_interval=100) | |
logger.log(model_name, 'enables AMP to accelerate training') | |
else: | |
logger.log(model_name, 'not enables AMP to accelerate training') | |
self.fp16_scale_growth = fp16_scale_growth | |
self.model_params = list(self.model.parameters( | |
)) if model_params is None else list(model_params) if not isinstance( | |
model_params, list) else model_params | |
self.master_params = self.model_params | |
self.param_groups_and_shapes = None | |
self.lg_loss_scale = initial_lg_loss_scale | |
if self.use_fp16: | |
self.param_groups_and_shapes = get_param_groups_and_shapes( | |
self.model.named_parameters()) | |
self.master_params = make_master_params( | |
self.param_groups_and_shapes) | |
self.model.convert_to_fp16() | |
def zero_grad(self): | |
zero_grad(self.model_params) | |
def backward(self, loss: th.Tensor, disable_amp=False, **kwargs): | |
"""**kwargs: retain_graph=True | |
""" | |
if self.use_fp16: | |
loss_scale = 2**self.lg_loss_scale | |
(loss * loss_scale).backward(**kwargs) | |
elif self.use_amp and not disable_amp: | |
self.scaler.scale(loss).backward(**kwargs) | |
else: | |
loss.backward(**kwargs) | |
# def optimize(self, opt: th.optim.Optimizer, clip_grad=False): | |
def optimize(self, opt: th.optim.Optimizer, clip_grad=True): | |
if self.use_fp16: | |
return self._optimize_fp16(opt) | |
elif self.use_amp: | |
return self._optimize_amp(opt, clip_grad) | |
else: | |
return self._optimize_normal(opt, clip_grad) | |
def _optimize_fp16(self, opt: th.optim.Optimizer): | |
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) | |
model_grads_to_master_grads(self.param_groups_and_shapes, | |
self.master_params) | |
grad_norm, param_norm = self._compute_norms( | |
grad_scale=2**self.lg_loss_scale) | |
if check_overflow(grad_norm): | |
self.lg_loss_scale -= 1 | |
logger.log( | |
f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") | |
zero_master_grads(self.master_params) | |
return False | |
logger.logkv_mean("grad_norm", grad_norm) | |
logger.logkv_mean("param_norm", param_norm) | |
for p in self.master_params: | |
p.grad.mul_(1.0 / (2**self.lg_loss_scale)) | |
opt.step() | |
zero_master_grads(self.master_params) | |
master_params_to_model_params(self.param_groups_and_shapes, | |
self.master_params) | |
self.lg_loss_scale += self.fp16_scale_growth | |
return True | |
def _optimize_amp(self, opt: th.optim.Optimizer, clip_grad=False): | |
# https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping | |
assert clip_grad | |
self.scaler.unscale_(opt) # to calculate accurate gradients | |
if clip_grad: | |
th.nn.utils.clip_grad_norm_( # type: ignore | |
self.master_params, | |
5.0, | |
norm_type=2, | |
error_if_nonfinite=False, | |
foreach=True, | |
) # clip before compute_norm | |
grad_norm, param_norm = self._compute_norms() | |
logger.logkv_mean("grad_norm", grad_norm) | |
logger.logkv_mean("param_norm", param_norm) | |
self.scaler.step(opt) | |
self.scaler.update() | |
return True | |
def _optimize_normal(self, opt: th.optim.Optimizer, clip_grad:bool=False): | |
assert clip_grad | |
if clip_grad: | |
th.nn.utils.clip_grad_norm_( # type: ignore | |
self.master_params, | |
5.0, | |
norm_type=2, | |
error_if_nonfinite=False, | |
foreach=True, | |
) # clip before compute_norm | |
grad_norm, param_norm = self._compute_norms() | |
logger.logkv_mean("grad_norm", grad_norm) | |
logger.logkv_mean("param_norm", param_norm) | |
opt.step() | |
return True | |
def _compute_norms(self, grad_scale=1.0): | |
grad_norm = 0.0 | |
param_norm = 0.0 | |
for p in self.master_params: | |
with th.no_grad(): | |
param_norm += th.norm(p, p=2, dtype=th.float32).item()**2 | |
if p.grad is not None: | |
grad_norm += th.norm(p.grad, p=2, | |
dtype=th.float32).item()**2 | |
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) | |
def master_params_to_state_dict(self, master_params, model=None): | |
if model is None: | |
model = self.model | |
return master_params_to_state_dict(model, self.param_groups_and_shapes, | |
master_params, self.use_fp16) | |
def state_dict_to_master_params(self, state_dict, model=None): | |
if model is None: | |
model = self.model | |
return state_dict_to_master_params(model, state_dict, self.use_fp16) | |
def state_dict_to_master_params_given_submodule_name( | |
self, state_dict, submodule_name): | |
return state_dict_to_master_params(getattr(self.model, submodule_name), | |
state_dict, self.use_fp16) | |
def check_overflow(value): | |
return (value == float("inf")) or (value == -float("inf")) or (value | |
!= value) | |