|
|
|
import logging |
|
import math |
|
import random |
|
from typing import Dict, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from omegaconf import DictConfig |
|
from torch.nn.modules.normalization import LayerNorm |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def circulant_mask(n: int, window: int) -> torch.Tensor: |
|
"""Calculate the relative attention mask, calculated once when model instatiated, as a subset of this matrix |
|
will be used for a input length less than max. |
|
i,j represent relative token positions in this matrix and in the attention scores matrix, |
|
this mask enables attention scores to be set to 0 if further than the specified window length |
|
|
|
:param n: a fixed parameter set to be larger than largest max sequence length across batches |
|
:param window: [window length], |
|
:return relative attention mask |
|
""" |
|
circulant_t = torch.zeros(n, n) |
|
|
|
offsets = [0] + [i for i in range(window + 1)] + [-i for i in range(window + 1)] |
|
if window >= n: |
|
return torch.ones(n, n) |
|
for offset in offsets: |
|
|
|
circulant_t.diagonal(offset=offset).copy_(torch.ones(n - abs(offset))) |
|
return circulant_t |
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
|
"""normal query, key, value based self attention but with relative attention functionality |
|
and a learnable bias encoding relative token position which is added to the attention scores before the softmax""" |
|
|
|
def __init__(self, config: DictConfig, relative_attention: int): |
|
"""init self attention weight of each key, query, value and output projection layer. |
|
|
|
:param config: model config |
|
:type config: ConveRTModelConfig |
|
""" |
|
super().__init__() |
|
|
|
self.config = config |
|
self.query = nn.Linear(config.num_embed_hidden, config.num_attention_project) |
|
self.key = nn.Linear(config.num_embed_hidden, config.num_attention_project) |
|
self.value = nn.Linear(config.num_embed_hidden, config.num_attention_project) |
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
self.output_projection = nn.Linear( |
|
config.num_attention_project, config.num_embed_hidden |
|
) |
|
self.bias = torch.nn.Parameter(torch.randn(config.n), requires_grad=True) |
|
stdv = 1.0 / math.sqrt(self.bias.data.size(0)) |
|
self.bias.data.uniform_(-stdv, stdv) |
|
self.relative_attention = relative_attention |
|
self.n = self.config.n |
|
self.half_n = self.n // 2 |
|
self.register_buffer( |
|
"relative_mask", |
|
circulant_mask(config.tokens_len, self.relative_attention), |
|
) |
|
|
|
def forward( |
|
self, attn_input: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
"""calculate self-attention of query, key and weighted to value at the end. |
|
self-attention input is projected by linear layer at the first time. |
|
applying attention mask for ignore pad index attention weight. Relative attention mask |
|
applied and a learnable bias added to the attention scores. |
|
return value after apply output projection layer to value * attention |
|
|
|
:param attn_input: [description] |
|
:type attn_input: [type] |
|
:param attention_mask: [description], defaults to None |
|
:type attention_mask: [type], optional |
|
:return: [description] |
|
:rtype: [type] |
|
""" |
|
self.T = attn_input.size()[1] |
|
|
|
_query = self.query.forward(attn_input) |
|
_key = self.key.forward(attn_input) |
|
_value = self.value.forward(attn_input) |
|
|
|
|
|
attention_scores = torch.matmul(_query, _key.transpose(1, 2)) |
|
attention_scores = attention_scores / math.sqrt( |
|
self.config.num_attention_project |
|
) |
|
|
|
|
|
|
|
|
|
extended_attention_mask = (1.0 - attention_mask.unsqueeze(-1)) * -10000.0 |
|
attention_scores = attention_scores + extended_attention_mask |
|
|
|
|
|
|
|
|
|
attention_scores = attention_scores.masked_fill( |
|
self.relative_mask.unsqueeze(0)[:, : self.T, : self.T] == 0, float("-inf") |
|
) |
|
|
|
|
|
|
|
|
|
|
|
ii, jj = torch.meshgrid(torch.arange(self.T), torch.arange(self.T)) |
|
B_matrix = self.bias[self.n // 2 - ii + jj] |
|
|
|
attention_scores = attention_scores + B_matrix.unsqueeze(0) |
|
|
|
attention_scores = self.softmax(attention_scores) |
|
output = torch.matmul(attention_scores, _value) |
|
|
|
output = self.output_projection(output) |
|
|
|
return [output,attention_scores] |
|
|
|
|
|
|
|
class FeedForward1(nn.Module): |
|
def __init__( |
|
self, input_hidden: int, intermediate_hidden: int, dropout_rate: float = 0.0 |
|
): |
|
|
|
|
|
super().__init__() |
|
|
|
self.linear_1 = nn.Linear(input_hidden, intermediate_hidden) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.linear_2 = nn.Linear(intermediate_hidden, input_hidden) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = F.gelu(self.linear_1(x)) |
|
return self.linear_2(self.dropout(x)) |
|
|
|
|
|
class SharedInnerBlock(nn.Module): |
|
def __init__(self, config: DictConfig, relative_attn: int): |
|
super().__init__() |
|
|
|
self.config = config |
|
self.self_attention = SelfAttention(config, relative_attn) |
|
self.norm1 = LayerNorm(config.num_embed_hidden) |
|
self.dropout = nn.Dropout(config.dropout) |
|
self.ff1 = FeedForward1( |
|
config.num_embed_hidden, config.feed_forward1_hidden, config.dropout |
|
) |
|
self.norm2 = LayerNorm(config.num_embed_hidden) |
|
|
|
def forward(self, x: torch.Tensor, attention_mask: int) -> torch.Tensor: |
|
|
|
new_values_x,attn_scores = self.self_attention(x, attention_mask=attention_mask) |
|
x = x+new_values_x |
|
x = self.norm1(x) |
|
x = x + self.ff1(x) |
|
return self.norm2(x),attn_scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
def __init__(self, config: DictConfig): |
|
super().__init__() |
|
self.num_attention_heads = config.num_attention_heads |
|
self.num_attn_proj = config.num_embed_hidden * config.num_attention_heads |
|
self.attention_head_size = int(self.num_attn_proj / self.num_attention_heads) |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = nn.Linear(config.num_embed_hidden, self.num_attn_proj) |
|
self.key = nn.Linear(config.num_embed_hidden, self.num_attn_proj) |
|
self.value = nn.Linear(config.num_embed_hidden, self.num_attn_proj) |
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
def forward( |
|
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None |
|
) -> torch.Tensor: |
|
B, T, _ = hidden_states.size() |
|
|
|
k = ( |
|
self.key(hidden_states) |
|
.view(B, T, self.num_attention_heads, self.attention_head_size) |
|
.transpose(1, 2) |
|
) |
|
q = ( |
|
self.query(hidden_states) |
|
.view(B, T, self.num_attention_heads, self.attention_head_size) |
|
.transpose(1, 2) |
|
) |
|
v = ( |
|
self.value(hidden_states) |
|
.view(B, T, self.num_attention_heads, self.attention_head_size) |
|
.transpose(1, 2) |
|
) |
|
|
|
attention_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask[:, None, None, :] |
|
attention_mask = (1.0 - attention_mask) * -10000.0 |
|
|
|
attention_scores = attention_scores + attention_mask |
|
|
|
attention_scores = F.softmax(attention_scores, dim=-1) |
|
|
|
attention_scores = self.dropout(attention_scores) |
|
|
|
y = attention_scores @ v |
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, self.num_attn_proj) |
|
|
|
return y |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, model_config: DictConfig,): |
|
super(PositionalEncoding, self).__init__() |
|
self.dropout = nn.Dropout(p=model_config.dropout) |
|
self.num_embed_hidden = model_config.num_embed_hidden |
|
pe = torch.zeros(model_config.tokens_len, self.num_embed_hidden) |
|
position = torch.arange( |
|
0, model_config.tokens_len, dtype=torch.float |
|
).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, self.num_embed_hidden, 2).float() |
|
* (-math.log(10000.0) / self.num_embed_hidden) |
|
) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
pe = pe.unsqueeze(0) |
|
self.register_buffer("pe", pe) |
|
|
|
def forward(self, x): |
|
x = x + self.pe[: x.size(0), :] |
|
return self.dropout(x) |
|
|
|
|
|
class RNAFFrwd( |
|
nn.Module |
|
): |
|
"""Fully-Connected 3-layer Linear Model""" |
|
|
|
def __init__(self, model_config: DictConfig): |
|
""" |
|
:param input_hidden: first-hidden layer input embed-dim |
|
:type input_hidden: int |
|
:param intermediate_hidden: layer-(hidden)-layer middle point weight |
|
:type intermediate_hidden: int |
|
:param dropout_rate: dropout rate, defaults to None |
|
:type dropout_rate: float, optional |
|
""" |
|
|
|
|
|
super().__init__() |
|
|
|
self.rna_ffwd_input_dim = ( |
|
model_config.num_embed_hidden * model_config.num_attention_heads |
|
) |
|
self.linear_1 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim) |
|
self.linear_2 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim) |
|
|
|
self.norm1 = LayerNorm(self.rna_ffwd_input_dim) |
|
self.norm2 = LayerNorm(self.rna_ffwd_input_dim) |
|
self.final = nn.Linear(self.rna_ffwd_input_dim, model_config.num_embed_hidden) |
|
self.orthogonal_initialization() |
|
|
|
def orthogonal_initialization(self): |
|
for l in [ |
|
self.linear_1, |
|
self.linear_2, |
|
]: |
|
torch.nn.init.orthogonal_(l.weight) |
|
|
|
def forward(self, x: torch.Tensor, attn_msk: torch.Tensor) -> torch.Tensor: |
|
sentence_lengths = attn_msk.sum(1) |
|
|
|
|
|
|
|
|
|
|
|
norms = 1 / torch.sqrt(sentence_lengths.double()).float() |
|
|
|
x = norms.unsqueeze(1) * torch.sum(x, dim=1) |
|
|
|
x = x + F.gelu(self.linear_1(self.norm1(x))) |
|
x = x + F.gelu(self.linear_2(self.norm2(x))) |
|
|
|
return F.normalize(self.final(x), dim=1, p=2) |
|
|
|
|
|
class RNATransformer(nn.Module): |
|
def __init__(self, model_config: DictConfig): |
|
super().__init__() |
|
self.num_embedd_hidden = model_config.num_embed_hidden |
|
self.encoder = nn.Embedding( |
|
model_config.vocab_size, model_config.num_embed_hidden |
|
) |
|
self.model_input = model_config.model_input |
|
if 'baseline' not in self.model_input: |
|
|
|
self.pos_encoder = PositionalEncoding(model_config) |
|
|
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
SharedInnerBlock(model_config, int(window/model_config.window)) |
|
for window in model_config.relative_attns[ |
|
: model_config.num_encoder_layers |
|
] |
|
] |
|
) |
|
self.MHA = MultiheadAttention(model_config) |
|
|
|
|
|
self.rna_ffrwd = RNAFFrwd(model_config) |
|
self.pad_id = 0 |
|
|
|
def forward(self, x:torch.Tensor) -> torch.Tensor: |
|
if x.is_cuda: |
|
long_tensor = torch.cuda.LongTensor |
|
else: |
|
long_tensor = torch.LongTensor |
|
|
|
embedds = self.encoder(x) |
|
if 'baseline' not in self.model_input: |
|
output = self.pos_encoder(embedds) |
|
attention_mask = (x != self.pad_id).int() |
|
|
|
for l in self.transformer_layers: |
|
output,attn_scores = l(output, attention_mask) |
|
output = self.MHA(output) |
|
output = self.rna_ffrwd(output, attention_mask) |
|
return output,attn_scores |
|
else: |
|
embedds = torch.flatten(embedds,start_dim=1) |
|
return embedds,None |
|
|
|
class GeneEmbeddModel(nn.Module): |
|
def __init__( |
|
self, main_config: DictConfig, |
|
): |
|
super().__init__() |
|
self.train_config = main_config["train_config"] |
|
self.model_config = main_config["model_config"] |
|
self.device = self.train_config.device |
|
self.model_input = self.model_config["model_input"] |
|
self.false_input_perc = self.model_config["false_input_perc"] |
|
|
|
self.model_config.n = self.model_config.tokens_len*2+1 |
|
self.transformer_layers = RNATransformer(self.model_config) |
|
|
|
self.tokens_len = self.model_config.tokens_len |
|
|
|
|
|
|
|
self.model_config.tokens_len = self.model_config.second_input_token_len |
|
self.model_config.n = self.model_config.tokens_len*2+1 |
|
self.seq_vocab_size = self.model_config.vocab_size |
|
|
|
self.model_config.vocab_size = self.model_config.second_input_vocab_size |
|
|
|
self.second_input_model = RNATransformer(self.model_config) |
|
|
|
|
|
self.num_transformers = 2 |
|
if self.model_input == 'seq': |
|
self.num_transformers = 1 |
|
|
|
self.weight_decay = self.train_config.l2_weight_decay |
|
if 'baseline' in self.model_input: |
|
self.num_transformers = 1 |
|
num_nodes = self.model_config.num_embed_hidden*self.tokens_len |
|
self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes) |
|
else: |
|
|
|
num_nodes = self.num_transformers*self.model_config.num_embed_hidden |
|
if self.num_transformers == 1: |
|
self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes) |
|
else: |
|
self.final_clf_1 = nn.Linear(num_nodes,num_nodes) |
|
self.final_clf_2 = nn.Linear(num_nodes,self.model_config.num_classes) |
|
self.relu = nn.ReLU() |
|
self.BN = nn.BatchNorm1d(num_nodes) |
|
self.dropout = nn.Dropout(0.6) |
|
|
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) |
|
|
|
|
|
def distort_input(self,x): |
|
for sample_idx in range(x.shape[0]): |
|
seq_length = x[sample_idx,-1] |
|
num_tokens_flipped = int(self.false_input_perc*seq_length) |
|
max_start_flip_idx = seq_length - num_tokens_flipped |
|
|
|
random_feat_idx = random.randint(0,max_start_flip_idx-1) |
|
x[sample_idx,random_feat_idx:random_feat_idx+num_tokens_flipped] = \ |
|
torch.tensor(np.random.choice(range(1,self.seq_vocab_size-1),size=num_tokens_flipped,replace=True)) |
|
|
|
x[sample_idx,random_feat_idx+self.tokens_len:random_feat_idx+self.tokens_len+num_tokens_flipped] = \ |
|
torch.tensor(np.random.choice(range(1,self.model_config.second_input_vocab_size-1),size=num_tokens_flipped,replace=True)) |
|
return x |
|
|
|
def forward(self, x,train=False): |
|
if self.device == 'cuda': |
|
long_tensor = torch.cuda.LongTensor |
|
float_tensor = torch.cuda.FloatTensor |
|
else: |
|
long_tensor = torch.LongTensor |
|
float_tensor = torch.FloatTensor |
|
if train: |
|
if self.false_input_perc > 0: |
|
x = self.distort_input(x) |
|
|
|
gene_embedd,attn_scores_first = self.transformer_layers( |
|
x[:, : self.tokens_len].type(long_tensor) |
|
) |
|
attn_scores_second = None |
|
second_input_embedd,attn_scores_second = self.second_input_model( |
|
x[:, self.tokens_len :-1].type(long_tensor) |
|
) |
|
|
|
|
|
if self.num_transformers == 1: |
|
activations = self.final_clf_1(gene_embedd) |
|
else: |
|
out_clf_1 = self.final_clf_1(torch.cat((gene_embedd, second_input_embedd), 1)) |
|
out = self.BN(out_clf_1) |
|
out = self.relu(out) |
|
out = self.dropout(out) |
|
activations = self.final_clf_2(out) |
|
|
|
|
|
if 'baseline' in self.model_input: |
|
attn_scores_first = torch.ones((1,2,2),device=x.device) |
|
|
|
return [gene_embedd, second_input_embedd, activations,attn_scores_first,attn_scores_second] |
|
|