Spaces:
Configuration error
Configuration error
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import numpy as np | |
import torch | |
__all__ = [ | |
"torch_randint", | |
"torch_random", | |
"torch_shuffle", | |
"torch_uniform", | |
"torch_random_choices", | |
] | |
def torch_randint( | |
low: int, high: int, generator: torch.Generator or None = None | |
) -> int: | |
"""uniform: [low, high)""" | |
if low == high: | |
return low | |
else: | |
assert low < high | |
return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) | |
def torch_random(generator: torch.Generator or None = None) -> float: | |
"""uniform distribution on the interval [0, 1)""" | |
return float(torch.rand(1, generator=generator)) | |
def torch_shuffle( | |
src_list: list[any], generator: torch.Generator or None = None | |
) -> list[any]: | |
rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() | |
return [src_list[i] for i in rand_indexes] | |
def torch_uniform( | |
low: float, high: float, generator: torch.Generator or None = None | |
) -> float: | |
"""uniform distribution on the interval [low, high)""" | |
rand_val = torch_random(generator) | |
return (high - low) * rand_val + low | |
def torch_random_choices( | |
src_list: list[any], | |
generator: torch.Generator or None = None, | |
k=1, | |
weight_list: list[float] or None = None, | |
) -> any or list: | |
if weight_list is None: | |
rand_idx = torch.randint( | |
low=0, high=len(src_list), generator=generator, size=(k,) | |
) | |
out_list = [src_list[i] for i in rand_idx] | |
else: | |
assert len(weight_list) == len(src_list) | |
accumulate_weight_list = np.cumsum(weight_list) | |
out_list = [] | |
for _ in range(k): | |
val = torch_uniform(0, accumulate_weight_list[-1], generator) | |
active_id = 0 | |
for i, weight_val in enumerate(accumulate_weight_list): | |
active_id = i | |
if weight_val > val: | |
break | |
out_list.append(src_list[active_id]) | |
return out_list[0] if k == 1 else out_list | |