|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Normalization modules.""" |
|
|
|
import typing as tp |
|
|
|
import einops |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class ConvLayerNorm(nn.LayerNorm): |
|
""" |
|
Convolution-friendly LayerNorm that moves channels to last dimensions |
|
before running the normalization and moves them back to original position right after. |
|
""" |
|
|
|
def __init__( |
|
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs |
|
): |
|
super().__init__(normalized_shape, **kwargs) |
|
|
|
def forward(self, x): |
|
x = einops.rearrange(x, "b ... t -> b t ...") |
|
x = super().forward(x) |
|
x = einops.rearrange(x, "b t ... -> b ... t") |
|
return |
|
|