File size: 10,363 Bytes
e93c659
 
 
 
 
 
 
 
 
 
fdf5616
6a7c0e6
306ab4d
 
6a7c0e6
 
 
83a2e73
6a7c0e6
 
83a2e73
6a7c0e6
 
83a2e73
6a7c0e6
 
83a2e73
6a7c0e6
 
4d05090
bc09cb1
 
 
4d05090
83a2e73
 
 
e93c659
83a2e73
 
 
 
e93c659
 
83a2e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e93c659
83a2e73
 
 
 
 
 
e93c659
83a2e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e93c659
83a2e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import tiger
import pandas as pd
import streamlit as st
from pathlib import Path

ENTRY_METHODS = dict(
    manual='Manual entry of single transcript',
    fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
)
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']

selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')

# Check if the selected model is Cas9
if selected_model == 'Cas9':
    # Display buttons for the Cas9 model
    if st.checkbox('SPCas9_U6'):
        # Placeholder for action when SPCas9_U6 is clicked
        pass
    if st.checkbox('SPCas9_t7'):
        # Placeholder for action when SPCas9_t7 is clicked
        pass
    if st.checkbox('eSPCas9'):
        # Placeholder for action when eSPCas9 is clicked
        pass
    if st.checkbox('SPCas9_HF1'):
        # Placeholder for action when SPCas9_HF1 is clicked
        pass
elif selected_model == 'Cas12':
        # Placeholder for Cas12 model loading
        # TODO: Implement Cas12 model loading logic
        raise NotImplementedError("Cas12 model loading not implemented yet.")
elif selected_model == 'Cas13d':
        ENTRY_METHODS = dict(
        manual='Manual entry of single transcript',
        fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
        )
        @st.cache_data
        def convert_df(df):
            # IMPORTANT: Cache the conversion to prevent computation on every rerun
            return df.to_csv().encode('utf-8')


        def mode_change_callback():
            if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}:  # TODO: support titration
                st.session_state.check_off_targets = False
                st.session_state.disable_off_target_checkbox = True
            else:
                st.session_state.disable_off_target_checkbox = False


        def progress_update(update_text, percent_complete):
            with progress.container():
                st.write(update_text)
                st.progress(percent_complete / 100)


        def initiate_run():

            # initialize state variables
            st.session_state.transcripts = None
            st.session_state.input_error = None
            st.session_state.on_target = None
            st.session_state.titration = None
            st.session_state.off_target = None

            # initialize transcript DataFrame
            transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])

            # manual entry
            if st.session_state.entry_method == ENTRY_METHODS['manual']:
                transcripts = pd.DataFrame({
                    tiger.ID_COL: ['ManualEntry'],
                    tiger.SEQ_COL: [st.session_state.manual_entry]
                }).set_index(tiger.ID_COL)

            # fasta file upload
            elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
                if st.session_state.fasta_entry is not None:
                    fasta_path = st.session_state.fasta_entry.name
                    with open(fasta_path, 'w') as f:
                        f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
                    transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
                    os.remove(fasta_path)

            # convert to upper case as used by tokenizer
            transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))

            # ensure all transcripts have unique identifiers
            if transcripts.index.has_duplicates:
                st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"

            # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
            elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
                st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'

            # ensure all transcripts satisfy length requirements
            elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
                st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)

            # run model if we have any transcripts
            elif len(transcripts) > 0:
                st.session_state.transcripts = transcripts


        if __name__ == '__main__':

            # app initialization
            if 'mode' not in st.session_state:
                st.session_state.mode = tiger.RUN_MODES['all']
                st.session_state.disable_off_target_checkbox = True
            if 'entry_method' not in st.session_state:
                st.session_state.entry_method = ENTRY_METHODS['manual']
            if 'transcripts' not in st.session_state:
                st.session_state.transcripts = None
            if 'input_error' not in st.session_state:
                st.session_state.input_error = None
            if 'on_target' not in st.session_state:
                st.session_state.on_target = None
            if 'titration' not in st.session_state:
                st.session_state.titration = None
            if 'off_target' not in st.session_state:
                st.session_state.off_target = None

            # title and documentation
            st.markdown(Path('tiger.md').read_text(), unsafe_allow_html=True)
            st.divider()

            # mode selection
            col1, col2 = st.columns([0.65, 0.35])
            with col1:
                st.radio(
                    label='What do you want to predict?',
                    options=tuple(tiger.RUN_MODES.values()),
                    key='mode',
                    on_change=mode_change_callback,
                    disabled=st.session_state.transcripts is not None,
                )
            with col2:
                st.checkbox(
                    label='Find off-target effects (slow)',
                    key='check_off_targets',
                    disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
                )

            # transcript entry
            st.selectbox(
                label='How would you like to provide transcript(s) of interest?',
                options=ENTRY_METHODS.values(),
                key='entry_method',
                disabled=st.session_state.transcripts is not None
            )
            if st.session_state.entry_method == ENTRY_METHODS['manual']:
                st.text_input(
                    label='Enter a target transcript:',
                    key='manual_entry',
                    placeholder='Upper or lower case',
                    disabled=st.session_state.transcripts is not None
                )
            elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
                st.file_uploader(
                    label='Upload a fasta file:',
                    key='fasta_entry',
                    disabled=st.session_state.transcripts is not None
                )

            # let's go!
            st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
            progress = st.empty()

            # input error
            error = st.empty()
            if st.session_state.input_error is not None:
                error.error(st.session_state.input_error, icon="🚨")
            else:
                error.empty()

            # on-target results
            on_target_results = st.empty()
            if st.session_state.on_target is not None:
                with on_target_results.container():
                    st.write('On-target predictions:', st.session_state.on_target)
                    st.download_button(
                        label='Download on-target predictions',
                        data=convert_df(st.session_state.on_target),
                        file_name='on_target.csv',
                        mime='text/csv'
                    )
            else:
                on_target_results.empty()

            # titration results
            titration_results = st.empty()
            if st.session_state.titration is not None:
                with titration_results.container():
                    st.write('Titration predictions:', st.session_state.titration)
                    st.download_button(
                        label='Download titration predictions',
                        data=convert_df(st.session_state.titration),
                        file_name='titration.csv',
                        mime='text/csv'
                    )
            else:
                titration_results.empty()

            # off-target results
            off_target_results = st.empty()
            if st.session_state.off_target is not None:
                with off_target_results.container():
                    if len(st.session_state.off_target) > 0:
                        st.write('Off-target predictions:', st.session_state.off_target)
                        st.download_button(
                            label='Download off-target predictions',
                            data=convert_df(st.session_state.off_target),
                            file_name='off_target.csv',
                            mime='text/csv'
                        )
                    else:
                        st.write('We did not find any off-target effects!')
            else:
                off_target_results.empty()

            # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
            if st.session_state.transcripts is not None:
                st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
                    transcripts=st.session_state.transcripts,
                    mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
                    check_off_targets=st.session_state.check_off_targets,
                    status_update_fn=progress_update
                )
                st.session_state.transcripts = None
                st.experimental_rerun()
else:
    raise ValueError(f"Unknown model: {model_name}")