CRISPRTool / app.py
supercat666's picture
fixed bugs
ba43ebe
raw
history blame
No virus
23.8 kB
import os
import tiger
import cas9on
import cas9off
import cas12
import pandas as pd
import streamlit as st
import plotly.graph_objs as go
from pygenomeviz import Genbank, GenomeViz
import numpy as np
from pathlib import Path
# title and documentation
st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True)
st.divider()
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
cas9on_path = 'cas9_model/on-cla.h5'
cas12_path = 'cas12_model/Seq_deepCpf1_weights.h5'
@st.cache_data
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv().encode('utf-8')
def mode_change_callback():
if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration
st.session_state.check_off_targets = False
st.session_state.disable_off_target_checkbox = True
else:
st.session_state.disable_off_target_checkbox = False
def progress_update(update_text, percent_complete):
with progress.container():
st.write(update_text)
st.progress(percent_complete / 100)
def initiate_run():
# initialize state variables
st.session_state.transcripts = None
st.session_state.input_error = None
st.session_state.on_target = None
st.session_state.titration = None
st.session_state.off_target = None
# initialize transcript DataFrame
transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])
# manual entry
if st.session_state.entry_method == ENTRY_METHODS['manual']:
transcripts = pd.DataFrame({
tiger.ID_COL: ['ManualEntry'],
tiger.SEQ_COL: [st.session_state.manual_entry]
}).set_index(tiger.ID_COL)
# fasta file upload
elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
if st.session_state.fasta_entry is not None:
fasta_path = st.session_state.fasta_entry.name
with open(fasta_path, 'w') as f:
f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
os.remove(fasta_path)
# convert to upper case as used by tokenizer
transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))
# ensure all transcripts have unique identifiers
if transcripts.index.has_duplicates:
st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"
# ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'
# ensure all transcripts satisfy length requirements
elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)
# run model if we have any transcripts
elif len(transcripts) > 0:
st.session_state.transcripts = transcripts
# Check if the selected model is Cas9
if selected_model == 'Cas9':
# Use a radio button to select enzymes, making sure only one can be selected at a time
target_selection = st.radio(
"Select either on-target or off-target:",
('on-target', 'off-target'),
key='target_selection'
)
if target_selection == 'on-target':
# Gene symbol entry
gene_symbol = st.text_input('Enter a Gene Symbol:', key='gene_symbol')
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
# Function to clean up old files
def clean_up_old_files(gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
if os.path.exists(genbank_file_path):
os.remove(genbank_file_path)
if os.path.exists(bed_file_path):
os.remove(bed_file_path)
if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
clean_up_old_files(st.session_state['current_gene_symbol'])
# Prediction button
predict_button = st.button('Predict on-target')
# Process predictions
if predict_button and gene_symbol:
predictions, gene_sequence = cas9on.process_gene(gene_symbol, cas9on_path)
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
st.session_state['on_target_results'] = sorted_predictions
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
# Include "Target" in the DataFrame's columns
df = pd.DataFrame(st.session_state['on_target_results'],
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
# Now create a Plotly plot with the sorted_predictions
fig = go.Figure()
# Iterate over the sorted predictions to create the plot
for i, prediction in enumerate(sorted_predictions, start=1):
# Extract data for plotting
chrom, start, end, strand, target, gRNA, pred_score = prediction # Adjusted to include the target sequence
fig.add_trace(go.Scatter(
x=[start, end],
y=[i, i], # Y-values are just the rank of the prediction
mode='lines+markers+text',
name=f"gRNA: {gRNA}",
text=[f"Rank: {i}", ""], # Text at the start position only
hoverinfo='text',
hovertext=[
f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' else '-'}<br>Prediction Score: {pred_score:.4f}",
""
],
))
# Update the layout of the plot
fig.update_layout(
title='Top 10 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis_title='Rank',
yaxis=dict(showticklabels=False)
# Hide the y-axis labels since the rank is indicated in the hovertext
)
# Display the plot
st.plotly_chart(fig)
if gene_sequence: # Ensure gene_sequence is not empty
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
cas9on.create_bed_file_from_df(df, bed_file_path)
st.write('Top on-target predictions:')
st.dataframe(df)
# Add a download button for the GenBank file
with open(genbank_file_path, "rb") as file:
st.download_button(
label="Download GenBank File",
data=file,
file_name=genbank_file_path,
mime="text/x-genbank"
)
# Download button for the BED file
with open(bed_file_path, "rb") as file:
st.download_button(label="Download BED File", data=file,
file_name=bed_file_path, mime="text/plain")
# # Visualize the GenBank file using pyGenomeViz
# gv = GenomeViz(
# feature_track_ratio=0.3,
# tick_track_ratio=0.5,
# tick_style="axis",
# )
#
# # Load the GenBank file
# gbk = Genbank(genbank_file_path)
#
# # Add a feature track to the GenomeViz object
# track = gv.add_feature_track(gbk.name, gbk.range_size)
#
# # Add all features from the GenBank file to the track
# track.add_genbank_features(gbk)
#
# # Plot the figure and display it in Streamlit
# fig = gv.plotfig()
# st.pyplot(fig)
elif target_selection == 'off-target':
ENTRY_METHODS = dict(
manual='Manual entry of target sequence',
txt="txt file upload"
)
if __name__ == '__main__':
# app initialization for Cas9 off-target
if 'target_sequence' not in st.session_state:
st.session_state.target_sequence = None
if 'input_error' not in st.session_state:
st.session_state.input_error = None
if 'off_target_results' not in st.session_state:
st.session_state.off_target_results = None
# target sequence entry
st.selectbox(
label='How would you like to provide target sequences?',
options=ENTRY_METHODS.values(),
key='entry_method',
disabled=st.session_state.target_sequence is not None
)
if st.session_state.entry_method == ENTRY_METHODS['manual']:
st.text_input(
label='Enter on/off sequences:',
key='manual_entry',
placeholder='Enter on/off sequences like:GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG',
disabled=st.session_state.target_sequence is not None
)
elif st.session_state.entry_method == ENTRY_METHODS['txt']:
st.file_uploader(
label='Upload a txt file:',
key='txt_entry',
disabled=st.session_state.target_sequence is not None
)
# prediction button
if st.button('Predict off-target'):
if st.session_state.entry_method == ENTRY_METHODS['manual']:
user_input = st.session_state.manual_entry
if user_input: # Check if user_input is not empty
predictions = cas9off.process_input_and_predict(user_input, input_type='manual')
elif st.session_state.entry_method == ENTRY_METHODS['txt']:
uploaded_file = st.session_state.txt_entry
if uploaded_file is not None:
# Read the uploaded file content
file_content = uploaded_file.getvalue().decode("utf-8")
predictions = cas9off.process_input_and_predict(file_content, input_type='manual')
st.session_state.off_target_results = predictions
else:
predictions = None
progress = st.empty()
# input error display
error = st.empty()
if st.session_state.input_error is not None:
error.error(st.session_state.input_error, icon="🚨")
else:
error.empty()
# off-target results display
off_target_results = st.empty()
if st.session_state.off_target_results is not None:
with off_target_results.container():
if len(st.session_state.off_target_results) > 0:
st.write('Off-target predictions:', st.session_state.off_target_results)
st.download_button(
label='Download off-target predictions',
data=convert_df(st.session_state.off_target_results),
file_name='off_target_results.csv',
mime='text/csv'
)
else:
st.write('No significant off-target effects detected!')
else:
off_target_results.empty()
# running the CRISPR-Net model for off-target predictions
if st.session_state.target_sequence is not None:
st.session_state.off_target_results = cas9off.predict_off_targets(
target_sequence=st.session_state.target_sequence,
status_update_fn=progress_update
)
st.session_state.target_sequence = None
st.experimental_rerun()
elif selected_model == 'Cas12':
# Gene symbol entry
gene_symbol = st.text_input('Enter a Gene Symbol:', key='gene_symbol')
# Initialize the current_gene_symbol in the session state if it doesn't exist
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
# Prediction button
predict_button = st.button('Predict on-target')
# Function to clean up old files
def clean_up_old_files(gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
if os.path.exists(genbank_file_path):
os.remove(genbank_file_path)
if os.path.exists(bed_file_path):
os.remove(bed_file_path)
# Clean up files if a new gene symbol is entered
if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
clean_up_old_files(st.session_state['current_gene_symbol'])
# Process predictions
if predict_button and gene_symbol:
# Update the current gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
# Run the prediction process
predictions, gene_sequence = cas12.process_gene(gene_symbol,cas12_path)
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
st.session_state['on_target_results'] = sorted_predictions
# Visualization and file generation
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
df = pd.DataFrame(st.session_state['on_target_results'],
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
# Now create a Plotly plot with the sorted_predictions
fig = go.Figure()
# Iterate over the sorted predictions to create the plot
for i, prediction in enumerate(sorted_predictions, start=1):
# Extract data for plotting
chrom, start, end, strand, Target, gRNA, pred_score = prediction
# Strand is not used in this plot, but you could use it to determine marker symbol, for example
fig.add_trace(go.Scatter(
x=[start, end],
y=[i, i], # Y-values are just the rank of the prediction
mode='lines+markers+text',
name=f"gRNA: {gRNA}",
text=[f"Rank: {i}", ""], # Text at the start position only
hoverinfo='text',
hovertext=[
f"Rank: {i}<br>Chromosome: {chrom}<br>Target: {Target}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
""
],
))
# Update the layout of the plot
fig.update_layout(
title='Top 10 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis_title='Rank',
yaxis=dict(showticklabels=False)
# We hide the y-axis labels since the rank is indicated in the hovertext
)
# Display the plot
st.plotly_chart(fig)
# Ensure gene_sequence is not empty before generating files
if gene_sequence:
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
# Generate GenBank file
cas12.generate_genbank_file_from_data(df, gene_sequence, gene_symbol, genbank_file_path)
# Generate BED file
cas12.generate_bed_file_from_data(df, bed_file_path)
st.write('Top on-target predictions:')
st.dataframe(df)
# Download buttons
with open(genbank_file_path, "rb") as file:
st.download_button(
label="Download GenBank File",
data=file,
file_name=genbank_file_path,
mime="text/x-genbank"
)
with open(bed_file_path, "rb") as file:
st.download_button(label="Download BED File", data=file,
file_name=bed_file_path, mime="text/plain")
# Clean up old files after download buttons are created
clean_up_old_files(gene_symbol)
elif selected_model == 'Cas13d':
ENTRY_METHODS = dict(
manual='Manual entry of single transcript',
fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
)
if __name__ == '__main__':
# app initialization
if 'mode' not in st.session_state:
st.session_state.mode = tiger.RUN_MODES['all']
st.session_state.disable_off_target_checkbox = True
if 'entry_method' not in st.session_state:
st.session_state.entry_method = ENTRY_METHODS['manual']
if 'transcripts' not in st.session_state:
st.session_state.transcripts = None
if 'input_error' not in st.session_state:
st.session_state.input_error = None
if 'on_target' not in st.session_state:
st.session_state.on_target = None
if 'titration' not in st.session_state:
st.session_state.titration = None
if 'off_target' not in st.session_state:
st.session_state.off_target = None
# mode selection
col1, col2 = st.columns([0.65, 0.35])
with col1:
st.radio(
label='What do you want to predict?',
options=tuple(tiger.RUN_MODES.values()),
key='mode',
on_change=mode_change_callback,
disabled=st.session_state.transcripts is not None,
)
with col2:
st.checkbox(
label='Find off-target effects (slow)',
key='check_off_targets',
disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
)
# transcript entry
st.selectbox(
label='How would you like to provide transcript(s) of interest?',
options=ENTRY_METHODS.values(),
key='entry_method',
disabled=st.session_state.transcripts is not None
)
if st.session_state.entry_method == ENTRY_METHODS['manual']:
st.text_input(
label='Enter a target transcript:',
key='manual_entry',
placeholder='Upper or lower case',
disabled=st.session_state.transcripts is not None
)
elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
st.file_uploader(
label='Upload a fasta file:',
key='fasta_entry',
disabled=st.session_state.transcripts is not None
)
# let's go!
st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
progress = st.empty()
# input error
error = st.empty()
if st.session_state.input_error is not None:
error.error(st.session_state.input_error, icon="🚨")
else:
error.empty()
# on-target results
on_target_results = st.empty()
if st.session_state.on_target is not None:
with on_target_results.container():
st.write('On-target predictions:', st.session_state.on_target)
st.download_button(
label='Download on-target predictions',
data=convert_df(st.session_state.on_target),
file_name='on_target.csv',
mime='text/csv'
)
else:
on_target_results.empty()
# titration results
titration_results = st.empty()
if st.session_state.titration is not None:
with titration_results.container():
st.write('Titration predictions:', st.session_state.titration)
st.download_button(
label='Download titration predictions',
data=convert_df(st.session_state.titration),
file_name='titration.csv',
mime='text/csv'
)
else:
titration_results.empty()
# off-target results
off_target_results = st.empty()
if st.session_state.off_target is not None:
with off_target_results.container():
if len(st.session_state.off_target) > 0:
st.write('Off-target predictions:', st.session_state.off_target)
st.download_button(
label='Download off-target predictions',
data=convert_df(st.session_state.off_target),
file_name='off_target.csv',
mime='text/csv'
)
else:
st.write('We did not find any off-target effects!')
else:
off_target_results.empty()
# keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
if st.session_state.transcripts is not None:
st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
transcripts=st.session_state.transcripts,
mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
check_off_targets=st.session_state.check_off_targets,
status_update_fn=progress_update
)
st.session_state.transcripts = None
st.experimental_rerun()