supercat666 commited on
Commit
7ef3dbe
1 Parent(s): dd69d15

updated cas9on

Browse files
Files changed (3) hide show
  1. app.py +10 -4
  2. cas9_model/on-cla.h5 +2 -2
  3. cas9on.py +14 -52
app.py CHANGED
@@ -90,7 +90,6 @@ if selected_model == 'Cas9':
90
  key='target_selection'
91
  )
92
 
93
- # Actions based on the selected enzyme
94
  if target_selection == 'on-target':
95
  # Gene symbol entry
96
  gene_symbol = st.text_input('Enter a Gene Symbol:', key='gene_symbol')
@@ -107,10 +106,17 @@ if selected_model == 'Cas9':
107
 
108
  # On-target results display
109
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
110
- st.write('On-target predictions:', st.session_state['on_target_results'])
 
 
 
 
 
111
  if 'full_on_target_results' in st.session_state:
112
- # Provide a download button for the full results
113
- full_predictions_csv = cas9on.convert_df(st.session_state['full_on_target_results'])
 
 
114
  st.download_button(
115
  label='Download on-target predictions',
116
  data=full_predictions_csv,
 
90
  key='target_selection'
91
  )
92
 
 
93
  if target_selection == 'on-target':
94
  # Gene symbol entry
95
  gene_symbol = st.text_input('Enter a Gene Symbol:', key='gene_symbol')
 
106
 
107
  # On-target results display
108
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
109
+ # Convert the results to a pandas DataFrame for better display
110
+ df = pd.DataFrame(st.session_state['on_target_results'],
111
+ columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
112
+ st.write('On-target predictions:')
113
+ st.dataframe(df)
114
+
115
  if 'full_on_target_results' in st.session_state:
116
+ # Convert full results to a CSV for download
117
+ full_df = pd.DataFrame(st.session_state['full_on_target_results'],
118
+ columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
119
+ full_predictions_csv = full_df.to_csv(index=False).encode('utf-8')
120
  st.download_button(
121
  label='Download on-target predictions',
122
  data=full_predictions_csv,
cas9_model/on-cla.h5 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5acf8f740cf326052ad08db2ca71d7204526c61f6a9fcdca36e15004bc16ad04
3
- size 34044032
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3426146f71d42c25fdc2baa959d3c43d23404c5f9200a064701bb86788d38fe9
3
+ size 34040392
cas9on.py CHANGED
@@ -19,42 +19,11 @@ ntmap = {'A': (1, 0, 0, 0),
19
  'G': (0, 0, 1, 0),
20
  'T': (0, 0, 0, 1)
21
  }
22
- epimap = {'A': 1, 'N': 0}
23
 
24
  def get_seqcode(seq):
25
  return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
26
  (1, len(seq), -1))
27
 
28
-
29
- def get_epicode(eseq):
30
- return np.array(list(map(lambda c: epimap[c], eseq))).reshape(1, len(eseq), -1)
31
-
32
- class Episgt:
33
- def __init__(self, fpath, num_epi_features, with_y=True):
34
- self._fpath = fpath
35
- self._ori_df = pd.read_csv(fpath, sep='\t', index_col=None, header=None)
36
- self._num_epi_features = num_epi_features
37
- self._with_y = with_y
38
- self._num_cols = num_epi_features + 2 if with_y else num_epi_features + 1
39
- self._cols = list(self._ori_df.columns)[-self._num_cols:]
40
- self._df = self._ori_df[self._cols]
41
-
42
- @property
43
- def length(self):
44
- return len(self._df)
45
-
46
- def get_dataset(self, x_dtype=np.float32, y_dtype=np.float32):
47
- x_seq = np.concatenate(list(map(get_seqcode, self._df[self._cols[0]])))
48
- x_epis = np.concatenate([np.concatenate(list(map(get_epicode, self._df[col]))) for col in
49
- self._cols[1: 1 + self._num_epi_features]], axis=-1)
50
- x = np.concatenate([x_seq, x_epis], axis=-1).astype(x_dtype)
51
- x = x.transpose(0, 2, 1)
52
- if self._with_y:
53
- y = np.array(self._df[self._cols[-1]]).astype(y_dtype)
54
- return x, y
55
- else:
56
- return x
57
-
58
  from keras.models import load_model
59
  class DCModelOntar:
60
  def __init__(self, ontar_model_dir, is_reg=False):
@@ -66,32 +35,25 @@ class DCModelOntar:
66
  yp = self.model.predict(x)
67
  return yp.ravel()
68
 
69
- # Function to generate random epigenetic data
70
- def generate_random_epigenetic_data(length):
71
- return ''.join(random.choice('AN') for _ in range(length))
72
 
73
  # Function to predict on-target efficiency and format output
74
- def format_prediction_output(gRNA_sites, gene_id, model_path):
75
  dcModel = DCModelOntar(model_path)
76
  formatted_data = []
77
 
78
- for gRNA in gRNA_sites:
79
  # Encode the gRNA sequence
80
- encoded_seq = get_seqcode(gRNA).reshape(-1,4,1,23)
81
- #encoded_seq = np.expand_dims(encoded_seq, axis=2) # Adjust the shape for the model
82
-
83
- # Generate random epigenetic features (as placeholders)
84
- ctcf = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
85
- dnase = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
86
- h3k4me3 = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
87
- rrbs = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
88
 
89
  # Predict on-target efficiency using the model
90
- input = np.concatenate((encoded_seq, ctcf, dnase, h3k4me3, rrbs), axis=1)
91
- prediction = dcModel.ontar_predict(input)
92
 
93
  # Format output
94
- formatted_data.append([gene_id, "start_pos", "end_pos", "strand", gRNA, ctcf, dnase, h3k4me3, rrbs, prediction[0]])
 
 
 
 
95
 
96
  return formatted_data
97
 
@@ -123,7 +85,7 @@ def fetch_ensembl_sequence(transcript_id):
123
  print(f"Error fetching sequence data from Ensembl: {response.text}")
124
  return None
125
 
126
- def find_crispr_targets(sequence, pam="NGG", target_length=20):
127
  targets = []
128
  len_sequence = len(sequence)
129
 
@@ -131,11 +93,12 @@ def find_crispr_targets(sequence, pam="NGG", target_length=20):
131
  if sequence[i + 1:i + 3] == pam[1:]:
132
  if i >= target_length:
133
  target_seq = sequence[i - target_length:i + 3]
134
- targets.append(target_seq)
 
 
135
 
136
  return targets
137
 
138
-
139
  def process_gene(gene_symbol, model_path):
140
  transcripts = fetch_ensembl_transcripts(gene_symbol)
141
  all_data = []
@@ -156,6 +119,5 @@ def process_gene(gene_symbol, model_path):
156
  # Function to save results as CSV
157
  def save_to_csv(data, filename="crispr_results.csv"):
158
  df = pd.DataFrame(data,
159
- columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "CTCF", "Dnase", "H3K4me3", "RRBS",
160
- "Prediction"])
161
  df.to_csv(filename, index=False)
 
19
  'G': (0, 0, 1, 0),
20
  'T': (0, 0, 0, 1)
21
  }
 
22
 
23
  def get_seqcode(seq):
24
  return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
25
  (1, len(seq), -1))
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from keras.models import load_model
28
  class DCModelOntar:
29
  def __init__(self, ontar_model_dir, is_reg=False):
 
35
  yp = self.model.predict(x)
36
  return yp.ravel()
37
 
 
 
 
38
 
39
  # Function to predict on-target efficiency and format output
40
+ def format_prediction_output(gRNAs, model_path):
41
  dcModel = DCModelOntar(model_path)
42
  formatted_data = []
43
 
44
+ for gRNA in gRNAs:
45
  # Encode the gRNA sequence
46
+ encoded_seq = get_seqcode(gRNA[0]).reshape(-1,4,1,23)
 
 
 
 
 
 
 
47
 
48
  # Predict on-target efficiency using the model
49
+ prediction = dcModel.ontar_predict(encoded_seq)
 
50
 
51
  # Format output
52
+ chr = gRNA[1]
53
+ start = gRNA[2]
54
+ end = gRNA[3]
55
+ strand = gRNA[4]
56
+ formatted_data.append([chr, start, end, strand, gRNA[0], prediction[0]])
57
 
58
  return formatted_data
59
 
 
85
  print(f"Error fetching sequence data from Ensembl: {response.text}")
86
  return None
87
 
88
+ def find_crispr_targets(sequence, chr, start, strand, pam="NGG", target_length=20):
89
  targets = []
90
  len_sequence = len(sequence)
91
 
 
93
  if sequence[i + 1:i + 3] == pam[1:]:
94
  if i >= target_length:
95
  target_seq = sequence[i - target_length:i + 3]
96
+ tar_start = start + i - target_length
97
+ tar_end = start + i + 3
98
+ targets.append([target_seq, chr, tar_start, tar_end, strand])
99
 
100
  return targets
101
 
 
102
  def process_gene(gene_symbol, model_path):
103
  transcripts = fetch_ensembl_transcripts(gene_symbol)
104
  all_data = []
 
119
  # Function to save results as CSV
120
  def save_to_csv(data, filename="crispr_results.csv"):
121
  df = pd.DataFrame(data,
122
+ columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
 
123
  df.to_csv(filename, index=False)