homemade_lo_vi / modules /wrapper.py
moiduy04's picture
Upload 18 files
bc1ada8
raw
history blame
692 Bytes
import torch.nn as nn
from torch import Tensor
class Linear(nn.Module):
"""
A wrapper class for nn.Linear
Initialize values using xxx
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
):
super(Linear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias, device, dtype)
nn.init.xavier_uniform_(self.linear.weight)
if bias:
nn.init.zeros_(self.linear.bias)
def forward(self, x: Tensor) -> Tensor:
"""
forward pass through linear layer.
"""
return self.linear(x)