CRISPRTool / cas9on.py
supercat666's picture
added cas9 off
0d0c645
raw
history blame
2.68 kB
import tensorflow as tf
import pandas as pd
import numpy as np
from operator import add
from functools import reduce
# configure GPUs
for gpu in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, enable=True)
if len(tf.config.list_physical_devices('GPU')) > 0:
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
ntmap = {'A': (1, 0, 0, 0),
'C': (0, 1, 0, 0),
'G': (0, 0, 1, 0),
'T': (0, 0, 0, 1)
}
epimap = {'A': 1, 'N': 0}
def get_seqcode(seq):
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
(1, len(seq), -1))
def get_epicode(eseq):
return np.array(list(map(lambda c: epimap[c], eseq))).reshape(1, len(eseq), -1)
class Episgt:
def __init__(self, fpath, num_epi_features, with_y=True):
self._fpath = fpath
self._ori_df = pd.read_csv(fpath, sep='\t', index_col=None, header=None)
self._num_epi_features = num_epi_features
self._with_y = with_y
self._num_cols = num_epi_features + 2 if with_y else num_epi_features + 1
self._cols = list(self._ori_df.columns)[-self._num_cols:]
self._df = self._ori_df[self._cols]
@property
def length(self):
return len(self._df)
def get_dataset(self, x_dtype=np.float32, y_dtype=np.float32):
x_seq = np.concatenate(list(map(get_seqcode, self._df[self._cols[0]])))
x_epis = np.concatenate([np.concatenate(list(map(get_epicode, self._df[col]))) for col in
self._cols[1: 1 + self._num_epi_features]], axis=-1)
x = np.concatenate([x_seq, x_epis], axis=-1).astype(x_dtype)
x = x.transpose(0, 2, 1)
if self._with_y:
y = np.array(self._df[self._cols[-1]]).astype(y_dtype)
return x, y
else:
return x
from keras.models import load_model
class DCModelOntar:
def __init__(self, ontar_model_dir, is_reg=False):
if is_reg:
self.model = load_model(ontar_model_dir)
else:
self.model = load_model(ontar_model_dir)
def ontar_predict(self, x, channel_first=True):
if channel_first:
x = x.transpose([0, 2, 3, 1])
yp = self.model.predict(x)
return yp.ravel()
def predict():
file_path = 'eg_cls_on_target.episgt'
input_data = Episgt(file_path, num_epi_features=4, with_y=True)
x, y = input_data.get_dataset()
x = np.expand_dims(x, axis=2) # shape(x) = [100, 8, 1, 23]
dcModel = DCModelOntar('on-cla.h5')
predicted_on_target = dcModel.ontar_predict(x)
return predicted_on_target