LTEnjoy commited on
Commit
5b69b32
1 Parent(s): 7cf5985

Update demo/modules/init_model.py

Browse files
Files changed (1) hide show
  1. demo/modules/init_model.py +117 -117
demo/modules/init_model.py CHANGED
@@ -1,118 +1,118 @@
1
- import faiss
2
- import numpy as np
3
- import pandas as pd
4
- import os
5
- import yaml
6
- import glob
7
-
8
- from easydict import EasyDict
9
- from utils.constants import sequence_level
10
- from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
11
- from tqdm import tqdm
12
-
13
-
14
- def load_model():
15
- model_config = {
16
- "protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
17
- "text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
18
- "structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
19
- "load_protein_pretrained": False,
20
- "load_text_pretrained": False,
21
- "from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
22
- }
23
-
24
- model = ProTrekTrimodalModel(**model_config)
25
- model.eval()
26
- return model
27
-
28
-
29
- def load_faiss_index(index_path: str):
30
- if config.faiss_config.IO_FLAG_MMAP:
31
- index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
32
- else:
33
- index = faiss.read_index(index_path)
34
-
35
- index.metric_type = faiss.METRIC_INNER_PRODUCT
36
- return index
37
-
38
-
39
- def load_index():
40
- all_index = {}
41
-
42
- # Load protein sequence index
43
- all_index["sequence"] = {}
44
- for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
45
- db_name = db["name"]
46
- index_dir = db["index_dir"]
47
-
48
- index_path = f"{index_dir}/sequence.index"
49
- sequence_index = load_faiss_index(index_path)
50
-
51
- id_path = f"{index_dir}/ids.tsv"
52
- uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
53
-
54
- all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
55
-
56
- # Load protein structure index
57
- print("Loading structure index...")
58
- all_index["structure"] = {}
59
- for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
60
- db_name = db["name"]
61
- index_dir = db["index_dir"]
62
-
63
- index_path = f"{index_dir}/structure.index"
64
- structure_index = load_faiss_index(index_path)
65
-
66
- id_path = f"{index_dir}/ids.tsv"
67
- uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
68
-
69
- all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
70
-
71
- # Load text index
72
- all_index["text"] = {}
73
- valid_subsections = {}
74
- for db in tqdm(config.text_index_dir, desc="Loading text index..."):
75
- db_name = db["name"]
76
- index_dir = db["index_dir"]
77
- all_index["text"][db_name] = {}
78
- text_dir = f"{index_dir}/subsections"
79
-
80
- # Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
81
- valid_subsections[db_name] = set()
82
- sequence_level.add("Global")
83
- for subsection in tqdm(sequence_level):
84
- index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
85
- if not os.path.exists(index_path):
86
- continue
87
-
88
- text_index = load_faiss_index(index_path)
89
-
90
- id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
91
- text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
92
-
93
- all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
94
- valid_subsections[db_name].add(subsection)
95
-
96
- # Sort valid_subsections
97
- for db_name in valid_subsections:
98
- valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
99
-
100
- return all_index, valid_subsections
101
-
102
-
103
- # Load the config file
104
- root_dir = __file__.rsplit("/", 3)[0]
105
- config_path = f"{root_dir}/demo/config.yaml"
106
- with open(config_path, 'r', encoding='utf-8') as r:
107
- config = EasyDict(yaml.safe_load(r))
108
-
109
- device = "cuda"
110
-
111
- print("Loading model...")
112
- model = load_model()
113
- model.to(device)
114
-
115
- all_index, valid_subsections = load_index()
116
- print("Done...")
117
- # model = None
118
  # all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}
 
1
+ import faiss
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+ import yaml
6
+ import glob
7
+
8
+ from easydict import EasyDict
9
+ from utils.constants import sequence_level
10
+ from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
11
+ from tqdm import tqdm
12
+
13
+
14
+ def load_model():
15
+ model_config = {
16
+ "protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
17
+ "text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
18
+ "structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
19
+ "load_protein_pretrained": False,
20
+ "load_text_pretrained": False,
21
+ "from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
22
+ }
23
+
24
+ model = ProTrekTrimodalModel(**model_config)
25
+ model.eval()
26
+ return model
27
+
28
+
29
+ def load_faiss_index(index_path: str):
30
+ if config.faiss_config.IO_FLAG_MMAP:
31
+ index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
32
+ else:
33
+ index = faiss.read_index(index_path)
34
+
35
+ index.metric_type = faiss.METRIC_INNER_PRODUCT
36
+ return index
37
+
38
+
39
+ def load_index():
40
+ all_index = {}
41
+
42
+ # Load protein sequence index
43
+ all_index["sequence"] = {}
44
+ for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
45
+ db_name = db["name"]
46
+ index_dir = db["index_dir"]
47
+
48
+ index_path = f"{index_dir}/sequence.index"
49
+ sequence_index = load_faiss_index(index_path)
50
+
51
+ id_path = f"{index_dir}/ids.tsv"
52
+ uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
53
+
54
+ all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
55
+
56
+ # Load protein structure index
57
+ print("Loading structure index...")
58
+ all_index["structure"] = {}
59
+ for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
60
+ db_name = db["name"]
61
+ index_dir = db["index_dir"]
62
+
63
+ index_path = f"{index_dir}/structure.index"
64
+ structure_index = load_faiss_index(index_path)
65
+
66
+ id_path = f"{index_dir}/ids.tsv"
67
+ uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
68
+
69
+ all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
70
+
71
+ # Load text index
72
+ all_index["text"] = {}
73
+ valid_subsections = {}
74
+ for db in tqdm(config.text_index_dir, desc="Loading text index..."):
75
+ db_name = db["name"]
76
+ index_dir = db["index_dir"]
77
+ all_index["text"][db_name] = {}
78
+ text_dir = f"{index_dir}/subsections"
79
+
80
+ # Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
81
+ valid_subsections[db_name] = set()
82
+ sequence_level.add("Global")
83
+ for subsection in tqdm(sequence_level):
84
+ index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
85
+ if not os.path.exists(index_path):
86
+ continue
87
+
88
+ text_index = load_faiss_index(index_path)
89
+
90
+ id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
91
+ text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
92
+
93
+ all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
94
+ valid_subsections[db_name].add(subsection)
95
+
96
+ # Sort valid_subsections
97
+ for db_name in valid_subsections:
98
+ valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
99
+
100
+ return all_index, valid_subsections
101
+
102
+
103
+ # Load the config file
104
+ root_dir = __file__.rsplit("/", 3)[0]
105
+ config_path = f"{root_dir}/demo/config.yaml"
106
+ with open(config_path, 'r', encoding='utf-8') as r:
107
+ config = EasyDict(yaml.safe_load(r))
108
+
109
+ device = "cuda"
110
+
111
+ print("Loading model...")
112
+ model = load_model()
113
+ # model.to(device)
114
+
115
+ all_index, valid_subsections = load_index()
116
+ print("Done...")
117
+ # model = None
118
  # all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}