Epsilon617
init
283e8f1
raw
history blame
1.4 kB
import torch
from torch import nn
import torch.nn.functional as F
class MLPProberBase(nn.Module):
def __init__(self, d=768, layer='all', num_outputs=87):
super().__init__()
self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
self.num_layers = len(self.hidden_layer_sizes)
self.layer = layer
for i, ld in enumerate(self.hidden_layer_sizes):
setattr(self, f"hidden_{i}", nn.Linear(d, ld))
d = ld
self.output = nn.Linear(d, num_outputs)
self.n_tranformer_layer = 12
self.init_aggregator()
def init_aggregator(self):
"""Initialize the aggregator for weighted sum over different layers of features
"""
if self.layer == "all":
# use learned weights to aggregate features
self.aggregator = nn.Parameter(torch.randn((1, self.n_tranformer_layer, 1)))
def forward(self, x):
"""
x: (B, L, T, H)
T=#chunks, can be 1 or several chunks
"""
if self.layer == "all":
weights = F.softmax(self.aggregator, dim=1)
x = (x * weights).sum(dim=1)
for i in range(self.num_layers):
x = getattr(self, f"hidden_{i}")(x)
# x = self.dropout(x)
x = F.relu(x)
output = self.output(x)
return output