Prot2Text-Large-v1-0 / utils_dataset.py
habdine's picture
Upload code
2f76f69 verified
raw
history blame
2.36 kB
import numpy as np
import csv
def load_GO_annot(filename):
# Load GO annotations
onts = ['mf', 'bp', 'cc']
prot2annot = {}
goterms = {ont: [] for ont in onts}
gonames = {ont: [] for ont in onts}
with open(filename, mode='r') as tsvfile:
reader = csv.reader(tsvfile, delimiter='\t')
# molecular function
next(reader, None) # skip the headers
goterms[onts[0]] = next(reader)
next(reader, None) # skip the headers
gonames[onts[0]] = next(reader)
# biological process
next(reader, None) # skip the headers
goterms[onts[1]] = next(reader)
next(reader, None) # skip the headers
gonames[onts[1]] = next(reader)
# cellular component
next(reader, None) # skip the headers
goterms[onts[2]] = next(reader)
next(reader, None) # skip the headers
gonames[onts[2]] = next(reader)
next(reader, None) # skip the headers
counts = {ont: np.zeros(len(goterms[ont]), dtype=float) for ont in onts}
for row in reader:
prot, prot_goterms = row[0], row[1:]
prot2annot[prot] = {ont: [] for ont in onts}
for i in range(3):
goterm_indices = [goterms[onts[i]].index(goterm) for goterm in prot_goterms[i].split(',') if goterm != '']
prot2annot[prot][onts[i]] = np.zeros(len(goterms[onts[i]]))
prot2annot[prot][onts[i]][goterm_indices] = 1.0
counts[onts[i]][goterm_indices] += 1.0
return prot2annot, goterms, gonames, counts
def load_EC_annot(filename):
# Load EC annotations """
prot2annot = {}
with open(filename, mode='r') as tsvfile:
reader = csv.reader(tsvfile, delimiter='\t')
# molecular function
next(reader, None) # skip the headers
ec_numbers = {'ec': next(reader)}
next(reader, None) # skip the headers
counts = {'ec': np.zeros(len(ec_numbers['ec']), dtype=float)}
for row in reader:
prot, prot_ec_numbers = row[0], row[1]
ec_indices = [ec_numbers['ec'].index(ec_num) for ec_num in prot_ec_numbers.split(',')]
prot2annot[prot] = {'ec': np.zeros(len(ec_numbers['ec']), dtype=np.int64)}
prot2annot[prot]['ec'][ec_indices] = 1.0
counts['ec'][ec_indices] += 1