|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class LayerNorm32(nn.LayerNorm): |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return super().forward(x.float()).type(x.dtype) |
|
|
|
|
|
class GroupNorm32(nn.GroupNorm): |
|
""" |
|
A GroupNorm layer that converts to float32 before the forward pass. |
|
""" |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return super().forward(x.float()).type(x.dtype) |
|
|
|
|
|
class ChannelLayerNorm32(LayerNorm32): |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
DIM = x.dim() |
|
x = x.permute(0, *range(2, DIM), 1).contiguous() |
|
x = super().forward(x) |
|
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() |
|
return x |
|
|