Spaces:
Running
Running
import requests | |
import tensorflow as tf | |
import pandas as pd | |
import numpy as np | |
from operator import add | |
from functools import reduce | |
from keras.models import load_model | |
import random | |
# configure GPUs | |
for gpu in tf.config.list_physical_devices('GPU'): | |
tf.config.experimental.set_memory_growth(gpu, enable=True) | |
if len(tf.config.list_physical_devices('GPU')) > 0: | |
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU') | |
ntmap = {'A': (1, 0, 0, 0), | |
'C': (0, 1, 0, 0), | |
'G': (0, 0, 1, 0), | |
'T': (0, 0, 0, 1) | |
} | |
def get_seqcode(seq): | |
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape( | |
(1, len(seq), -1)) | |
from keras.models import load_model | |
class DCModelOntar: | |
def __init__(self, ontar_model_dir, is_reg=False): | |
self.model = load_model(ontar_model_dir) | |
def ontar_predict(self, x, channel_first=True): | |
if channel_first: | |
x = x.transpose([0, 2, 3, 1]) | |
yp = self.model.predict(x) | |
return yp.ravel() | |
# Function to predict on-target efficiency and format output | |
def format_prediction_output(gRNAs, model_path): | |
dcModel = DCModelOntar(model_path) | |
formatted_data = [] | |
for gRNA in gRNAs: | |
# Encode the gRNA sequence | |
encoded_seq = get_seqcode(gRNA[0]).reshape(-1,4,1,23) | |
# Predict on-target efficiency using the model | |
prediction = dcModel.ontar_predict(encoded_seq) | |
# Format output | |
chr = gRNA[1] | |
start = gRNA[2] | |
end = gRNA[3] | |
strand = gRNA[4] | |
formatted_data.append([chr, start, end, strand, gRNA[0], prediction[0]]) | |
return formatted_data | |
def fetch_ensembl_transcripts(gene_symbol): | |
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
gene_data = response.json() | |
if 'Transcript' in gene_data: | |
return gene_data['Transcript'] | |
else: | |
print("No transcripts found for gene:", gene_symbol) | |
return None | |
else: | |
print(f"Error fetching gene data from Ensembl: {response.text}") | |
return None | |
def fetch_ensembl_sequence(transcript_id): | |
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
sequence_data = response.json() | |
if 'seq' in sequence_data: | |
return sequence_data['seq'] | |
else: | |
print("No sequence found for transcript:", transcript_id) | |
return None | |
else: | |
print(f"Error fetching sequence data from Ensembl: {response.text}") | |
return None | |
def find_crispr_targets(sequence, chr, start, strand, pam="NGG", target_length=20): | |
targets = [] | |
len_sequence = len(sequence) | |
for i in range(len_sequence - len(pam) + 1): | |
if sequence[i + 1:i + 3] == pam[1:]: | |
if i >= target_length: | |
target_seq = sequence[i - target_length:i + 3] | |
tar_start = start + i - target_length | |
tar_end = start + i + 3 | |
targets.append([target_seq, chr, tar_start, tar_end, strand]) | |
return targets | |
def process_gene(gene_symbol, model_path): | |
transcripts = fetch_ensembl_transcripts(gene_symbol) | |
all_data = [] | |
if transcripts: | |
for transcript in transcripts: | |
transcript_id = transcript['id'] | |
chr = transcript.get('seq_region_name', 'unknown') | |
start = transcript.get('start', 0) | |
strand = transcript.get('strand', 'unknown') | |
gene_sequence = fetch_ensembl_sequence(transcript_id) | |
if gene_sequence: | |
gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand) | |
if gRNA_sites: | |
formatted_data = format_prediction_output(gRNA_sites, model_path) | |
all_data.extend(formatted_data) | |
return all_data | |
# Function to save results as CSV | |
def save_to_csv(data, filename="crispr_results.csv"): | |
df = pd.DataFrame(data, | |
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"]) | |
df.to_csv(filename, index=False) |