Spaces:
Running
Running
supercat666
commited on
Commit
•
dc94424
1
Parent(s):
73dcc35
add cas12
Browse files- app.py +101 -3
- cas12.py +175 -0
- cas12_model/Seq_deepCpf1_weights.h5 +3 -0
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import tiger
|
3 |
import cas9on
|
4 |
import cas9off
|
|
|
5 |
import pandas as pd
|
6 |
import streamlit as st
|
7 |
import plotly.graph_objs as go
|
@@ -18,6 +19,7 @@ CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
|
|
18 |
|
19 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
20 |
cas9on_path = 'cas9_model/on-cla.h5'
|
|
|
21 |
|
22 |
@st.cache_data
|
23 |
def convert_df(df):
|
@@ -287,9 +289,105 @@ if selected_model == 'Cas9':
|
|
287 |
st.experimental_rerun()
|
288 |
|
289 |
elif selected_model == 'Cas12':
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
elif selected_model == 'Cas13d':
|
294 |
ENTRY_METHODS = dict(
|
295 |
manual='Manual entry of single transcript',
|
|
|
2 |
import tiger
|
3 |
import cas9on
|
4 |
import cas9off
|
5 |
+
import cas12
|
6 |
import pandas as pd
|
7 |
import streamlit as st
|
8 |
import plotly.graph_objs as go
|
|
|
19 |
|
20 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
21 |
cas9on_path = 'cas9_model/on-cla.h5'
|
22 |
+
cas12_path = 'cas12_model/Seq_deepCpf1_weights.h5'
|
23 |
|
24 |
@st.cache_data
|
25 |
def convert_df(df):
|
|
|
289 |
st.experimental_rerun()
|
290 |
|
291 |
elif selected_model == 'Cas12':
|
292 |
+
# Gene symbol entry
|
293 |
+
gene_symbol = st.text_input('Enter a Gene Symbol:', key='gene_symbol')
|
294 |
+
|
295 |
+
# Initialize the current_gene_symbol in the session state if it doesn't exist
|
296 |
+
if 'current_gene_symbol' not in st.session_state:
|
297 |
+
st.session_state['current_gene_symbol'] = ""
|
298 |
+
|
299 |
+
# Prediction button
|
300 |
+
predict_button = st.button('Predict on-target')
|
301 |
+
|
302 |
+
# Function to clean up old files
|
303 |
+
def clean_up_old_files(gene_symbol):
|
304 |
+
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
|
305 |
+
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
|
306 |
+
if os.path.exists(genbank_file_path):
|
307 |
+
os.remove(genbank_file_path)
|
308 |
+
if os.path.exists(bed_file_path):
|
309 |
+
os.remove(bed_file_path)
|
310 |
+
|
311 |
+
# Clean up files if a new gene symbol is entered
|
312 |
+
if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
|
313 |
+
clean_up_old_files(st.session_state['current_gene_symbol'])
|
314 |
+
|
315 |
+
# Process predictions
|
316 |
+
if predict_button and gene_symbol:
|
317 |
+
# Update the current gene symbol
|
318 |
+
st.session_state['current_gene_symbol'] = gene_symbol
|
319 |
+
|
320 |
+
# Run the prediction process
|
321 |
+
predictions, gene_sequence = cas12.process_gene(gene_symbol,cas12_path)
|
322 |
+
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
|
323 |
+
st.session_state['on_target_results'] = sorted_predictions
|
324 |
+
|
325 |
+
# Visualization and file generation
|
326 |
+
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
|
327 |
+
df = pd.DataFrame(st.session_state['on_target_results'],
|
328 |
+
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
|
329 |
+
|
330 |
+
# Now create a Plotly plot with the sorted_predictions
|
331 |
+
fig = go.Figure()
|
332 |
+
|
333 |
+
# Iterate over the sorted predictions to create the plot
|
334 |
+
for i, prediction in enumerate(sorted_predictions, start=1):
|
335 |
+
# Extract data for plotting
|
336 |
+
chrom, start, end, strand, gRNA, pred_score = prediction
|
337 |
+
# Strand is not used in this plot, but you could use it to determine marker symbol, for example
|
338 |
+
fig.add_trace(go.Scatter(
|
339 |
+
x=[start, end],
|
340 |
+
y=[i, i], # Y-values are just the rank of the prediction
|
341 |
+
mode='lines+markers+text',
|
342 |
+
name=f"gRNA: {gRNA}",
|
343 |
+
text=[f"Rank: {i}", ""], # Text at the start position only
|
344 |
+
hoverinfo='text',
|
345 |
+
hovertext=[
|
346 |
+
f"Rank: {i}<br>Chromosome: {chrom}<br>Target: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
|
347 |
+
""
|
348 |
+
],
|
349 |
+
))
|
350 |
+
# Update the layout of the plot
|
351 |
+
fig.update_layout(
|
352 |
+
title='Top 10 gRNA Sequences by Prediction Score',
|
353 |
+
xaxis_title='Genomic Position',
|
354 |
+
yaxis_title='Rank',
|
355 |
+
yaxis=dict(showticklabels=False)
|
356 |
+
# We hide the y-axis labels since the rank is indicated in the hovertext
|
357 |
+
)
|
358 |
+
# Display the plot
|
359 |
+
st.plotly_chart(fig)
|
360 |
+
|
361 |
+
# Ensure gene_sequence is not empty before generating files
|
362 |
+
if gene_sequence:
|
363 |
+
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
|
364 |
+
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
|
365 |
+
|
366 |
+
# Generate GenBank file
|
367 |
+
cas12.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
|
368 |
+
|
369 |
+
# Generate BED file
|
370 |
+
cas12.create_bed_file_from_df(df, bed_file_path)
|
371 |
+
|
372 |
+
st.write('Top on-target predictions:')
|
373 |
+
st.dataframe(df)
|
374 |
+
|
375 |
+
# Download buttons
|
376 |
+
with open(genbank_file_path, "rb") as file:
|
377 |
+
st.download_button(
|
378 |
+
label="Download GenBank File",
|
379 |
+
data=file,
|
380 |
+
file_name=genbank_file_path,
|
381 |
+
mime="text/x-genbank"
|
382 |
+
)
|
383 |
+
|
384 |
+
with open(bed_file_path, "rb") as file:
|
385 |
+
st.download_button(label="Download BED File", data=file,
|
386 |
+
file_name=bed_file_path, mime="text/plain")
|
387 |
+
|
388 |
+
# Clean up old files after download buttons are created
|
389 |
+
clean_up_old_files(gene_symbol)
|
390 |
+
|
391 |
elif selected_model == 'Cas13d':
|
392 |
ENTRY_METHODS = dict(
|
393 |
manual='Manual entry of single transcript',
|
cas12.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from keras import Model
|
2 |
+
from keras.layers import Input
|
3 |
+
from keras.layers import Multiply
|
4 |
+
from keras.layers import Dense, Dropout, Activation, Flatten
|
5 |
+
from keras.layers import Convolution1D, AveragePooling1D
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import keras
|
9 |
+
import requests
|
10 |
+
from functools import reduce
|
11 |
+
from operator import add
|
12 |
+
from Bio.SeqRecord import SeqRecord
|
13 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
14 |
+
from Bio.Seq import Seq
|
15 |
+
from Bio import SeqIO
|
16 |
+
|
17 |
+
ntmap = {'A': (1, 0, 0, 0),
|
18 |
+
'C': (0, 1, 0, 0),
|
19 |
+
'G': (0, 0, 1, 0),
|
20 |
+
'T': (0, 0, 0, 1)
|
21 |
+
}
|
22 |
+
|
23 |
+
def get_seqcode(seq):
|
24 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
25 |
+
|
26 |
+
def Seq_DeepCpf1_model(input_shape):
|
27 |
+
Seq_deepCpf1_Input_SEQ = Input(shape=input_shape)
|
28 |
+
Seq_deepCpf1_C1 = Convolution1D(80, 5, activation='relu')(Seq_deepCpf1_Input_SEQ)
|
29 |
+
Seq_deepCpf1_P1 = AveragePooling1D(2)(Seq_deepCpf1_C1)
|
30 |
+
Seq_deepCpf1_F = Flatten()(Seq_deepCpf1_P1)
|
31 |
+
Seq_deepCpf1_DO1 = Dropout(0.3)(Seq_deepCpf1_F)
|
32 |
+
Seq_deepCpf1_D1 = Dense(80, activation='relu')(Seq_deepCpf1_DO1)
|
33 |
+
Seq_deepCpf1_DO2 = Dropout(0.3)(Seq_deepCpf1_D1)
|
34 |
+
Seq_deepCpf1_D2 = Dense(40, activation='relu')(Seq_deepCpf1_DO2)
|
35 |
+
Seq_deepCpf1_DO3 = Dropout(0.3)(Seq_deepCpf1_D2)
|
36 |
+
Seq_deepCpf1_D3 = Dense(40, activation='relu')(Seq_deepCpf1_DO3)
|
37 |
+
Seq_deepCpf1_DO4 = Dropout(0.3)(Seq_deepCpf1_D3)
|
38 |
+
Seq_deepCpf1_Output = Dense(1, activation='linear')(Seq_deepCpf1_DO4)
|
39 |
+
Seq_deepCpf1 = Model(inputs=[Seq_deepCpf1_Input_SEQ], outputs=[Seq_deepCpf1_Output])
|
40 |
+
return Seq_deepCpf1
|
41 |
+
|
42 |
+
# seq-ca model (DeepCpf1)
|
43 |
+
def DeepCpf1_model(input_shape):
|
44 |
+
DeepCpf1_Input_SEQ = Input(shape=input_shape)
|
45 |
+
DeepCpf1_C1 = Convolution1D(80, 5, activation='relu')(DeepCpf1_Input_SEQ)
|
46 |
+
DeepCpf1_P1 = AveragePooling1D(2)(DeepCpf1_C1)
|
47 |
+
DeepCpf1_F = Flatten()(DeepCpf1_P1)
|
48 |
+
DeepCpf1_DO1 = Dropout(0.3)(DeepCpf1_F)
|
49 |
+
DeepCpf1_D1 = Dense(80, activation='relu')(DeepCpf1_DO1)
|
50 |
+
DeepCpf1_DO2 = Dropout(0.3)(DeepCpf1_D1)
|
51 |
+
DeepCpf1_D2 = Dense(40, activation='relu')(DeepCpf1_DO2)
|
52 |
+
DeepCpf1_DO3 = Dropout(0.3)(DeepCpf1_D2)
|
53 |
+
DeepCpf1_D3_SEQ = Dense(40, activation='relu')(DeepCpf1_DO3)
|
54 |
+
DeepCpf1_Input_CA = Input(shape=(1,))
|
55 |
+
DeepCpf1_D3_CA = Dense(40, activation='relu')(DeepCpf1_Input_CA)
|
56 |
+
DeepCpf1_M = Multiply()([DeepCpf1_D3_SEQ, DeepCpf1_D3_CA])
|
57 |
+
DeepCpf1_DO4 = Dropout(0.3)(DeepCpf1_M)
|
58 |
+
DeepCpf1_Output = Dense(1, activation='linear')(DeepCpf1_DO4)
|
59 |
+
DeepCpf1 = Model(inputs=[DeepCpf1_Input_SEQ, DeepCpf1_Input_CA], outputs=[DeepCpf1_Output])
|
60 |
+
return DeepCpf1
|
61 |
+
|
62 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
63 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
64 |
+
response = requests.get(url)
|
65 |
+
if response.status_code == 200:
|
66 |
+
gene_data = response.json()
|
67 |
+
if 'Transcript' in gene_data:
|
68 |
+
return gene_data['Transcript']
|
69 |
+
else:
|
70 |
+
print("No transcripts found for gene:", gene_symbol)
|
71 |
+
return None
|
72 |
+
else:
|
73 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
74 |
+
return None
|
75 |
+
|
76 |
+
def fetch_ensembl_sequence(transcript_id):
|
77 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
78 |
+
response = requests.get(url)
|
79 |
+
if response.status_code == 200:
|
80 |
+
sequence_data = response.json()
|
81 |
+
if 'seq' in sequence_data:
|
82 |
+
return sequence_data['seq']
|
83 |
+
else:
|
84 |
+
print("No sequence found for transcript:", transcript_id)
|
85 |
+
return None
|
86 |
+
else:
|
87 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
88 |
+
return None
|
89 |
+
|
90 |
+
|
91 |
+
def find_crispr_targets(sequence, chr, start, strand, pam="TTTN", target_length=34):
|
92 |
+
targets = []
|
93 |
+
len_sequence = len(sequence)
|
94 |
+
|
95 |
+
for i in range(len_sequence - target_length + 1):
|
96 |
+
target_seq = sequence[i:i + target_length]
|
97 |
+
if target_seq[4:7] == 'TTT':
|
98 |
+
tar_start = start + i
|
99 |
+
tar_end = start + i + target_length
|
100 |
+
gRNA = target_seq[8:28]
|
101 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand)])
|
102 |
+
return targets
|
103 |
+
|
104 |
+
def format_prediction_output(targets, seq_deepCpf1):
|
105 |
+
formatted_data = []
|
106 |
+
for target in targets:
|
107 |
+
# Predict
|
108 |
+
encoded_seq = get_seqcode(target[0]) # 'target' seems to be the full sequence including PAM
|
109 |
+
prediction = seq_deepCpf1.predict(encoded_seq)
|
110 |
+
# Format output
|
111 |
+
gRNA = target[1] # gRNA is presumably the guide RNA sequence
|
112 |
+
chr = target[2] # Chromosome
|
113 |
+
start = target[3] # Start position
|
114 |
+
end = target[4] # End position
|
115 |
+
strand = target[5] # Strand
|
116 |
+
target_seq = target[0] # Full target sequence including PAM
|
117 |
+
formatted_data.append([chr, start, end, strand, target_seq, gRNA, prediction[0][0]])
|
118 |
+
return formatted_data
|
119 |
+
|
120 |
+
def process_gene(gene_symbol, model_path):
|
121 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
122 |
+
all_data = []
|
123 |
+
gene_sequence = '' # Initialize an empty string for the gene sequence
|
124 |
+
|
125 |
+
# Load the model
|
126 |
+
seq_deepCpf1 = Seq_DeepCpf1_model(input_shape=(34, 4))
|
127 |
+
seq_deepCpf1.load_weights(model_path)
|
128 |
+
|
129 |
+
if transcripts:
|
130 |
+
for transcript in transcripts:
|
131 |
+
transcript_id = transcript['id']
|
132 |
+
chr = transcript.get('seq_region_name', 'unknown')
|
133 |
+
start = transcript.get('start', 0)
|
134 |
+
strand = transcript.get('strand', 'unknown')
|
135 |
+
# Fetch the sequence here and concatenate if multiple transcripts
|
136 |
+
gene_sequence += fetch_ensembl_sequence(transcript_id) or ''
|
137 |
+
|
138 |
+
if gene_sequence:
|
139 |
+
targets = find_crispr_targets(gene_sequence, chr, start, strand)
|
140 |
+
if targets:
|
141 |
+
formatted_data = format_prediction_output(targets, seq_deepCpf1)
|
142 |
+
all_data.extend(formatted_data)
|
143 |
+
else:
|
144 |
+
print("Failed to retrieve transcripts.")
|
145 |
+
|
146 |
+
return all_data, gene_sequence
|
147 |
+
|
148 |
+
def create_genbank_features(formatted_data):
|
149 |
+
features = []
|
150 |
+
for data in formatted_data:
|
151 |
+
location = FeatureLocation(start=int(data[1]), end=int(data[2]), strand=(1 if data[3] == '+' else -1))
|
152 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
153 |
+
'label': data[5], # gRNA as label
|
154 |
+
'note': f"Prediction: {data[6]}" # Prediction score in note
|
155 |
+
})
|
156 |
+
features.append(feature)
|
157 |
+
return features
|
158 |
+
|
159 |
+
def generate_genbank_file_from_data(formatted_data, gene_sequence, gene_symbol, output_path):
|
160 |
+
features = create_genbank_features(formatted_data)
|
161 |
+
record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
|
162 |
+
description='CRISPR Cas12 predicted targets', features=features)
|
163 |
+
record.annotations["molecule_type"] = "DNA"
|
164 |
+
SeqIO.write(record, output_path, "genbank")
|
165 |
+
|
166 |
+
def generate_bed_file_from_data(formatted_data, output_path):
|
167 |
+
with open(output_path, 'w') as bed_file:
|
168 |
+
for data in formatted_data:
|
169 |
+
chrom = data[0]
|
170 |
+
start = data[1]
|
171 |
+
end = data[2]
|
172 |
+
strand = data[3]
|
173 |
+
gRNA = data[5]
|
174 |
+
score = data[6]
|
175 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
cas12_model/Seq_deepCpf1_weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c52c1f93169ea1da55d4cb464f4d948551b9aeafb9ee47dc55fa76e23486526d
|
3 |
+
size 1285864
|