#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch.nn as nn | |
from model.module.bilinear import Bilinear | |
class Biaffine(nn.Module): | |
def __init__(self, input_dim, output_dim, bias=True, bias_init=None): | |
super(Biaffine, self).__init__() | |
self.linear_1 = nn.Linear(input_dim, output_dim, bias=False) | |
self.linear_2 = nn.Linear(input_dim, output_dim, bias=False) | |
self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias) | |
if bias_init is not None: | |
self.bilinear.bias.data = bias_init | |
def forward(self, x, y): | |
return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1) | |