Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
import os | |
import sys | |
# to import modules from parent_dir | |
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.append(parent_dir) | |
import torch | |
from alias_free_activation.cuda import activation1d | |
from activations import SnakeBeta | |
def test_load_fused_kernels(): | |
try: | |
print("[Success] load_fused_kernels") | |
except ImportError as e: | |
print("[Fail] load_fused_kernels") | |
raise e | |
def test_anti_alias_activation(): | |
data = torch.rand((10, 10, 200), device="cuda") | |
# Check activations, Snake CUDA vs. Torch | |
fused_anti_alias_activation = activation1d.Activation1d( | |
activation=SnakeBeta(10), fused=True | |
).cuda() | |
fused_activation_output = fused_anti_alias_activation(data) | |
torch_anti_alias_activation = activation1d.Activation1d( | |
activation=SnakeBeta(10), fused=False | |
).cuda() | |
torch_activation_output = torch_anti_alias_activation(data) | |
test_result = (fused_activation_output - torch_activation_output).abs() | |
while test_result.dim() != 1: | |
test_result = test_result.mean(dim=-1) | |
diff = test_result.mean(dim=-1) | |
if diff <= 1e-3: | |
print( | |
f"\n[Success] test_fused_anti_alias_activation" | |
f"\n > mean_difference={diff}" | |
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}" | |
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" | |
) | |
else: | |
print( | |
f"\n[Fail] test_fused_anti_alias_activation" | |
f"\n > mean_difference={diff}, " | |
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, " | |
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" | |
) | |
if __name__ == "__main__": | |
from alias_free_activation.cuda import load | |
load.load() | |
test_load_fused_kernels() | |
test_anti_alias_activation() | |