|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
|
|
class Linear(nn.Module): |
|
""" |
|
A wrapper class for nn.Linear |
|
Initialize values using xxx |
|
""" |
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
bias: bool = True, |
|
device=None, |
|
dtype=None, |
|
): |
|
super(Linear, self).__init__() |
|
self.linear = nn.Linear(in_features, out_features, bias, device, dtype) |
|
nn.init.xavier_uniform_(self.linear.weight) |
|
if bias: |
|
nn.init.zeros_(self.linear.bias) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
""" |
|
forward pass through linear layer. |
|
""" |
|
return self.linear(x) |
|
|