File size: 19,873 Bytes
f53b39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import fnmatch
import inspect
import itertools
import logging
import types
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
)

import hydra

import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch import Tensor


class Optimizer:
    def __init__(self, optimizer, schedulers=None) -> None:
        self.optimizer = optimizer
        self.schedulers = schedulers
        self._validate_optimizer_schedulers()
        self.step_schedulers(0.0, 0)

    def _validate_optimizer_schedulers(self):
        if self.schedulers is None:
            return
        for _, set_of_schedulers in enumerate(self.schedulers):
            for option, _ in set_of_schedulers.items():
                assert option in self.optimizer.defaults, (
                    "Optimizer option "
                    f"{option} not found in {self.optimizer}. Valid options are "
                    f"{self.optimizer.defaults.keys()}"
                )

    def step_schedulers(self, where: float, step: int) -> None:
        if self.schedulers is None:
            return
        for i, param_group in enumerate(self.optimizer.param_groups):
            for option, scheduler in self.schedulers[i].items():
                if "step" in inspect.signature(scheduler.__call__).parameters:
                    new_value = scheduler(step=step, where=where)
                elif (
                    hasattr(scheduler, "scheduler")
                    and "step"
                    in inspect.signature(scheduler.scheduler.__call__).parameters
                ):
                    # To handle ValueScaler wrappers
                    new_value = scheduler(step=step, where=where)
                else:
                    new_value = scheduler(where)
                param_group[option] = new_value

    def step(self, where, step, closure=None):
        self.step_schedulers(where, step)
        return self.optimizer.step(closure)

    def zero_grad(self, *args, **kwargs):
        return self.optimizer.zero_grad(*args, **kwargs)


def set_default_parameters(
    scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
) -> None:
    """Set up the "default" scheduler with the right parameters.

    Args:
        scheduler_cgfs: A list of scheduler configs, where each scheduler also
            specifies which parameters it applies to, based on the names of parameters
            or the class of the modules. At most one scheduler is allowed to skip this
            specification, which is used as a "default" specification for any remaining
            parameters.
        all_parameter_names: Names of all the parameters to consider.
    """
    constraints = [
        scheduler_cfg.parameter_names
        for scheduler_cfg in scheduler_cfgs
        if scheduler_cfg.parameter_names is not None
    ]
    if len(constraints) == 0:
        default_params = set(all_parameter_names)
    else:
        default_params = all_parameter_names - set.union(*constraints)
    default_count = 0
    for scheduler_cfg in scheduler_cfgs:
        if scheduler_cfg.parameter_names is None:
            scheduler_cfg.parameter_names = default_params
            default_count += 1
    assert default_count <= 1, "Only one scheduler per option can be default"
    if default_count == 0:
        # No default scheduler specified, add a default, but without any scheduler
        # for that option
        scheduler_cfgs.append({"parameter_names": default_params})


def name_constraints_to_parameters(
    param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
) -> List[torch.nn.Parameter]:
    """Return parameters which match the intersection of parameter constraints.

    Note that this returns the parameters themselves, not their names.

    Args:
        param_constraints: A list, with each element being a set of allowed parameters.
        named_parameters: Mapping from a parameter name to the parameter itself.

    Returns:
        A list containing the parameters which overlap with _each_ constraint set from
        param_constraints.
    """
    matching_names = set.intersection(*param_constraints)
    return [value for name, value in named_parameters.items() if name in matching_names]


def map_scheduler_cfgs_to_param_groups(
    all_scheduler_cfgs: Iterable[List[Dict]],
    named_parameters: Dict[str, Tensor],
) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
    """Produce parameter groups corresponding to all the scheduler configs.

    Takes all the scheduler configs, each of which applies to a specific optimizer
    option (like "lr" or "weight_decay") and has a set of parameter names which it
    applies to, and produces a final set of param groups where each param group
    covers all the options which apply to a particular set of parameters.

    Args:
        all_scheduler_cfgs: All the scheduler configs covering every option.
        named_parameters: Mapping from a parameter name to the parameter itself.
    Returns:
        Tuple of lists of schedulers and param_groups, where schedulers[i]
        applies to param_groups[i].
    """

    scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
    schedulers = []
    param_groups = []
    for scheduler_cfgs in scheduler_cfgs_per_param_group:
        param_constraints = [
            scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
        ]
        matching_parameters = name_constraints_to_parameters(
            param_constraints, named_parameters
        )
        if len(matching_parameters) == 0:  # If no overlap of parameters, skip
            continue
        schedulers_for_group = {
            scheduler_cfg["option"]: scheduler_cfg["scheduler"]
            for scheduler_cfg in scheduler_cfgs
            if "option" in scheduler_cfg
        }
        schedulers.append(schedulers_for_group)
        param_groups.append({"params": matching_parameters})
    return schedulers, param_groups


def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
    """Check that the param groups are non-overlapping and cover all the parameters.

    Args:
        param_groups: List of all param groups
        model: Model to validate against. The check ensures that all the model
            parameters are part of param_groups
    """
    for pg in param_groups:
        # no param should be repeated within a group
        assert len(pg["params"]) == len(set(pg["params"]))
    parameters = [set(param_group["params"]) for param_group in param_groups]
    model_parameters = {parameter for _, parameter in model.named_parameters()}
    for p1, p2 in itertools.permutations(parameters, 2):
        assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
    assert set.union(*parameters) == model_parameters, (
        "Scheduler generated param_groups must include all parameters of the model."
        f" Found {len(set.union(*parameters))} params whereas model has"
        f" {len(model_parameters)} params"
    )


def unix_module_cls_pattern_to_parameter_names(
    filter_module_cls_names: List[str],
    module_cls_to_param_names: Dict[Type, str],
) -> Union[None, Set[str]]:
    """Returns param names which pass the filters specified in filter_module_cls_names.

    Args:
        filter_module_cls_names: A list of filter strings containing class names, like
            ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
        module_cls_to_param_names: Mapping from module classes to the parameter names
            they contain. See `get_module_cls_to_param_names`.
    """
    if filter_module_cls_names is None:
        return set()
    allowed_parameter_names = []
    for module_cls_name in filter_module_cls_names:
        module_cls = hydra.utils.get_class(module_cls_name)
        if module_cls not in module_cls_to_param_names:
            raise AssertionError(
                f"module_cls_name {module_cls_name} does not "
                "match any classes in the model"
            )
        matching_parameters = module_cls_to_param_names[module_cls]
        assert (
            len(matching_parameters) > 0
        ), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
        logging.info(
            f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
        )
        allowed_parameter_names.append(matching_parameters)
    return set.union(*allowed_parameter_names)


def unix_param_pattern_to_parameter_names(
    filter_param_names: Optional[List[str]],
    parameter_names: Dict[str, torch.Tensor],
) -> Union[None, Set[str]]:
    """Returns param names which pass the filters specified in filter_param_names.

    Args:
        filter_param_names: A list of unix-style filter strings with optional
            wildcards, like ["block.2.*", "block.2.linear.weight"]
        module_cls_to_param_names: Mapping from module classes to the parameter names
            they contain. See `get_module_cls_to_param_names`.
    """

    if filter_param_names is None:
        return set()
    allowed_parameter_names = []
    for param_name in filter_param_names:
        matching_parameters = set(fnmatch.filter(parameter_names, param_name))
        assert (
            len(matching_parameters) >= 1
        ), f"param_name {param_name} does not match any parameters in the model"
        logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
        allowed_parameter_names.append(matching_parameters)
    return set.union(*allowed_parameter_names)


def _unix_pattern_to_parameter_names(
    scheduler_cfg: DictConfig,
    parameter_names: Set[str],
    module_cls_to_param_names: Dict[Type, str],
) -> Union[None, Set[str]]:
    """Returns param names which pass the filters specified in scheduler_cfg.

    Args:
        scheduler_cfg: The config for the scheduler
        parameter_names: The set of all parameter names which will be filtered
    """
    if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
        return None
    return unix_param_pattern_to_parameter_names(
        scheduler_cfg.get("param_names"), parameter_names
    ).union(
        unix_module_cls_pattern_to_parameter_names(
            scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
        )
    )


def get_module_cls_to_param_names(
    model: nn.Module, param_allowlist: Set[str] = None
) -> Dict[Type, str]:
    """Produce a mapping from all the modules classes to the names of parames they own.

    Only counts a parameter as part of the immediate parent module, i.e. recursive
    parents do not count.

    Args:
        model: Model to iterate over
        param_allowlist: If specified, only these param names will be processed
    """

    module_cls_to_params = {}
    for module_name, module in model.named_modules():
        module_cls = type(module)
        module_cls_to_params.setdefault(module_cls, set())
        for param_name, _ in module.named_parameters(recurse=False):
            full_param_name = get_full_parameter_name(module_name, param_name)
            if param_allowlist is None or full_param_name in param_allowlist:
                module_cls_to_params[module_cls].add(full_param_name)
    return module_cls_to_params


def construct_optimizer(
    model: torch.nn.Module,
    optimizer_conf: Any,
    options_conf: Mapping[str, List] = None,
    param_group_modifiers_conf: List[Callable] = None,
    param_allowlist: Optional[Set[str]] = None,
    validate_param_groups=True,
) -> Optimizer:
    """
    Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
    with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
    Batchnorm and/or no-update 1-D parameters support, based on the config.

    Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
    (LARS): https://arxiv.org/abs/1708.03888

    Args:
        model: model to perform stochastic gradient descent
            optimization or ADAM optimization.
        optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
            ADAM, still missing the params argument which this function provides to
            produce the final optimizer
        param_group_modifiers_conf: Optional user specified functions which can modify
            the final scheduler configs before the optimizer's param groups are built
        param_allowlist: The parameters to optimize. Parameters which are not part of
            this allowlist will be skipped.
        validate_param_groups: If enabled, valides that the produced param_groups don't
            overlap and cover all the model parameters.
    """
    if param_allowlist is None:
        param_allowlist = {name for name, _ in model.named_parameters()}

    named_parameters = {
        name: param
        for name, param in model.named_parameters()
        if name in param_allowlist
    }

    if not options_conf:
        optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
        return Optimizer(optimizer)

    all_parameter_names = {
        name for name, _ in model.named_parameters() if name in param_allowlist
    }
    module_cls_to_all_param_names = get_module_cls_to_param_names(
        model, param_allowlist
    )

    scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
    all_scheduler_cfgs = []
    for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
        for config in scheduler_cfgs:
            config.option = option
            config.parameter_names = _unix_pattern_to_parameter_names(
                config, all_parameter_names, module_cls_to_all_param_names
            )
        set_default_parameters(scheduler_cfgs, all_parameter_names)
        all_scheduler_cfgs.append(scheduler_cfgs)

    if param_group_modifiers_conf:
        for custom_param_modifier in param_group_modifiers_conf:
            custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
            all_scheduler_cfgs = custom_param_modifier(
                scheduler_cfgs=all_scheduler_cfgs, model=model
            )
    schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
        all_scheduler_cfgs, named_parameters
    )
    if validate_param_groups:
        validate_param_group_params(param_groups, model)
    optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
    return Optimizer(optimizer, schedulers)


def get_full_parameter_name(module_name, param_name):
    if module_name == "":
        return param_name
    return f"{module_name}.{param_name}"


class GradientClipper:
    """
    Gradient clipping utils that works for DDP
    """

    def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
        assert isinstance(max_norm, (int, float)) or max_norm is None
        self.max_norm = max_norm if max_norm is None else float(max_norm)
        self.norm_type = norm_type

    def __call__(self, model: nn.Module):
        if self.max_norm is None:
            return  # no-op

        nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
        )


class ValueScaler:
    def __init__(self, scheduler, mult_val: float):
        self.scheduler = scheduler
        self.mult_val = mult_val

    def __call__(self, *args, **kwargs):
        val = self.scheduler(*args, **kwargs)
        return val * self.mult_val


def rgetattr(obj, rattrs: str = None):
    """
    Like getattr(), but supports dotted notation for nested objects.
    rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
    """
    if rattrs is None:
        return obj
    attrs = rattrs.split(".")
    for attr in attrs:
        obj = getattr(obj, attr)
    return obj


def layer_decay_param_modifier(
    scheduler_cfgs: List[List[Dict]],
    model,
    layer_decay_value: float,
    layer_decay_min: Optional[float] = None,
    apply_to: Optional[str] = None,
    overrides: List[Dict] = (),
) -> List[List[Dict]]:
    """
    Args
    - scheduler_cfgs: a list of omegaconf.ListConfigs.
        Each element in the list is a omegaconfg.DictConfig with the following structure
        {
            "scheduler": <some fvcore scheduler>
            "option": <value> possible options are "lr", "weight_decay" etc.
            "parameter_names": Set of str indicating param names that this scheduler applies to
        }
    - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
            and a method get_num_layers.
            Alternatively, use apply_to argument to select a specific component of the model.
    - layer_decay_value: float
    - layer_decay_min: min val for layer decay
    - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
    - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
    Returns
    - scheduler_configs: same structure as the input, elements can be modified
    """
    model = rgetattr(model, apply_to)
    num_layers = model.get_num_layers() + 1
    layer_decays = [
        layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
    ]
    if layer_decay_min is not None:
        layer_decays = [max(val, layer_decay_min) for val in layer_decays]
    final_scheduler_cfgs = []
    # scheduler_cfgs is a list of lists
    for scheduler_cfg_group in scheduler_cfgs:
        curr_cfg_group = []
        # scheduler_cfg_group is a list of dictionaries
        for scheduler_cfg in scheduler_cfg_group:
            if scheduler_cfg["option"] != "lr":
                curr_cfg_group.append(scheduler_cfg)
                continue
            # Need sorted so that the list of parameter names is deterministic and consistent
            # across re-runs of this job. Else it was causing issues with loading the optimizer
            # state during a job restart (D38591759)
            parameter_names = sorted(scheduler_cfg["parameter_names"])

            # Only want one cfg group per layer
            layer_cfg_groups = {}
            for param_name in parameter_names:
                layer_id = num_layers
                this_scale = layer_decays[layer_id]
                if param_name.startswith(apply_to):
                    layer_id = model.get_layer_id(param_name)
                    this_scale = layer_decays[layer_id]
                    # Overrides
                    for override in overrides:
                        if fnmatch.fnmatchcase(param_name, override["pattern"]):
                            this_scale = float(override["value"])
                            layer_id = override["pattern"]
                            break

                if layer_id not in layer_cfg_groups:
                    curr_param = {
                        "option": scheduler_cfg["option"],
                        "scheduler": ValueScaler(
                            scheduler_cfg["scheduler"], this_scale
                        ),
                        "parameter_names": {param_name},
                    }
                else:
                    curr_param = layer_cfg_groups[layer_id]
                    curr_param["parameter_names"].add(param_name)
                layer_cfg_groups[layer_id] = curr_param

            for layer_cfg in layer_cfg_groups.values():
                curr_cfg_group.append(layer_cfg)

        final_scheduler_cfgs.append(curr_cfg_group)
    return final_scheduler_cfgs