File size: 170 Bytes
a1d0506 |
1 2 3 4 5 6 7 8 9 10 |
from torch import nn
FC_CLASS_REGISTRY = {"torch": nn.Linear}
try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY["te"] = te.Linear
except:
pass
|
a1d0506 |
1 2 3 4 5 6 7 8 9 10 |
from torch import nn
FC_CLASS_REGISTRY = {"torch": nn.Linear}
try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY["te"] = te.Linear
except:
pass
|