gomoku / DI-engine /ding /torch_utils /tests /test_backend_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
435 Bytes
import pytest
import torch
from ding.torch_utils.backend_helper import enable_tf32
@pytest.mark.cudatest
class TestBackendHelper:
def test_tf32(self):
r"""
Overview:
Test the tf32.
"""
enable_tf32()
net = torch.nn.Linear(3, 4)
x = torch.randn(1, 3)
y = torch.sum(net(x))
net.zero_grad()
y.backward()
assert net.weight.grad is not None