File size: 1,492 Bytes
d3b8c8f 31f7840 db40549 d3b8c8f db40549 d3b8c8f db40549 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Union, Tuple
import torch
from torch import nn
norm_t = Union[Tuple[float, float, float], torch.Tensor]
class InputConditioner(nn.Module):
def __init__(self,
input_scale: float,
norm_mean: norm_t,
norm_std: norm_t,
dtype: torch.dtype = None,
):
super().__init__()
self.dtype = dtype
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
def forward(self, x: torch.Tensor):
y = (x - self.norm_mean) / self.norm_std
if self.dtype is not None:
y = y.to(self.dtype)
return y
def get_default_conditioner():
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
return InputConditioner(
input_scale=1.0,
norm_mean=OPENAI_CLIP_MEAN,
norm_std=OPENAI_CLIP_STD,
)
def _to_tensor(v: norm_t):
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
|