File size: 677 Bytes
1d5604f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
#!/usr/bin/env python3
# coding=utf-8
import torch.nn as nn
from model.module.bilinear import Bilinear
class Biaffine(nn.Module):
def __init__(self, input_dim, output_dim, bias=True, bias_init=None):
super(Biaffine, self).__init__()
self.linear_1 = nn.Linear(input_dim, output_dim, bias=False)
self.linear_2 = nn.Linear(input_dim, output_dim, bias=False)
self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias)
if bias_init is not None:
self.bilinear.bias.data = bias_init
def forward(self, x, y):
return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1)
|