|
import torch.nn as nn |
|
import torch |
|
import math |
|
|
|
class Adapter(nn.Module): |
|
def __init__( |
|
self, ds_factor, hidden_dim, ln_after=False, ln_before=False, dropout=0.1 |
|
): |
|
super().__init__() |
|
assert not hidden_dim % ds_factor |
|
self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor) |
|
self.act = nn.ReLU() |
|
self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim) |
|
self.apply(self.init_weights) |
|
self.ln_after = ln_after |
|
self.ln_before = ln_before |
|
self.dropout = dropout |
|
if ln_after or ln_before: |
|
self.ln = nn.LayerNorm(hidden_dim) |
|
if dropout: |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def init_weights(self, m: nn.Module, std=1e-3): |
|
if isinstance(m, nn.Linear): |
|
torch.nn.init.normal_(m.weight, std=std) |
|
torch.nn.init.normal_(m.bias, std=std) |
|
m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std) |
|
m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std) |
|
elif isinstance(m, nn.LayerNorm): |
|
m.bias.data.zero_() |
|
m.weight.data.fill_(1.0) |
|
|
|
def forward(self, hidden_states): |
|
if self.ln_before: |
|
residual = self.ln(hidden_states) |
|
residual = self.down(residual) |
|
else: |
|
residual = self.down(hidden_states) |
|
residual = self.act(residual) |
|
if self.dropout: |
|
residual = self.dropout(residual) |
|
residual = self.up(residual) |
|
if self.ln_after: |
|
residual = self.ln(hidden_states) |
|
return hidden_states + residual |
|
|
|
|
|
class ST_Adapter(nn.Module): |
|
|
|
def __init__(self, ds_factor, hidden_dim): |
|
super().__init__() |
|
|
|
self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor) |
|
self.conv = nn.Conv1d( |
|
hidden_dim // ds_factor, hidden_dim // ds_factor, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
groups=hidden_dim // ds_factor |
|
) |
|
self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim) |
|
nn.init.constant_(self.conv.weight, 0.) |
|
nn.init.constant_(self.conv.bias, 0.) |
|
nn.init.constant_(self.down.bias, 0.) |
|
nn.init.constant_(self.up.bias, 0.) |
|
|
|
def forward(self, x): |
|
N, T, C = x.size() |
|
ori_x = x |
|
x = self.down(x) |
|
x = x.permute(0, 2, 1).contiguous() |
|
x = self.conv(x) |
|
x = x.permute(0, 2, 1).contiguous() |
|
x = self.up(x) |
|
x = x + ori_x |
|
return x |
|
|