|
"""Reverse-complement equivariant modules. |
|
|
|
""" |
|
from collections import OrderedDict |
|
from typing import Optional |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
try: |
|
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn |
|
except ImportError: |
|
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None |
|
|
|
|
|
class RCPSEmbedding(nn.Module): |
|
"""Embedding layer that supports reverse-complement equivariance.""" |
|
def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs): |
|
""" |
|
Args: |
|
vocab_size: Size of vocabulary. |
|
d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim). |
|
complement_map: Dictionary mapping each token id to its complement. |
|
""" |
|
super().__init__() |
|
self.register_buffer( |
|
"complement_map", |
|
torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long) |
|
) |
|
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) |
|
|
|
@property |
|
def weight(self): |
|
"""Embedding weights.""" |
|
return self.embedding.weight |
|
|
|
def set_weight(self, value): |
|
"""Set embedding weights.""" |
|
self.embedding.weight = value |
|
|
|
def rc(self, x): |
|
"""Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids.""" |
|
return torch.gather( |
|
self.complement_map.unsqueeze(0).expand(x.shape[0], -1), |
|
dim=1, |
|
index=torch.flip(x, dims=[-1]) |
|
) |
|
|
|
def forward(self, input_ids): |
|
"""Reverse-complement equivariant forward pass. |
|
|
|
This embedding module doubles the output dimensionality to support reverse-complement equivariance. |
|
|
|
Args: |
|
input_ids: Input tensor of shape (batch_size, seq_len) |
|
Returns: |
|
Embedding tensor of shape (batch_size, seq_len, d_model * 2) |
|
""" |
|
fwd_out = self.embedding(input_ids) |
|
rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1]) |
|
|
|
return torch.cat([fwd_out, rc_out], dim=-1) |
|
|
|
|
|
class RCPSWrapper(nn.Module): |
|
"""Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module. |
|
|
|
See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory |
|
Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details. |
|
""" |
|
def __init__(self, submodule: nn.Module): |
|
super().__init__() |
|
self.submodule = submodule |
|
|
|
@staticmethod |
|
def rc(x): |
|
"""Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions.""" |
|
return torch.flip(x, dims=[-2, -1]) |
|
|
|
def forward(self, x, **kwargs): |
|
"""Reverse-complement equivariant forward pass. |
|
|
|
Args: |
|
x: Input tensor of shape (batch_size, seq_len, channels) |
|
Returns: |
|
Output tensor of shape (batch_size, seq_len, channels * 2) |
|
""" |
|
n_channels = x.shape[-1] |
|
|
|
fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs) |
|
|
|
rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs) |
|
|
|
return torch.cat([fwd_out, self.rc(rc_out)], dim=-1) |
|
|
|
|
|
class RCPSAddNormWrapper(RCPSWrapper): |
|
"""RC equivariant AddNorm layer.""" |
|
def __init__(self, submodule: nn.Module): |
|
super().__init__(submodule) |
|
|
|
def forward(self, x, residual=None): |
|
""" |
|
Args: |
|
x: Input tensor of shape (batch_size, seq_len, channels) |
|
residual: Residual tensor of shape (batch_size, seq_len, channels) or None. |
|
""" |
|
n_channels = x.shape[-1] |
|
if residual is None: |
|
residual = x |
|
x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype)) |
|
x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype)) |
|
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1) |
|
else: |
|
residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2] |
|
x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype)) |
|
|
|
residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:]) |
|
x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype)) |
|
|
|
residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1) |
|
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1) |
|
|
|
return x, residual |
|
|
|
|
|
class RCPSMambaBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
mixer_cls, |
|
norm_cls=nn.LayerNorm, |
|
fused_add_norm=False, |
|
residual_in_fp32=False, |
|
device=None, |
|
dtype=None, |
|
): |
|
"""RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. |
|
|
|
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py |
|
""" |
|
super().__init__() |
|
self.residual_in_fp32 = residual_in_fp32 |
|
self.fused_add_norm = fused_add_norm |
|
self.mixer = RCPSWrapper(mixer_cls(dim)) |
|
norm_f = norm_cls(dim) |
|
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f) |
|
|
|
def forward( |
|
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None |
|
): |
|
r"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
hidden_states: the sequence to the encoder layer (required). |
|
residual: hidden_states = Mixer(LN(residual)). |
|
inference_params: inference parameters for mixer. |
|
""" |
|
if not self.fused_add_norm: |
|
hidden_states, residual = self.norm(hidden_states, residual=residual) |
|
if self.residual_in_fp32: |
|
residual = residual.to(torch.float32) |
|
else: |
|
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn |
|
|
|
hidden_states_fwd, residual_fwd = fused_add_norm_fn( |
|
hidden_states[..., hidden_states.shape[-1] // 2:], |
|
self.norm.weight, |
|
self.norm.bias, |
|
residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None, |
|
prenorm=True, |
|
residual_in_fp32=self.residual_in_fp32, |
|
eps=self.norm.eps, |
|
) |
|
|
|
hidden_states_rc, residual_rc = fused_add_norm_fn( |
|
hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]), |
|
self.norm.weight, |
|
self.norm.bias, |
|
residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None, |
|
prenorm=True, |
|
residual_in_fp32=self.residual_in_fp32, |
|
eps=self.norm.eps, |
|
) |
|
hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1) |
|
residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1) |
|
hidden_states = self.mixer(hidden_states, inference_params=inference_params) |
|
return hidden_states, residual |
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
"""Allocate inference cache for mixer. |
|
|
|
Keep for compatibility with original Mamba Block. |
|
""" |
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
|
|
|
class RCPSLMHead(nn.Module): |
|
"""LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs.""" |
|
def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs): |
|
""" |
|
`true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement |
|
equivariant, i.e. 0.5 times the actual input dim. |
|
""" |
|
super().__init__() |
|
self.register_buffer( |
|
"complement_map", |
|
torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long) |
|
) |
|
self.true_dim = true_dim |
|
self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs) |
|
|
|
@property |
|
def weight(self): |
|
"""LM head weights.""" |
|
return self.lm_head.weight |
|
|
|
def set_weight(self, value): |
|
"""Set LM head weights.""" |
|
self.lm_head.weight = value |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim. |
|
""" |
|
n_channels = x.shape[-1] |
|
assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels." |
|
fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias) |
|
rc_logits = F.linear( |
|
torch.flip(x[..., n_channels // 2:], dims=[-1]), |
|
self.weight[self.complement_map, :], |
|
bias=self.lm_head.bias |
|
) |
|
return fwd_logits + rc_logits |
|
|