|
import torch.nn as nn |
|
import torch |
|
from models.util import MyResNet34 |
|
|
|
class audio2poseLSTM(nn.Module): |
|
def __init__(self): |
|
super(audio2poseLSTM,self).__init__() |
|
|
|
self.em_pose = MyResNet34(256, 1) |
|
self.em_audio = MyResNet34(256, 1) |
|
self.lstm = nn.LSTM(512,256,num_layers=2,bias=True,batch_first=True) |
|
|
|
self.output = nn.Linear(256,6) |
|
|
|
|
|
def forward(self,x): |
|
pose_em = self.em_pose(x["img"]) |
|
bs = pose_em.shape[0] |
|
zero_state = torch.zeros((2, bs, 256), requires_grad=True).to(pose_em.device) |
|
cur_state = (zero_state, zero_state) |
|
img_em = pose_em |
|
bs,seqlen,num,dims = x["audio"].shape |
|
|
|
audio = x["audio"].reshape(-1, 1, num, dims) |
|
audio_em = self.em_audio(audio).reshape(bs, seqlen, 256) |
|
|
|
result = [self.output(img_em).unsqueeze(1)] |
|
|
|
for i in range(seqlen): |
|
|
|
img_em,cur_state = self.lstm(torch.cat((audio_em[:,i:i+1],img_em.unsqueeze(1)),dim=2),cur_state) |
|
img_em = img_em.reshape(-1, 256) |
|
|
|
result.append(self.output(img_em).unsqueeze(1)) |
|
res = torch.cat(result,dim=1) |
|
return res |
|
|