Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# LARS optimizer, implementation from MoCo v3: | |
# https://github.com/facebookresearch/moco-v3 | |
# -------------------------------------------------------- | |
import torch | |
class LARS(torch.optim.Optimizer): | |
""" | |
LARS optimizer, no rate scaling or weight decay for parameters <= 1D. | |
""" | |
def __init__( | |
self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 | |
): | |
defaults = dict( | |
lr=lr, | |
weight_decay=weight_decay, | |
momentum=momentum, | |
trust_coefficient=trust_coefficient, | |
) | |
super().__init__(params, defaults) | |
def step(self): | |
for g in self.param_groups: | |
for p in g["params"]: | |
dp = p.grad | |
if dp is None: | |
continue | |
if p.ndim > 1: # if not normalization gamma/beta or bias | |
dp = dp.add(p, alpha=g["weight_decay"]) | |
param_norm = torch.norm(p) | |
update_norm = torch.norm(dp) | |
one = torch.ones_like(param_norm) | |
q = torch.where( | |
param_norm > 0.0, | |
torch.where( | |
update_norm > 0, | |
(g["trust_coefficient"] * param_norm / update_norm), | |
one, | |
), | |
one, | |
) | |
dp = dp.mul(q) | |
param_state = self.state[p] | |
if "mu" not in param_state: | |
param_state["mu"] = torch.zeros_like(p) | |
mu = param_state["mu"] | |
mu.mul_(g["momentum"]).add_(dp) | |
p.add_(mu, alpha=-g["lr"]) | |