File size: 812 Bytes
e49b33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
from torch import nn
import numpy as np
class Projector(nn.Module):
def __init__(
self,
input_embedding_dim: int = 300,
final_embedding_dim: int = 512,
dropout: float = 0.2
):
super().__init__()
self.fc1 = nn.Linear(input_embedding_dim, 512)
self.fn1 = nn.LeakyReLU()
self.fc2 = nn.Linear(512, final_embedding_dim)
self.fn2 = nn.LeakyReLU()
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(final_embedding_dim)
self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, x):
x = self.fc1(x)
x = self.fn1(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.fn2(x)
x = self.layer_norm(x)
return x |