Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass, field | |
| import torch | |
| import torch.distributed as dist | |
| from fairseq.dataclass.configs import FairseqBMUFConfig | |
| from fairseq.dataclass.utils import gen_parser_from_dataclass | |
| from fairseq.optim.fairseq_optimizer import FairseqOptimizer | |
| class FairseqBMUF(FairseqOptimizer): | |
| """ | |
| Implements incremental block distributed data parallelism similar to | |
| https://ieeexplore.ieee.org/document/7472805 | |
| Paper title: Scalable training of deep learning machines by incremental | |
| block training with intra-block parallel optimization and blockwise | |
| model-update filtering | |
| """ | |
| def __init__(self, cfg: FairseqBMUFConfig, optimizer): | |
| super().__init__(cfg) | |
| self._optimizer = optimizer | |
| self._num_updates = 0 | |
| self.sync_iter = cfg.global_sync_iter | |
| self.block_momentum = cfg.block_momentum | |
| self.block_lr = cfg.block_lr | |
| self._reset_local_data() | |
| self.warmup_iteration = cfg.warmup_iterations | |
| self.use_nbm = cfg.use_nbm | |
| self.initial_state = self._optimizer.state_dict() | |
| self.average_sync = self.cfg.average_sync | |
| self.world_size = self.cfg.distributed_world_size | |
| def add_args(parser): | |
| """Add optimizer-specific arguments to the parser.""" | |
| gen_parser_from_dataclass(parser, FairseqBMUFConfig()) | |
| def optimizer(self): | |
| return self._optimizer.optimizer | |
| def optimizer_config(self): | |
| return self._optimizer.optimizer_config | |
| def get_lr(self): | |
| return self._optimizer.get_lr() | |
| def set_lr(self, lr): | |
| self._optimizer.set_lr(lr) | |
| def state_dict(self): | |
| return self._optimizer.state_dict() | |
| def load_state_dict(self, state_dict, optimizer_overrides=None): | |
| self._optimizer.load_state_dict(state_dict, optimizer_overrides) | |
| self.initial_state = self._optimizer.state_dict() | |
| def multiply_grads(self, c): | |
| """Multiplies grads by a constant *c*.""" | |
| self._optimizer.multiply_grads(c) | |
| def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): | |
| """Clips gradient norm.""" | |
| return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) | |
| def average_params(self): | |
| self._optimizer.average_params() | |
| def _block_sync(self): | |
| if self.world_size <= 1: | |
| return | |
| # Update the global model using local models from all GPUs | |
| # (Step-1) Calculate grad between previously synced model and | |
| # currrent local model | |
| if self.block_momentum != 0: | |
| self._calc_grad() | |
| # (Step-2) Average gradient from all GPUs | |
| self._avg_grad_from_all_gpus() | |
| # (Step-3) Calculate global momentum and update the global model | |
| if self.block_momentum != 0: | |
| self._update_global_model() | |
| # (Step-4) Average local optimizer params | |
| if self.average_sync: | |
| self.average_params() | |
| def _is_warmup_end(self): | |
| # Check whether train iterations is equal to warmup iter | |
| if self.get_num_updates() == self.warmup_iteration: | |
| return True | |
| return False | |
| def _is_bmuf_iter(self): | |
| # Check whether train iterations is equal to bmuf sync iter | |
| if (self.get_num_updates() > self.warmup_iteration) and ( | |
| self.get_num_updates() % self.sync_iter == 0 | |
| ): | |
| return True | |
| return False | |
| def _warmup_sync(self, root_rank=0): | |
| if self.world_size <= 1: | |
| return | |
| # Broadcast the local model to all gpus | |
| for param in self.params: | |
| dist.broadcast(param.data, src=root_rank) | |
| # Update local optimizer state | |
| if self.average_sync: | |
| self._optimizer.average_params() | |
| else: | |
| self._optimizer.load_state_dict(self.initial_state) | |
| self._reset_local_data() | |
| def step(self, closure=None): | |
| """Performs a single optimization step.""" | |
| self._optimizer.step(closure) | |
| self.set_num_updates(self.get_num_updates() + 1) | |
| if self._is_warmup_end(): | |
| self._warmup_sync() | |
| elif self._is_bmuf_iter(): | |
| self._block_sync() | |
| def zero_grad(self): | |
| """Clears the gradients of all optimized parameters.""" | |
| self._optimizer.zero_grad() | |
| def get_num_updates(self): | |
| """Get the number of parameters updates.""" | |
| return self._num_updates | |
| def set_num_updates(self, num_updates): | |
| """Set the number of parameters updates.""" | |
| self._num_updates = num_updates | |
| def _reset_local_data(self): | |
| # (Step-0) Initialize global momentum parameters and store global copy on each gpu | |
| self.global_params = [torch.zeros_like(p.data) for p in self.params] | |
| self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params] | |
| self.grads = [p.data.new_zeros(p.data.size()) for p in self.params] | |
| # saving the global model locally for calculating gradient during bmuf sync | |
| for param, global_param in zip(self.params, self.global_params): | |
| global_param.copy_(param.data) | |
| def _calc_grad(self): | |
| # global_params is basically the global copy from the previously finished | |
| # synchronisation. param.data is local parameter after block_sync_freq | |
| # for the local gpu. so grad is difference between previously synced | |
| # model and currrent local model. | |
| for index, (param, global_param) in enumerate( | |
| zip(self.params, self.global_params) | |
| ): | |
| self.grads[index] = global_param - param.data | |
| def _avg_grad_from_all_gpus(self): | |
| for index, param in enumerate(self.params): | |
| sync_para = param.data if self.block_momentum == 0 else self.grads[index] | |
| sync_para /= float(dist.get_world_size()) | |
| dist.all_reduce(sync_para, op=dist.ReduceOp.SUM) | |
| def _update_global_model(self): | |
| for index, (param, global_param, smoothed_grad, grad) in enumerate( | |
| zip( | |
| self.params, | |
| self.global_params, | |
| self.smoothed_grads, | |
| # all gpus would share the same value of smoothed_grad, since it is | |
| # always computed on synchronized gradients. | |
| self.grads, | |
| ) | |
| ): | |
| # global_param is basically last syncrhornized parameter. though | |
| # smoothed_grad is local, all processes will have same value of | |
| # smoothed_grad and hence param is globally synchronized copy. | |
| # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t) | |
| smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad | |
| param.data.copy_(global_param - smoothed_grad) | |
| # A Nesterov momentum here is to do a partial weight update before | |
| # calculating the gradient | |
| if self.use_nbm: | |
| param.data.copy_(param.data - self.block_momentum * smoothed_grad) | |
| # backup for the next synchronization. | |
| self.smoothed_grads[index] = smoothed_grad | |
| global_param.copy_(param.data) | |