Spaces:
Running
Running
supercat666
commited on
Commit
•
ce4236e
1
Parent(s):
a5afc1a
fixed cas9on
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ st.divider()
|
|
13 |
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
|
14 |
|
15 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
16 |
-
|
17 |
|
18 |
@st.cache_data
|
19 |
def convert_df(df):
|
@@ -92,8 +92,43 @@ if selected_model == 'Cas9':
|
|
92 |
|
93 |
# Actions based on the selected enzyme
|
94 |
if target_selection == 'on-target':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
elif target_selection == 'off-target':
|
99 |
ENTRY_METHODS = dict(
|
|
|
13 |
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
|
14 |
|
15 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
16 |
+
cas9on_path = '/cas9_model/on-cla.h5'
|
17 |
|
18 |
@st.cache_data
|
19 |
def convert_df(df):
|
|
|
92 |
|
93 |
# Actions based on the selected enzyme
|
94 |
if target_selection == 'on-target':
|
95 |
+
# app initialization for Cas9 on-target
|
96 |
+
if 'gene_symbol' not in st.session_state:
|
97 |
+
st.session_state.gene_symbol = None
|
98 |
+
if 'on_target_results' not in st.session_state:
|
99 |
+
st.session_state.on_target_results = None
|
100 |
+
|
101 |
+
# Gene symbol entry
|
102 |
+
st.text_input(
|
103 |
+
label='Enter a Gene Symbol:',
|
104 |
+
key='gene_symbol_entry',
|
105 |
+
placeholder='e.g., BRCA1'
|
106 |
+
)
|
107 |
|
108 |
+
# prediction button
|
109 |
+
if st.button('Predict on-target'):
|
110 |
+
gene_symbol = st.session_state.gene_symbol_entry
|
111 |
+
if gene_symbol: # Check if gene_symbol is not empty
|
112 |
+
predictions = cas9on.process_gene(gene_symbol, cas9on_path)
|
113 |
+
st.session_state.on_target_results = predictions[:10] # Store only first 10 for display
|
114 |
+
|
115 |
+
# on-target results display
|
116 |
+
on_target_results = st.empty()
|
117 |
+
if st.session_state.on_target_results is not None:
|
118 |
+
with on_target_results.container():
|
119 |
+
if len(st.session_state.on_target_results) > 0:
|
120 |
+
st.write('On-target predictions:', st.session_state.on_target_results)
|
121 |
+
full_predictions = cas9on.process_gene(gene_symbol, cas9on_path) # Get full predictions for download
|
122 |
+
st.download_button(
|
123 |
+
label='Download on-target predictions',
|
124 |
+
data=cas9on.convert_df(full_predictions),
|
125 |
+
file_name='on_target_results.csv',
|
126 |
+
mime='text/csv'
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
st.write('No significant on-target effects detected!')
|
130 |
+
else:
|
131 |
+
on_target_results.empty()
|
132 |
|
133 |
elif target_selection == 'off-target':
|
134 |
ENTRY_METHODS = dict(
|
cas9on.py
CHANGED
@@ -1,8 +1,11 @@
|
|
|
|
1 |
import tensorflow as tf
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
from operator import add
|
5 |
from functools import reduce
|
|
|
|
|
6 |
|
7 |
# configure GPUs
|
8 |
for gpu in tf.config.list_physical_devices('GPU'):
|
@@ -18,7 +21,6 @@ ntmap = {'A': (1, 0, 0, 0),
|
|
18 |
}
|
19 |
epimap = {'A': 1, 'N': 0}
|
20 |
|
21 |
-
|
22 |
def get_seqcode(seq):
|
23 |
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
|
24 |
(1, len(seq), -1))
|
@@ -54,13 +56,9 @@ class Episgt:
|
|
54 |
return x
|
55 |
|
56 |
from keras.models import load_model
|
57 |
-
|
58 |
class DCModelOntar:
|
59 |
def __init__(self, ontar_model_dir, is_reg=False):
|
60 |
-
|
61 |
-
self.model = load_model(ontar_model_dir)
|
62 |
-
else:
|
63 |
-
self.model = load_model(ontar_model_dir)
|
64 |
|
65 |
def ontar_predict(self, x, channel_first=True):
|
66 |
if channel_first:
|
@@ -68,11 +66,96 @@ class DCModelOntar:
|
|
68 |
yp = self.model.predict(x)
|
69 |
return yp.ravel()
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
import tensorflow as tf
|
3 |
import pandas as pd
|
4 |
import numpy as np
|
5 |
from operator import add
|
6 |
from functools import reduce
|
7 |
+
from keras.models import load_model
|
8 |
+
import random
|
9 |
|
10 |
# configure GPUs
|
11 |
for gpu in tf.config.list_physical_devices('GPU'):
|
|
|
21 |
}
|
22 |
epimap = {'A': 1, 'N': 0}
|
23 |
|
|
|
24 |
def get_seqcode(seq):
|
25 |
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
|
26 |
(1, len(seq), -1))
|
|
|
56 |
return x
|
57 |
|
58 |
from keras.models import load_model
|
|
|
59 |
class DCModelOntar:
|
60 |
def __init__(self, ontar_model_dir, is_reg=False):
|
61 |
+
self.model = load_model(ontar_model_dir)
|
|
|
|
|
|
|
62 |
|
63 |
def ontar_predict(self, x, channel_first=True):
|
64 |
if channel_first:
|
|
|
66 |
yp = self.model.predict(x)
|
67 |
return yp.ravel()
|
68 |
|
69 |
+
# Function to generate random epigenetic data
|
70 |
+
def generate_random_epigenetic_data(length):
|
71 |
+
return ''.join(random.choice('AN') for _ in range(length))
|
72 |
+
|
73 |
+
# Function to predict on-target efficiency and format output
|
74 |
+
def format_prediction_output(gRNA_sites, gene_id, model_path):
|
75 |
+
dcModel = DCModelOntar(model_path)
|
76 |
+
formatted_data = []
|
77 |
+
|
78 |
+
for gRNA in gRNA_sites:
|
79 |
+
# Encode the gRNA sequence
|
80 |
+
encoded_seq = get_seqcode(gRNA).reshape(-1,4,1,23)
|
81 |
+
#encoded_seq = np.expand_dims(encoded_seq, axis=2) # Adjust the shape for the model
|
82 |
+
|
83 |
+
# Generate random epigenetic features (as placeholders)
|
84 |
+
ctcf = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
|
85 |
+
dnase = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
|
86 |
+
h3k4me3 = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
|
87 |
+
rrbs = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
|
88 |
+
|
89 |
+
# Predict on-target efficiency using the model
|
90 |
+
input = np.concatenate((encoded_seq, ctcf, dnase, h3k4me3, rrbs), axis=1)
|
91 |
+
prediction = dcModel.ontar_predict(input)
|
92 |
+
|
93 |
+
# Format output
|
94 |
+
formatted_data.append([gene_id, "start_pos", "end_pos", "strand", gRNA, ctcf, dnase, h3k4me3, rrbs, prediction[0]])
|
95 |
+
|
96 |
+
return formatted_data
|
97 |
+
|
98 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
99 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
100 |
+
response = requests.get(url)
|
101 |
+
if response.status_code == 200:
|
102 |
+
gene_data = response.json()
|
103 |
+
if 'Transcript' in gene_data:
|
104 |
+
return gene_data['Transcript']
|
105 |
+
else:
|
106 |
+
print("No transcripts found for gene:", gene_symbol)
|
107 |
+
return None
|
108 |
+
else:
|
109 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
110 |
+
return None
|
111 |
+
|
112 |
+
def fetch_ensembl_sequence(transcript_id):
|
113 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
114 |
+
response = requests.get(url)
|
115 |
+
if response.status_code == 200:
|
116 |
+
sequence_data = response.json()
|
117 |
+
if 'seq' in sequence_data:
|
118 |
+
return sequence_data['seq']
|
119 |
+
else:
|
120 |
+
print("No sequence found for transcript:", transcript_id)
|
121 |
+
return None
|
122 |
+
else:
|
123 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
124 |
+
return None
|
125 |
+
|
126 |
+
def find_crispr_targets(sequence, pam="NGG", target_length=20):
|
127 |
+
targets = []
|
128 |
+
len_sequence = len(sequence)
|
129 |
+
|
130 |
+
for i in range(len_sequence - len(pam) + 1):
|
131 |
+
if sequence[i + 1:i + 3] == pam[1:]:
|
132 |
+
if i >= target_length:
|
133 |
+
target_seq = sequence[i - target_length:i + 3]
|
134 |
+
targets.append(target_seq)
|
135 |
+
|
136 |
+
return targets
|
137 |
+
|
138 |
+
|
139 |
+
def process_gene(gene_symbol, model_path):
|
140 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
141 |
+
all_data = []
|
142 |
+
|
143 |
+
if transcripts:
|
144 |
+
for transcript in transcripts:
|
145 |
+
transcript_id = transcript['id']
|
146 |
+
gene_sequence = fetch_ensembl_sequence(transcript_id)
|
147 |
+
if gene_sequence:
|
148 |
+
gRNA_sites = find_crispr_targets(gene_sequence)
|
149 |
+
if gRNA_sites:
|
150 |
+
formatted_data = format_prediction_output(gRNA_sites, transcript_id, model_path)
|
151 |
+
all_data.extend(formatted_data)
|
152 |
+
|
153 |
+
return all_data
|
154 |
+
|
155 |
+
|
156 |
+
# Function to save results as CSV
|
157 |
+
def save_to_csv(data, filename="crispr_results.csv"):
|
158 |
+
df = pd.DataFrame(data,
|
159 |
+
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "CTCF", "Dnase", "H3K4me3", "RRBS",
|
160 |
+
"Prediction"])
|
161 |
+
df.to_csv(filename, index=False)
|