|
import logging |
|
import random |
|
from contextlib import redirect_stdout |
|
from pathlib import Path |
|
from random import randint |
|
from typing import Dict, List, Tuple |
|
|
|
import numpy as np |
|
import pandas as pd |
|
|
|
from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split |
|
from ..utils.energy import fold_sequences |
|
from ..utils.file import load |
|
from ..utils.tcga_post_analysis_utils import Results_Handler |
|
from ..utils.utils import (get_model, infer_from_pd, |
|
prepare_inference_results_tcga, |
|
update_config_with_inference_params) |
|
from .seq_tokenizer import SeqTokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class IDModelAugmenter: |
|
''' |
|
This class is used to augment the dataset with the predictions of the ID models |
|
It will first predict the subclasses of the NA set using the ID models |
|
Then it will compute the levenstein distance between the sequences of the NA set and the closest neighbor in the training set |
|
If the levenstein distance is less than a threshold, the sequence is considered familiar |
|
''' |
|
def __init__(self,df:pd.DataFrame,config:Dict): |
|
self.df = df |
|
self.config = config |
|
self.mapping_dict = load(config['train_config']['mapping_dict_path']) |
|
|
|
|
|
def predict_transforna_na(self) -> Tuple: |
|
infer_pd = pd.DataFrame(columns=['Sequence','Net-Label','Is Familiar?']) |
|
|
|
mc_or_sc = 'major_class' if 'major_class' in self.config['model_config']['clf_target'] else 'sub_class' |
|
inference_config = update_config_with_inference_params(self.config,mc_or_sc=mc_or_sc,path_to_models=self.config['path_to_models']) |
|
model_path = inference_config['inference_settings']["model_path"] |
|
logger.info(f"Augmenting hico sequences based on predictions from model at: {model_path}") |
|
|
|
|
|
embedds_path = '/'.join(inference_config['inference_settings']["model_path"].split('/')[:-2])+'/embedds' |
|
|
|
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train','no_annotation']) |
|
results.get_knn_model() |
|
threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"] |
|
sequences = results.splits_df_dict['no_annotation_df'][results.seq_col].values[:,0] |
|
with redirect_stdout(None): |
|
root_dir = Path(__file__).parents[3].absolute() |
|
inference_config, net = get_model(inference_config, root_dir) |
|
|
|
original_infer_pd = pd.Series(sequences, name="Sequences").to_frame() |
|
logger.info(f'predicting sub classes for the NA set by the ID models') |
|
predicted_labels, logits,_, _,all_data, max_len, net, infer_pd = infer_from_pd(inference_config, net, original_infer_pd, SeqTokenizer) |
|
|
|
|
|
prepare_inference_results_tcga(inference_config,predicted_labels, logits, all_data, max_len) |
|
infer_pd = all_data["infere_rna_seq"] |
|
|
|
|
|
logger.info('computing levenstein distance for the NA set by the ID models') |
|
_,_,_,_,_,lev_dist = get_closest_ngbr_per_split(results,'no_annotation') |
|
|
|
logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}') |
|
infer_pd['Is Familiar?'] = [True if lv<threshold else False for lv in lev_dist] |
|
infer_pd['Threshold'] = threshold |
|
logger.info(f'num of new hico based on levenstein distance is {np.sum(infer_pd["Is Familiar?"])}') |
|
return infer_pd.rename_axis("Sequence").reset_index() |
|
|
|
def include_id_model_predictions(self): |
|
pred_df = self.predict_transforna_na() |
|
set1 = set(self.df.Labels.cat.categories) |
|
set2 = set(pred_df['Net-Label'].unique()) |
|
self.df['Labels'] = self.df['Labels'].cat.add_categories(set2.difference(set1)) |
|
|
|
|
|
familiar_seqs = pred_df[pred_df['Is Familiar?'] == True].Sequence.values |
|
|
|
tcga_familiar_seqs = set(familiar_seqs).intersection(set(self.df.index)) |
|
|
|
pred_df = pred_df[pred_df['Sequence'].isin(tcga_familiar_seqs)] |
|
|
|
self.df.loc[tcga_familiar_seqs,'Labels'] = pred_df[pred_df['Is Familiar?'] == True]['Net-Label'].values |
|
|
|
def get_augmented_df(self): |
|
self.include_id_model_predictions() |
|
return self.df |
|
|
|
class RecombinedSeqAugmenter: |
|
''' |
|
This class is used to augment the dataset with recombined sequences |
|
recombinations are done by fusing two sequences from the same subclass |
|
''' |
|
def __init__(self,df:pd.DataFrame,config:Dict): |
|
self.df = df |
|
self.config = config |
|
|
|
def create_recombined_seqs(self,class_label:str='recombined'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_labels = self.df.Labels.value_counts()[self.df.Labels.value_counts() >= 1].index.tolist() |
|
|
|
samples = [self.df[self.df['Labels'] == label].sample(1).index[0] for label in unique_labels] |
|
|
|
if len(samples) % 2 != 0: |
|
samples = samples[:-1] |
|
np.random.shuffle(samples) |
|
|
|
samples_set1 = samples[:len(samples)//2] |
|
samples_set2 = samples[len(samples)//2:] |
|
|
|
recombined_set = [] |
|
for i in range(len(samples_set1)): |
|
recombined_seq = samples_set1[i]+samples_set2[i] |
|
|
|
recombined_index = len(samples_set1[i]) |
|
|
|
offset = randint(-5,5) |
|
recombined_index += offset |
|
|
|
random_half_len = int(randint(18,30)/2) |
|
|
|
random_seq = recombined_seq[max(0,recombined_index - random_half_len):recombined_index + random_half_len] |
|
recombined_set.append(random_seq) |
|
|
|
recombined_df = pd.DataFrame(index=recombined_set, data=[f'{class_label}']*len(recombined_set)\ |
|
, columns =['Labels']) |
|
|
|
return recombined_df |
|
|
|
def get_augmented_df(self): |
|
recombined_df = self.create_recombined_seqs() |
|
return recombined_df |
|
|
|
class RandomSeqAugmenter: |
|
''' |
|
This class is used to augment the dataset with random sequences within the same length range as the tcga sequences |
|
''' |
|
def __init__(self,df:pd.DataFrame,config:Dict): |
|
self.df = df |
|
self.config = config |
|
self.num_seqs = 500 |
|
self.min_len = 18 |
|
self.max_len = 30 |
|
|
|
def get_random_seq(self): |
|
|
|
random_seqs = [] |
|
while len(random_seqs) < self.num_seqs: |
|
random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(self.min_len,self.max_len))) |
|
if random_seq not in random_seqs and random_seq not in self.df.index: |
|
random_seqs.append(random_seq) |
|
|
|
return pd.DataFrame(index=random_seqs, data=['random']*len(random_seqs)\ |
|
, columns =['Labels']) |
|
def get_augmented_df(self): |
|
random_df = self.get_random_seq() |
|
return random_df |
|
|
|
class PrecursorAugmenter: |
|
def __init__(self,df:pd.DataFrame, config:Dict): |
|
self.df = df |
|
self.config = config |
|
self.mapping_dict = load(config['train_config'].mapping_dict_path) |
|
self.precursor_df = self.load_precursor_file() |
|
self.trained_on = config.trained_on |
|
|
|
self.min_num_samples_per_sc:int=1 |
|
if self.trained_on == 'id': |
|
self.min_num_samples_per_sc = 8 |
|
|
|
self.min_bin_size = 20 |
|
self.max_bin_size = 30 |
|
self.min_seq_len = 18 |
|
self.max_seq_len = 30 |
|
|
|
def load_precursor_file(self): |
|
try: |
|
precursor_df = pd.read_csv(self.config['train_config'].precursor_file_path, index_col=0) |
|
return precursor_df |
|
except: |
|
logger.info('Could not load precursor file') |
|
return None |
|
|
|
def compute_dynamic_bin_size(self,precursor_len:int, name:str=None) -> List[int]: |
|
''' |
|
This function splits precursor to bins of size max_bin_size |
|
if the last bin is smaller than min_bin_size, it will split the precursor to bins of size max_bin_size-1 |
|
This process will continue until the last bin is larger than min_bin_size. |
|
if the min bin size is reached and still the last bin is smaller than min_bin_size, the last two bins will be merged. |
|
so the maximimum bin size possible would be min_bin_size+(min_bin_size-1) = 39 |
|
''' |
|
def split_precursor_to_bins(precursor_len,max_bin_size): |
|
''' |
|
This function splits precursor to bins of size max_bin_size |
|
''' |
|
precursor_bin_lens = [] |
|
for i in range(0, precursor_len, max_bin_size): |
|
if i+max_bin_size < precursor_len: |
|
precursor_bin_lens.append(max_bin_size) |
|
else: |
|
precursor_bin_lens.append(precursor_len-i) |
|
return precursor_bin_lens |
|
|
|
if precursor_len < self.min_bin_size: |
|
return [precursor_len] |
|
else: |
|
precursor_bin_lens = split_precursor_to_bins(precursor_len,self.max_bin_size) |
|
reduced_len = self.max_bin_size-1 |
|
while precursor_bin_lens[-1] < self.min_bin_size: |
|
precursor_bin_lens = split_precursor_to_bins(precursor_len,reduced_len) |
|
reduced_len -= 1 |
|
if reduced_len < self.min_bin_size: |
|
|
|
precursor_bin_lens[-2] += precursor_bin_lens[-1] |
|
precursor_bin_lens = precursor_bin_lens[:-1] |
|
break |
|
|
|
return precursor_bin_lens |
|
|
|
def get_bin_with_max_overlap(self,precursor_len:int,start_frag_pos:int,frag_len:int,name) -> int: |
|
''' |
|
This function returns the bin number of a fragment that overlaps the most with the fragment |
|
''' |
|
precursor_bin_lens = self.compute_dynamic_bin_size(precursor_len=precursor_len,name=name) |
|
bin_no = 0 |
|
for i,bin_len in enumerate(precursor_bin_lens): |
|
if start_frag_pos < bin_len: |
|
|
|
overlap = min(bin_len-start_frag_pos,frag_len) |
|
|
|
if overlap > frag_len/2: |
|
bin_no = i |
|
else: |
|
bin_no = i+1 |
|
break |
|
|
|
else: |
|
start_frag_pos -= bin_len |
|
return bin_no+1 |
|
|
|
def get_precursor_info(self,mc:str,sc:str): |
|
|
|
xRNA_df = self.precursor_df.loc[self.precursor_df.small_RNA_class_annotation == mc] |
|
xRNA_df.index = xRNA_df.index.str.replace('|','-', regex=False) |
|
prec_name = sc.split('_bin-')[0] |
|
|
|
if mc in ['snoRNA','lncRNA','protein_coding','miscRNA']: |
|
prec_name = mc+'-'+prec_name |
|
prec_row_df = xRNA_df.iloc[xRNA_df.index.str.contains(prec_name)] |
|
|
|
if prec_row_df.empty: |
|
xRNA_df = self.precursor_df.loc[self.precursor_df.small_RNA_class_annotation == 'pseudo_'+mc] |
|
xRNA_df.index = xRNA_df.index.str.replace('|','-', regex=False) |
|
prec_row_df = xRNA_df.iloc[xRNA_df.index.str.contains(prec_name)] |
|
if prec_row_df.empty: |
|
logger.info(f'precursor {prec_name} not found in HBDxBase') |
|
return pd.DataFrame() |
|
|
|
prec_row_df = prec_row_df.iloc[0] |
|
else: |
|
prec_row_df = xRNA_df.loc[f'{mc}-{prec_name}'] |
|
|
|
precursor = prec_row_df.sequence |
|
return precursor,prec_name |
|
|
|
def populate_from_bin(self,sc:str,precursor:str,prec_name:str,existing_seqs:List[str]): |
|
''' |
|
This function will first get the bin no from the sc. |
|
Then it will do three types of sampling: |
|
1. sample from the previous bin, insuring that the overlap with the middle bin is the highest |
|
2. sample from the next bin, insuring that the overlap with the middle bin is the highest |
|
3. sample from the middle bin, insuring that the overlap with the middle bin is the highest |
|
The staet idx should be the middle position of the previous bin, then start position is incremented until the end of the current bin |
|
''' |
|
bin_no = int(sc.split('_bin-')[1]) |
|
bins = self.compute_dynamic_bin_size(len(precursor), prec_name) |
|
if len(bins) == 1: |
|
return pd.DataFrame() |
|
|
|
|
|
bin_no -= 1 |
|
|
|
|
|
try: |
|
previous_bin_start = sum(bins[:bin_no-1]) |
|
except: |
|
previous_bin_start = 0 |
|
middle_bin_start = sum(bins[:bin_no]) |
|
next_bin_start = sum(bins[:bin_no+1]) |
|
|
|
|
|
try: |
|
previous_bin_size = bins[bin_no-1] |
|
except: |
|
previous_bin_size = 0 |
|
|
|
middle_bin_size = bins[bin_no] |
|
try: |
|
next_bin_size = bins[bin_no+1] |
|
except: |
|
next_bin_size = 0 |
|
|
|
|
|
start_idx = previous_bin_start + previous_bin_size//2 + 1 |
|
sampled_seqs = [] |
|
|
|
while start_idx < middle_bin_start+middle_bin_size: |
|
|
|
if start_idx < middle_bin_start: |
|
max_overlap_prev = middle_bin_start - start_idx |
|
end_idx = start_idx + randint(max(self.min_seq_len,max_overlap_prev*2+1),self.max_seq_len) |
|
else: |
|
max_overlap_curr = next_bin_start - start_idx |
|
max_overlap_next = (start_idx + self.max_seq_len) - next_bin_start |
|
max_overlap_next = min(max_overlap_next,next_bin_size) |
|
if max_overlap_curr <= 9 or (max_overlap_next==0 and max_overlap_curr < self.min_seq_len): |
|
end_idx = -1 |
|
else: |
|
end_idx = start_idx + randint(self.min_seq_len,min(self.max_seq_len,self.max_seq_len - max_overlap_next + max_overlap_curr - 1)) |
|
|
|
|
|
if end_idx == -1: |
|
break |
|
|
|
tmp_seq = precursor[start_idx:end_idx] |
|
|
|
assert len(tmp_seq) >= self.min_seq_len and len(tmp_seq) <= self.max_seq_len, f'length of tmp_seq is {len(tmp_seq)}' |
|
if tmp_seq not in existing_seqs: |
|
sampled_seqs.append(tmp_seq) |
|
start_idx += 1 |
|
|
|
|
|
for frag in sampled_seqs: |
|
all_occ = precursor.find(frag) |
|
if not isinstance(all_occ,list): |
|
all_occ = [all_occ] |
|
|
|
for occ in all_occ: |
|
curr_bin_no = self.get_bin_with_max_overlap(len(precursor),occ,len(frag),' ') |
|
|
|
if abs(curr_bin_no - (bin_no+1)) > 1: |
|
continue |
|
assert curr_bin_no == bin_no+1, f'curr_bin_no is {curr_bin_no} and bin_no is {bin_no+1}' |
|
|
|
return pd.DataFrame(index=sampled_seqs, data=[sc]*len(sampled_seqs)\ |
|
, columns =['Labels']) |
|
|
|
def populate_scs_with_bins(self): |
|
augmented_df = pd.DataFrame(columns=['Labels']) |
|
|
|
|
|
unique_labels = self.df.Labels.value_counts()[self.df.Labels.value_counts() >= self.min_num_samples_per_sc].index.tolist() |
|
scs_list = [] |
|
scs_before = [] |
|
sc_after = [] |
|
for sc in unique_labels: |
|
|
|
if type(sc) == str and '_bin-' in sc: |
|
|
|
try: |
|
mc = self.mapping_dict[sc] |
|
except: |
|
sc_mc_mapper = lambda x: 'miRNA' if 'miR' in x else 'tRNA' if 'tRNA' in x else 'rRNA' if 'rRNA' in x else 'snRNA' if 'snRNA' in x else 'snoRNA' if 'snoRNA' in x else 'snoRNA' if 'SNO' in x else 'protein_coding' if 'RPL37A' in x else 'lncRNA' if 'SNHG1' in x else None |
|
mc = sc_mc_mapper(sc) |
|
if mc is None: |
|
logger.info(f'No mapping for {sc}') |
|
continue |
|
existing_seqs = self.df[self.df['Labels'] == sc].index |
|
scs_list.append(sc) |
|
scs_before.append(len(existing_seqs)) |
|
|
|
precursor,prec_name = self.get_precursor_info(mc,sc) |
|
sc2_df = self.populate_from_bin(sc,precursor,prec_name,existing_seqs) |
|
augmented_df = augmented_df.append(sc2_df) |
|
sc_after.append(len(sc2_df)) |
|
|
|
scs_dict = {'sub_class':scs_list,'Number of samples before':scs_before,'Number of samples afrer':sc_after} |
|
scs_df = pd.DataFrame(scs_dict) |
|
scs_df.to_csv(f'frequency_per_sub_class_df.csv') |
|
|
|
|
|
return augmented_df |
|
|
|
def get_augmented_df(self): |
|
return self.populate_scs_with_bins() |
|
|
|
class DataAugmenter: |
|
''' |
|
This class sets the labels of the dataset to major class or sub class labels based on the clf_target |
|
major class: miRNA, tRNA ... |
|
sub class: mir-192-3p, rRNA-bin-30 ... |
|
Then if the models should be tained on ID models, it will augment the dataset with sequences sampled from the precursor file |
|
If the models should be trained on full, it will augment the dataset based on the following: |
|
1. Random sequences |
|
2. Recombined sequences |
|
3. Sequences sampled from the precursor file |
|
4. predictions of the sequences that previously had no annotation of low confidence but were predicted to be familiar by the ID models |
|
''' |
|
def __init__(self,df:pd.DataFrame, config:Dict): |
|
self.df = df |
|
self.config = config |
|
self.mapping_dict = load(config['train_config'].mapping_dict_path) |
|
self.trained_on = config.trained_on |
|
self.clf_target = config['model_config'].clf_target |
|
logger.info(f'Augmenting the dataset for {self.clf_target}') |
|
self.set_labels() |
|
|
|
self.precursor_augmenter = PrecursorAugmenter(self.df,self.config) |
|
self.random_augmenter = RandomSeqAugmenter(self.df,self.config) |
|
self.recombined_augmenter = RecombinedSeqAugmenter(self.df,self.config) |
|
self.id_model_augmenter = IDModelAugmenter(self.df,self.config) |
|
|
|
|
|
|
|
def set_labels(self): |
|
if 'hico' not in self.clf_target: |
|
self.df['Labels'] = self.df['subclass_name'].str.split(';', expand=True)[0] |
|
else: |
|
self.df['Labels'] = self.df['subclass_name'][self.df['hico'] == True] |
|
|
|
self.df['Labels'] = self.df['Labels'].astype('category') |
|
|
|
|
|
def convert_to_major_class_labels(self): |
|
if 'major_class' in self.clf_target: |
|
self.df['Labels'] = self.df['Labels'].map(self.mapping_dict).astype('category') |
|
|
|
self.df = self.df[~self.df['Labels'].str.contains(';').fillna(False)] |
|
|
|
|
|
def combine_df(self,new_var_df:pd.DataFrame): |
|
|
|
duplicated_df = new_var_df[new_var_df.index.isin(self.df.index)] |
|
|
|
if len(duplicated_df): |
|
logger.info(f'Number of duplicated sequences to be removed augmented data: {duplicated_df.shape[0]}') |
|
|
|
new_var_df = new_var_df[~new_var_df.index.isin(self.df.index)].sample(frac=1) |
|
|
|
for col in self.df.columns: |
|
if col not in new_var_df.columns: |
|
new_var_df[col] = np.nan |
|
|
|
self.df = new_var_df.append(self.df) |
|
self.df.index = self.df.index.str.upper() |
|
self.df.Labels = self.df.Labels.astype('category') |
|
return self.df |
|
|
|
|
|
def annotate_artificial_affix_seqs(self): |
|
|
|
aa_seqs = self.df[self.df['five_prime_adapter_filter'] == 0].index.tolist() |
|
self.df['Labels'] = self.df['Labels'].cat.add_categories('artificial_affix') |
|
self.df.loc[aa_seqs,'Labels'] = 'artificial_affix' |
|
|
|
|
|
|
|
def full_pipeline(self): |
|
self.df = self.id_model_augmenter.get_augmented_df() |
|
|
|
|
|
def post_augmentation(self): |
|
random_df = self.random_augmenter.get_augmented_df() |
|
|
|
if 'sub_class' in self.clf_target: |
|
df = self.precursor_augmenter.get_augmented_df() |
|
else: |
|
df = pd.DataFrame() |
|
recombined_df = self.recombined_augmenter.get_augmented_df() |
|
df = df.append(recombined_df).append(random_df) |
|
self.df['Labels'] = self.df['Labels'].cat.add_categories({'random','recombined'}) |
|
self.combine_df(df) |
|
|
|
self.convert_to_major_class_labels() |
|
self.annotate_artificial_affix_seqs() |
|
self.df['Labels'] = self.df['Labels'].cat.remove_unused_categories() |
|
self.df['Sequences'] = self.df.index.tolist() |
|
|
|
if 'struct' in self.config['model_config'].model_input: |
|
self.df['Secondary'] = fold_sequences(self.df.index.tolist(),temperature=37)[f'structure_37'].values |
|
|
|
return self.df |
|
|
|
def get_augmented_df(self): |
|
if self.trained_on == 'full': |
|
self.full_pipeline() |
|
return self.post_augmentation() |