|
|
|
import logging |
|
import math |
|
import os |
|
import warnings |
|
from random import randint |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from numpy.lib.stride_tricks import as_strided |
|
from omegaconf import DictConfig, open_dict |
|
|
|
from ..utils import energy |
|
from ..utils.file import save |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class SeqTokenizer: |
|
''' |
|
This class should contain functions that other data specific classes should inherit from. |
|
''' |
|
def __init__(self,seqs_dot_bracket_labels: pd.DataFrame, config: DictConfig): |
|
|
|
self.seqs_dot_bracket_labels = seqs_dot_bracket_labels.reset_index(drop=True) |
|
|
|
if not config["inference"]: |
|
self.seqs_dot_bracket_labels = self.seqs_dot_bracket_labels\ |
|
.sample(frac=1)\ |
|
.reset_index(drop=True) |
|
|
|
|
|
self.model_input = config["model_config"].model_input |
|
|
|
|
|
|
|
if config["train_config"].filter_seq_length: |
|
self.get_outlier_length_threshold() |
|
self.limit_seqs_to_range() |
|
|
|
else: |
|
self.max_length = self.seqs_dot_bracket_labels['Sequences'].str.len().max() |
|
self.min_length = 0 |
|
|
|
with open_dict(config): |
|
config["model_config"]["max_length"] = np.int64(self.max_length).item() |
|
config["model_config"]["min_length"] = np.int64(self.min_length).item() |
|
|
|
self.window = config["model_config"].window |
|
self.tokens_len = math.ceil(self.max_length / self.window) |
|
if config["model_config"].tokenizer in ["overlap", "overlap_multi_window"]: |
|
self.tokens_len = int(self.max_length - (config["model_config"].window - 1)) |
|
self.tokenizer = config["model_config"].tokenizer |
|
|
|
|
|
self.seq_len_dist = self.seqs_dot_bracket_labels['Sequences'].str.len().value_counts() |
|
|
|
self.seq_tokens_ids_dict = {} |
|
self.second_input_tokens_ids_dict = {} |
|
|
|
|
|
config["model_config"].num_classes = len(self.seqs_dot_bracket_labels['Labels'].unique()) |
|
|
|
self.set_class_attr() |
|
|
|
|
|
def get_outlier_length_threshold(self): |
|
lengths_arr = self.seqs_dot_bracket_labels['Sequences'].str.len() |
|
mean = np.mean(lengths_arr) |
|
standard_deviation = np.std(lengths_arr) |
|
distance_from_mean = abs(lengths_arr - mean) |
|
in_distribution = distance_from_mean < 2 * standard_deviation |
|
|
|
inlier_lengths = np.sort(lengths_arr[in_distribution].unique()) |
|
self.max_length = int(np.max(inlier_lengths)) |
|
self.min_length = int(np.min(inlier_lengths)) |
|
logger.info(f'maximum and minimum sequence length is set to: {self.max_length} and {self.min_length}') |
|
return |
|
|
|
|
|
def limit_seqs_to_range(self): |
|
''' |
|
Trimms seqs longer than maximum len and deletes seqs shorter than min length |
|
''' |
|
df = self.seqs_dot_bracket_labels |
|
min_to_be_deleted = [] |
|
|
|
num_longer_seqs = sum(df['Sequences'].str.len()>self.max_length) |
|
if num_longer_seqs: |
|
logger.info(f"Number of sequences to be trimmed: {num_longer_seqs}") |
|
|
|
|
|
for idx,seq in enumerate(df['Sequences']): |
|
if len(seq) > self.max_length: |
|
df['Sequences'].iloc[idx] = \ |
|
df['Sequences'].iloc[idx][:self.max_length] |
|
|
|
elif len(seq) < self.min_length: |
|
|
|
min_to_be_deleted.append(str(idx)) |
|
|
|
if len(min_to_be_deleted): |
|
df = df.drop(min_to_be_deleted).reset_index(drop=True) |
|
logger.info(f"Number of sequences shroter sequences to be removed: {len(min_to_be_deleted)}") |
|
self.seqs_dot_bracket_labels = df |
|
|
|
def get_secondary_structure(self,sequences): |
|
secondary = energy.fold_sequences(sequences.tolist()) |
|
return secondary['structure_37'].values |
|
|
|
|
|
def chunkstring_overlap(self, string, window): |
|
return ( |
|
string[0 + i : window + i] for i in range(0, len(string) - window + 1, 1) |
|
) |
|
|
|
def chunkstring_no_overlap(self, string, window): |
|
return (string[0 + i : window + i] for i in range(0, len(string), window)) |
|
|
|
|
|
def tokenize_samples(self, window:int,sequences_to_be_tokenized:pd.DataFrame,inference:bool=False,tokenizer:str="overlap") -> np.ndarray: |
|
""" |
|
This function tokenizes rnas based on window(window) |
|
with or without overlap according to the current tokenizer option. |
|
In case of overlap: |
|
example: Token :AACTAGA, window: 3 |
|
output: AAC,ACT,CTA,TAG,AGA |
|
|
|
In case no_overlap: |
|
example: Token :AACTAGA, window: 3 |
|
output: AAC,TAG,A |
|
""" |
|
|
|
if "overlap" in tokenizer: |
|
feature_tokens_gen = list( |
|
self.chunkstring_overlap(feature, window) |
|
for feature in sequences_to_be_tokenized |
|
) |
|
elif tokenizer == "no_overlap": |
|
feature_tokens_gen = list( |
|
self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized |
|
) |
|
|
|
samples_tokenized = [] |
|
sample_token_ids = [] |
|
if not self.seq_tokens_ids_dict: |
|
self.seq_tokens_ids_dict = {"pad": 0} |
|
|
|
for gen in feature_tokens_gen: |
|
sample_token_id = [] |
|
sample_token = list(gen) |
|
sample_len = len(sample_token) |
|
|
|
sample_token.extend( |
|
["pad" for _ in range(int(self.tokens_len - sample_len))] |
|
) |
|
|
|
for token in sample_token: |
|
|
|
if token not in self.seq_tokens_ids_dict: |
|
if not inference: |
|
id = len(self.seq_tokens_ids_dict.keys()) |
|
self.seq_tokens_ids_dict[token] = id |
|
else: |
|
|
|
logger.warning(f"The sequence token: {token} was not seen previously by the model. Token will be replaced by a random token") |
|
id = randint(1,len(self.seq_tokens_ids_dict.keys()) - 1) |
|
token = self.seq_tokens_ids_dict[id] |
|
|
|
sample_token_id.append(self.seq_tokens_ids_dict[token]) |
|
|
|
|
|
sample_token_ids.append(np.array(sample_token_id)) |
|
|
|
sample_token = np.array(sample_token) |
|
samples_tokenized.append(sample_token) |
|
|
|
return (np.array(samples_tokenized), np.array(sample_token_ids)) |
|
|
|
def tokenize_secondary_structure(self, window,sequences_to_be_tokenized,inference:bool=False,tokenizer= "overlap") -> np.ndarray: |
|
""" |
|
This function tokenizes rnas based on window(window) |
|
with or without overlap according to the current tokenizer option. |
|
In case of overlap: |
|
example: Token :...()..., window: 3 |
|
output: ...,..(,.(),().,)..,... |
|
|
|
In case no_overlap: |
|
example: Token :...()..., window: 3 |
|
output: ...,().,.. |
|
""" |
|
samples_tokenized = [] |
|
sample_token_ids = [] |
|
if not self.second_input_tokens_ids_dict: |
|
self.second_input_tokens_ids_dict = {"pad": 0} |
|
|
|
|
|
if "overlap" in tokenizer: |
|
feature_tokens_gen = list( |
|
self.chunkstring_overlap(feature, window) |
|
for feature in sequences_to_be_tokenized |
|
) |
|
elif "no_overlap" == tokenizer: |
|
feature_tokens_gen = list( |
|
self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized |
|
) |
|
|
|
for seq_idx, gen in enumerate(feature_tokens_gen): |
|
sample_token_id = [] |
|
sample_token = list(gen) |
|
|
|
|
|
for token in sample_token: |
|
|
|
if token not in self.second_input_tokens_ids_dict: |
|
if not inference: |
|
id = len(self.second_input_tokens_ids_dict.keys()) |
|
self.second_input_tokens_ids_dict[token] = id |
|
else: |
|
|
|
warnings.warn(f"The secondary structure token: {token} was not seen previously by the model. Token will be replaced by a random token") |
|
id = randint(1,len(self.second_input_tokens_ids_dict.keys()) - 1) |
|
token = self.second_input_tokens_ids_dict[id] |
|
|
|
sample_token_id.append(self.second_input_tokens_ids_dict[token]) |
|
|
|
sample_token_ids.append(sample_token_id) |
|
samples_tokenized.append(sample_token) |
|
|
|
|
|
|
|
|
|
self.second_input_token_len = self.tokens_len |
|
for seq_idx, token in enumerate(sample_token_ids): |
|
sample_len = len(token) |
|
sample_token_ids[seq_idx].extend( |
|
[self.second_input_tokens_ids_dict["pad"] for _ in range(int(self.second_input_token_len - sample_len))] |
|
) |
|
samples_tokenized[seq_idx].extend( |
|
["pad" for _ in range(int(self.second_input_token_len - sample_len))] |
|
) |
|
sample_token_ids[seq_idx] = np.array(sample_token_ids[seq_idx]) |
|
samples_tokenized[seq_idx] = np.array(samples_tokenized[seq_idx]) |
|
|
|
return (np.array(samples_tokenized), np.array(sample_token_ids)) |
|
|
|
def set_class_attr(self): |
|
|
|
self.seq = self.seqs_dot_bracket_labels["Sequences"] |
|
if 'struct' in self.model_input: |
|
self.struct = self.seqs_dot_bracket_labels["Secondary"] |
|
|
|
self.labels = self.seqs_dot_bracket_labels['Labels'] |
|
|
|
def prepare_multi_idx_pd(self,num_coln,pd_name,pd_value): |
|
iterables = [[pd_name], np.arange(num_coln)] |
|
index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) |
|
return pd.DataFrame(columns=index, data=pd_value) |
|
|
|
def phase_sequence(self,sample_token_ids): |
|
phase0 = sample_token_ids[:,::2] |
|
phase1 = sample_token_ids[:,1::2] |
|
|
|
if phase0.shape!= phase1.shape: |
|
phase1 = np.concatenate([phase1,np.zeros(phase1.shape[0])[...,np.newaxis]],axis=1) |
|
sample_token_ids = phase0 |
|
|
|
return sample_token_ids,phase1 |
|
|
|
def custom_roll(self,arr, n_shifts_per_row): |
|
''' |
|
shifts each row of a numpy array according to n_shifts_per_row |
|
''' |
|
m = np.asarray(n_shifts_per_row) |
|
arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() |
|
strd_0, strd_1 = arr_roll.strides |
|
n = arr.shape[1] |
|
result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1)) |
|
|
|
return result[np.arange(arr.shape[0]), (n-m)%n] |
|
|
|
def save_token_dicts(self): |
|
|
|
save(data = self.second_input_tokens_ids_dict,path = os.getcwd()+'/second_input_tokens_ids_dict') |
|
save(data = self.seq_tokens_ids_dict,path = os.getcwd()+'/seq_tokens_ids_dict') |
|
|
|
|
|
def get_tokenized_data(self,inference:bool=False): |
|
|
|
samples_tokenized,sample_token_ids = self.tokenize_samples(self.window,self.seq,inference) |
|
|
|
logger.info(f'Vocab size for primary sequences: {len(self.seq_tokens_ids_dict.keys())}') |
|
logger.info(f'Vocab size for secondary structure: {len(self.second_input_tokens_ids_dict.keys())}') |
|
logger.info(f'Number of sequences used for tokenization: {samples_tokenized.shape[0]}') |
|
|
|
|
|
if "comp" in self.model_input: |
|
|
|
self.seq_comp = [] |
|
for feature in self.seq: |
|
feature = feature.replace('A','%temp%').replace('T','A')\ |
|
.replace('C','%temp2%').replace('G','C')\ |
|
.replace('%temp%','T').replace('%temp2%','G') |
|
self.seq_comp.append(feature) |
|
|
|
self.seq_tokens_ids_dict_temp = self.seq_tokens_ids_dict |
|
self.seq_tokens_ids_dict = None |
|
|
|
_,seq_comp_str_token_ids = self.tokenize_samples(self.window,self.seq_comp,inference) |
|
sec_input_value = seq_comp_str_token_ids |
|
|
|
self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict |
|
self.seq_tokens_ids_dict = self.seq_tokens_ids_dict_temp |
|
|
|
|
|
|
|
if "struct" in self.model_input: |
|
_,sec_str_token_ids = self.tokenize_secondary_structure(self.window,self.struct,inference) |
|
sec_input_value = sec_str_token_ids |
|
|
|
|
|
|
|
if "seq-seq" in self.model_input: |
|
sample_token_ids,sec_input_value = self.phase_sequence(sample_token_ids) |
|
self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict |
|
|
|
|
|
|
|
|
|
if "seq-rev" in self.model_input or "baseline" in self.model_input or self.model_input == 'seq': |
|
sample_token_ids_rev = sample_token_ids[:,::-1] |
|
n_zeros = np.count_nonzero(sample_token_ids_rev==0, axis=1) |
|
sec_input_value = self.custom_roll(sample_token_ids_rev, -n_zeros) |
|
self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict |
|
|
|
|
|
|
|
|
|
seqs_length = list(sum(sample_token_ids.T !=0)) |
|
|
|
labels_df = self.prepare_multi_idx_pd(1,"Labels",self.labels.values) |
|
tokens_id_df = self.prepare_multi_idx_pd(sample_token_ids.shape[1],"tokens_id",sample_token_ids) |
|
tokens_df = self.prepare_multi_idx_pd(samples_tokenized.shape[1],"tokens",samples_tokenized) |
|
sec_input_df = self.prepare_multi_idx_pd(sec_input_value.shape[1],'second_input',sec_input_value) |
|
seqs_length_df = self.prepare_multi_idx_pd(1,"seqs_length",seqs_length) |
|
|
|
all_df = labels_df.join(tokens_df).join(tokens_id_df).join(sec_input_df).join(seqs_length_df) |
|
|
|
|
|
self.save_token_dicts() |
|
return all_df |