import torch
from timm.loss import SoftTargetCrossEntropy

from timm.models.layers import DropPath

from .infinity import Infinity, sample_with_top_k_top_p_also_inplace_modifying_logits_

def _ex_repr(self):
    return ', '.join(
        f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
        for k, v in vars(self).items()
        if not k.startswith('_') and k != 'training'
        and not isinstance(v, (torch.nn.Module, torch.Tensor))
    )
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy):  # no longer __repr__ DropPath with drop_prob
    if hasattr(clz, 'extra_repr'):
        clz.extra_repr = _ex_repr
    else:
        clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'

DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)'

alias_dict = {}
for d in range(6, 40+2, 2):
    alias_dict[f'd{d}'] = f'infinity_d{d}'
alias_dict_inv = {v: k for k, v in alias_dict.items()}