emotion_recognition / model_LA.py
nouamanetazi's picture
nouamanetazi HF staff
linting
c731c61
raw
history blame
9.03 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.fc import MLP, FC
from layers.layer_norm import LayerNorm
# ------------------------------------
# ---------- Masking sequence --------
# ------------------------------------
def make_mask(feature):
return (torch.sum(torch.abs(feature), dim=-1) == 0).unsqueeze(1).unsqueeze(2)
# ------------------------------
# ---------- Flattening --------
# ------------------------------
class AttFlat(nn.Module):
def __init__(self, args, flat_glimpse, merge=False):
super(AttFlat, self).__init__()
self.args = args
self.merge = merge
self.flat_glimpse = flat_glimpse
self.mlp = MLP(
in_size=args.hidden_size,
mid_size=args.ff_size,
out_size=flat_glimpse,
dropout_r=args.dropout_r,
use_relu=True,
)
if self.merge:
self.linear_merge = nn.Linear(
args.hidden_size * flat_glimpse, args.hidden_size * 2
)
def forward(self, x, x_mask):
att = self.mlp(x)
if x_mask is not None:
att = att.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2), -1e9)
att = F.softmax(att, dim=1)
att_list = []
for i in range(self.flat_glimpse):
att_list.append(torch.sum(att[:, :, i : i + 1] * x, dim=1))
if self.merge:
x_atted = torch.cat(att_list, dim=1)
x_atted = self.linear_merge(x_atted)
return x_atted
return torch.stack(att_list).transpose_(0, 1)
# ------------------------
# ---- Self Attention ----
# ------------------------
class SA(nn.Module):
def __init__(self, args):
super(SA, self).__init__()
self.mhatt = MHAtt(args)
self.ffn = FFN(args)
self.dropout1 = nn.Dropout(args.dropout_r)
self.norm1 = LayerNorm(args.hidden_size)
self.dropout2 = nn.Dropout(args.dropout_r)
self.norm2 = LayerNorm(args.hidden_size)
def forward(self, y, y_mask):
y = self.norm1(y + self.dropout1(self.mhatt(y, y, y, y_mask)))
y = self.norm2(y + self.dropout2(self.ffn(y)))
return y
# -------------------------------
# ---- Self Guided Attention ----
# -------------------------------
class SGA(nn.Module):
def __init__(self, args):
super(SGA, self).__init__()
self.mhatt1 = MHAtt(args)
self.mhatt2 = MHAtt(args)
self.ffn = FFN(args)
self.dropout1 = nn.Dropout(args.dropout_r)
self.norm1 = LayerNorm(args.hidden_size)
self.dropout2 = nn.Dropout(args.dropout_r)
self.norm2 = LayerNorm(args.hidden_size)
self.dropout3 = nn.Dropout(args.dropout_r)
self.norm3 = LayerNorm(args.hidden_size)
def forward(self, x, y, x_mask, y_mask):
x = self.norm1(x + self.dropout1(self.mhatt1(v=x, k=x, q=x, mask=x_mask)))
x = self.norm2(x + self.dropout2(self.mhatt2(v=y, k=y, q=x, mask=y_mask)))
x = self.norm3(x + self.dropout3(self.ffn(x)))
return x
# ------------------------------
# ---- Multi-Head Attention ----
# ------------------------------
class MHAtt(nn.Module):
def __init__(self, args):
super(MHAtt, self).__init__()
self.args = args
self.linear_v = nn.Linear(args.hidden_size, args.hidden_size)
self.linear_k = nn.Linear(args.hidden_size, args.hidden_size)
self.linear_q = nn.Linear(args.hidden_size, args.hidden_size)
self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size)
self.dropout = nn.Dropout(args.dropout_r)
def forward(self, v, k, q, mask):
n_batches = q.size(0)
v = (
self.linear_v(v)
.view(
n_batches,
-1,
self.args.multi_head,
int(self.args.hidden_size / self.args.multi_head),
)
.transpose(1, 2)
)
k = (
self.linear_k(k)
.view(
n_batches,
-1,
self.args.multi_head,
int(self.args.hidden_size / self.args.multi_head),
)
.transpose(1, 2)
)
q = (
self.linear_q(q)
.view(
n_batches,
-1,
self.args.multi_head,
int(self.args.hidden_size / self.args.multi_head),
)
.transpose(1, 2)
)
atted = self.att(v, k, q, mask)
atted = (
atted.transpose(1, 2)
.contiguous()
.view(n_batches, -1, self.args.hidden_size)
)
atted = self.linear_merge(atted)
return atted
def att(self, value, key, query, mask):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
att_map = F.softmax(scores, dim=-1)
att_map = self.dropout(att_map)
return torch.matmul(att_map, value)
# ---------------------------
# ---- Feed Forward Nets ----
# ---------------------------
class FFN(nn.Module):
def __init__(self, args):
super(FFN, self).__init__()
self.mlp = MLP(
in_size=args.hidden_size,
mid_size=args.ff_size,
out_size=args.hidden_size,
dropout_r=args.dropout_r,
use_relu=True,
)
def forward(self, x):
return self.mlp(x)
# ---------------------------
# ---- FF + norm -----------
# ---------------------------
class FFAndNorm(nn.Module):
def __init__(self, args):
super(FFAndNorm, self).__init__()
self.ffn = FFN(args)
self.norm1 = LayerNorm(args.hidden_size)
self.dropout2 = nn.Dropout(args.dropout_r)
self.norm2 = LayerNorm(args.hidden_size)
def forward(self, x):
x = self.norm1(x)
x = self.norm2(x + self.dropout2(self.ffn(x)))
return x
class Block(nn.Module):
def __init__(self, args, i):
super(Block, self).__init__()
self.args = args
self.sa1 = SA(args)
self.sa3 = SGA(args)
self.last = i == args.layer - 1
if not self.last:
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
self.norm_l = LayerNorm(args.hidden_size)
self.norm_i = LayerNorm(args.hidden_size)
self.dropout = nn.Dropout(args.dropout_r)
def forward(self, x, x_mask, y, y_mask):
ax = self.sa1(x, x_mask)
ay = self.sa3(y, x, y_mask, x_mask)
x = ax + x
y = ay + y
if self.last:
return x, y
ax = self.att_lang(x, x_mask)
ay = self.att_audio(y, y_mask)
return self.norm_l(x + self.dropout(ax)), self.norm_i(y + self.dropout(ay))
class Model_LA(nn.Module):
def __init__(self, args, vocab_size, pretrained_emb):
super(Model_LA, self).__init__()
self.args = args
# LSTM
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=args.word_embed_size
)
# Loading the GloVe embedding weights
self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
self.lstm_x = nn.LSTM(
input_size=args.word_embed_size,
hidden_size=args.hidden_size,
num_layers=1,
batch_first=True,
)
# self.lstm_y = nn.LSTM(
# input_size=args.audio_feat_size,
# hidden_size=args.hidden_size,
# num_layers=1,
# batch_first=True
# )
# Feature size to hid size
self.adapter = nn.Linear(args.audio_feat_size, args.hidden_size)
# Encoder blocks
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
# Flattenting features before proj
self.attflat_img = AttFlat(args, 1, merge=True)
self.attflat_lang = AttFlat(args, 1, merge=True)
# Classification layers
self.proj_norm = LayerNorm(2 * args.hidden_size)
self.proj = self.proj = nn.Linear(2 * args.hidden_size, args.ans_size)
def forward(self, x, y, _):
x_mask = make_mask(x.unsqueeze(2))
y_mask = make_mask(y)
embedding = self.embedding(x)
x, _ = self.lstm_x(embedding)
# y, _ = self.lstm_y(y)
y = self.adapter(y)
for i, dec in enumerate(self.enc_list):
x_m, x_y = None, None
if i == 0:
x_m, x_y = x_mask, y_mask
x, y = dec(x, x_m, y, x_y)
x = self.attflat_lang(x, None)
y = self.attflat_img(y, None)
# Classification layers
proj_feat = x + y
proj_feat = self.proj_norm(proj_feat)
ans = self.proj(proj_feat)
return ans