uploaded TransfoRNA repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +674 -0
- README.md +145 -3
- conf/__init__.py +0 -0
- conf/hydra/job_logging/custom.yaml +13 -0
- conf/inference_settings/default.yaml +3 -0
- conf/main_config.yaml +51 -0
- conf/model/transforna.yaml +9 -0
- conf/readme.md +17 -0
- conf/train_model_configs/__init__.py +0 -0
- conf/train_model_configs/custom.py +76 -0
- conf/train_model_configs/premirna.py +68 -0
- conf/train_model_configs/sncrna.py +70 -0
- conf/train_model_configs/tcga.py +81 -0
- environment.yml +24 -0
- install.sh +33 -0
- kba_pipeline/README.md +58 -0
- kba_pipeline/environment.yml +20 -0
- kba_pipeline/src/annotate_from_mapping.py +751 -0
- kba_pipeline/src/make_anno.py +59 -0
- kba_pipeline/src/map_2_HBDxBase.py +318 -0
- kba_pipeline/src/precursor_bins.py +127 -0
- kba_pipeline/src/utils.py +163 -0
- requirements.txt +18 -0
- scripts/test_inference_api.py +29 -0
- scripts/train.sh +80 -0
- setup.py +40 -0
- transforna/__init__.py +7 -0
- transforna/__main__.py +54 -0
- transforna/bin/figure_scripts/figure_4_table_3.py +173 -0
- transforna/bin/figure_scripts/figure_5_S10_table_4.py +466 -0
- transforna/bin/figure_scripts/figure_6.ipynb +228 -0
- transforna/bin/figure_scripts/figure_S4.ipynb +368 -0
- transforna/bin/figure_scripts/figure_S5.py +94 -0
- transforna/bin/figure_scripts/figure_S8.py +56 -0
- transforna/bin/figure_scripts/figure_S9_S11.py +136 -0
- transforna/bin/figure_scripts/infer_lc_using_tcga.py +438 -0
- transforna/src/__init__.py +9 -0
- transforna/src/callbacks/LRCallback.py +174 -0
- transforna/src/callbacks/__init__.py +4 -0
- transforna/src/callbacks/criterion.py +165 -0
- transforna/src/callbacks/metrics.py +218 -0
- transforna/src/callbacks/tbWriter.py +6 -0
- transforna/src/inference/__init__.py +3 -0
- transforna/src/inference/inference_api.py +243 -0
- transforna/src/inference/inference_benchmark.py +48 -0
- transforna/src/inference/inference_tcga.py +38 -0
- transforna/src/model/__init__.py +2 -0
- transforna/src/model/model_components.py +449 -0
- transforna/src/model/skorchWrapper.py +364 -0
- transforna/src/novelty_prediction/__init__.py +2 -0
LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
README.md
CHANGED
@@ -1,3 +1,145 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TransfoRNA
|
2 |
+
TransfoRNA is a **bioinformatics** and **machine learning** tool based on **Transformers** to provide annotations for 11 major classes (miRNA, rRNA, tRNA, snoRNA, protein
|
3 |
+
-coding/mRNA, lncRNA, YRNA, piRNA, snRNA, snoRNA and vtRNA) and 1923 sub-classes
|
4 |
+
for **human small RNAs and RNA fragments**. These are typically detected by RNA-seq NGS (next generation sequencing) data.
|
5 |
+
|
6 |
+
TransfoRNA can be trained on just the RNA sequences and optionally on additional information such as secondary structure. The result is a major and sub-class assignment combined with a novelty score (Normalized Levenshtein Distance) that quantifies the difference between the query sequence and the closest match found in the training set. Based on that it decides if the query sequence is novel or familiar. TransfoRNA uses a small curated set of ground truth labels obtained from common knowledge-based bioinformatics tools that map the sequences to transcriptome databases and a reference genome. Using TransfoRNA's framewok, the high confidence annotations in the TCGA dataset can be increased by 3 folds.
|
7 |
+
|
8 |
+
|
9 |
+
## Dataset (Objective):
|
10 |
+
- **The Cancer Genome Atlas, [TCGA](https://www.cancer.gov/about-nci/organization/ccg/research/structural-genomics/tcga)** offers sequencing data of small RNAs and is used to evaluate TransfoRNAs classification performance
|
11 |
+
- Sequences are annotated based on a knowledge-based annotation approach that provides annotations for ~2k different sub-classes belonging to 11 major classes.
|
12 |
+
- Knowledge-based annotations are divided into three sets of varying confidence levels: a **high-confidence (HICO)** set, a **low-confidence (LOCO)** set, and a **non-annotated (NA)** set for sequences that could not be annotated at all. Only HICO annotations are used for training.
|
13 |
+
- HICO RNAs cover ~2k sub-classes and constitute 19.6% of all RNAs found in TCGA. LOCO and NA sets comprise 66.9% and 13.6% of RNAs, respectively.
|
14 |
+
- HICO RNAs are further divided into **in-distribution, ID** (374 sub-classes) and **out-of-distribution, OOD** (1549 sub-classes) sets.
|
15 |
+
- Criteria for ID and OOD: Sub-class containing more than 8 sequences are considered ID, otherwise OOD.
|
16 |
+
- An additional **putative 5' adapter affixes set** contains 294 sequences known to be technical artefacts. The 5’-end perfectly matches the last five or more nucleotides of the 5’-adapter sequence, commonly used in small RNA sequencing.
|
17 |
+
- The knowledge-based annotation (KBA) pipline including installation guide is located under `kba_pipline`
|
18 |
+
|
19 |
+
## Models
|
20 |
+
There are 5 classifier models currently available, each with different input representation.
|
21 |
+
- Baseline:
|
22 |
+
- Input: (single input) Sequence
|
23 |
+
- Model: An embedding layer that converts sequences into vectors followed by a classification feed forward layer.
|
24 |
+
- Seq:
|
25 |
+
- Input: (single input) Sequence
|
26 |
+
- Model: A transformer based encoder model.
|
27 |
+
- Seq-Seq:
|
28 |
+
- Input: (dual inputs) Sequence divided into even and odd tokens.
|
29 |
+
- Model: A transformer encoder is placed for odd tokens and another for even tokens.
|
30 |
+
- Seq-Struct:
|
31 |
+
- Input: (dual inputs) Sequence + Secondary structure
|
32 |
+
- Model: A transformer encoder for the sequence and another for the secondary structure.
|
33 |
+
- Seq-Rev (best performant):
|
34 |
+
- Input: (dual inputs) Sequence
|
35 |
+
- Model: A transformer encoder for the sequence and another for the sequence reversed.
|
36 |
+
|
37 |
+
|
38 |
+
*Note: These (Transformer) based models show overlapping and distinct capabilities. Consequently, an ensemble model is created to leverage those capabilities.*
|
39 |
+
|
40 |
+
|
41 |
+
<img width="948" alt="Screenshot 2023-08-16 at 16 39 20" src="https://github.com/gitHBDX/TransfoRNA-Framework/assets/82571392/d7d092d8-8cbd-492a-9ccc-994ffdd5aa5f">
|
42 |
+
|
43 |
+
## Data Availability
|
44 |
+
Downloading the data and the models can be done from [here](https://www.dropbox.com/sh/y7u8cofmg41qs0y/AADvj5lw91bx7fcDxghMbMtsa?dl=0).
|
45 |
+
|
46 |
+
This will download three subfolders that should be kept on the same folder level as `src`:
|
47 |
+
- `data`: Contains three files:
|
48 |
+
- `TCGA` anndata with ~75k sequences and `var` columns containing the knowledge based annotations.
|
49 |
+
- `HBDXBase.csv` containing a list of RNA precursors which are then used for data augmentation.
|
50 |
+
- `subclass_to_annotation.json` holds mappings for every sub-class to major-class.
|
51 |
+
|
52 |
+
- `models`:
|
53 |
+
- `benchmark` : contains benchmark models trained on sncRNA and premiRNA data. (See additional datasets at the bottom)
|
54 |
+
- `tcga`: All models trained on the TCGA data; `TransfoRNA_ID` (for testing and validation) and `TransfoRNA_FULL` (the production version) containing higher RNA major and sub-class coverage. Each of the two folders contain all the models trained seperately on major-class and sub-class.
|
55 |
+
- `kba_pipeline`: contains mapping reference data required to run the knowledge based pipeline manually
|
56 |
+
## Repo Structure
|
57 |
+
- configs: Contains the configurations of each model, training and inference settings.
|
58 |
+
|
59 |
+
The `conf/main_config.yaml` file offers options to change the task, the training settings and the logging. The following shows all the options and permitted values for each option.
|
60 |
+
|
61 |
+
<img width="835" alt="Screenshot 2024-05-22 at 10 19 15" src="https://github.com/gitHBDX/TransfoRNA/assets/82571392/225d2c98-ed45-4ca7-9e86-557a73af702d">
|
62 |
+
|
63 |
+
- transforna contains two folders:
|
64 |
+
- `src` folder which contains transforna package. View transforna's architecture [here](https://github.com/gitHBDX/TransfoRNA/blob/master/transforna/src/readme.md).
|
65 |
+
- `bin` folder contains all scripts necessary for reproducing manuscript figures.
|
66 |
+
|
67 |
+
## Installation
|
68 |
+
|
69 |
+
The `install.sh` is a script that creates an transforna environment in which all the required packages for TransfoRNA are installed. Simply navigate to the root directory and run from terminal:
|
70 |
+
|
71 |
+
```
|
72 |
+
#make install script executable
|
73 |
+
chmod +x install.sh
|
74 |
+
|
75 |
+
|
76 |
+
#run script
|
77 |
+
./install.sh
|
78 |
+
```
|
79 |
+
|
80 |
+
## TransfoRNA API
|
81 |
+
In `transforna/src/inference/inference_api.py`, all the functionalities of transforna are offered as APIs. There are two functions of interest:
|
82 |
+
- `predict_transforna` : Computes for a set of sequences and for a given model, one of various options; the embeddings, logits, explanatory (similar) sequences, attentions masks or umap coordinates.
|
83 |
+
- `predict_transforna_all_models`: Same as `predict_transforna` but computes the desired option for all the models as well as aggregates the output of the ensemble model.
|
84 |
+
Both return a pandas dataframe containing the sequence along with the desired computation.
|
85 |
+
|
86 |
+
Check the script at `src/test_inference_api.py` for a basic demo on how to call the either of the APIs.
|
87 |
+
|
88 |
+
## Inference from terminal
|
89 |
+
For inference, two paths in `configs/inference_settings/default.yaml` have to be edited:
|
90 |
+
- `sequences_path`: The full path to a csv file containing the sequences for which annotations are to be inferred.
|
91 |
+
- `model_path`: The full path of the model. (currently this points to the Seq model)
|
92 |
+
|
93 |
+
Also in the `main_config.yaml`, make sure to edit the `model_name` to match the input expected by the loaded model.
|
94 |
+
- `model_name`: add the name of the model. One of `"seq"`,`"seq-seq"`,`"seq-struct"`,`"baseline"` or `"seq-rev"` (see above)
|
95 |
+
|
96 |
+
|
97 |
+
Then, navigate the repositories' root directory and run the following command:
|
98 |
+
|
99 |
+
```
|
100 |
+
python transforna/__main__.py inference=True
|
101 |
+
```
|
102 |
+
|
103 |
+
After inference, an `inference_output` folder will be created under `outputs/` which will include two files.
|
104 |
+
- `(model_name)_embedds.csv`: contains vector embedding per sequence in the inference set- (could be used for downstream tasks).
|
105 |
+
*Note: The embedds of each sequence will only be logged if `log_embedds` in the `main_config` is `True`.*
|
106 |
+
- `(model_name)_inference_results.csv`: Contains columns; Net-Label containing predicted label and Is Familiar? boolean column containing the models' novelty predictor output. (True: familiar/ False: Novel)
|
107 |
+
*Note: The output will also contain the logits of the model is `log_logits` in the `main_config` is `True`.*
|
108 |
+
|
109 |
+
|
110 |
+
## Train on custom data
|
111 |
+
TransfoRNA can be trained using input data as Anndata, csv or fasta. If the input is anndata, then `anndata.var` should contains all the sequences. Some changes has to be made (follow `configs/train_model_configs/tcga`):
|
112 |
+
|
113 |
+
In `configs/train_model_configs/custom`:
|
114 |
+
- `dataset_path_train` has to point to the input_data which should contain; a `sequence` column, a `small_RNA_class_annotation` coliumn indicating the major class if available (otherwise should be NaN), `five_prime_adapter_filter` specifies whether the sequence is considered a real sequence or an artifact (`True `for Real and `False` for artifact), a `subclass_name` column containing the sub-class name if available (otherwise should be NaN), and a boolean column `hico` indicating whether a sequence is high confidence or not.
|
115 |
+
- If sampling from the precursor is required in order to augment the sub-classes, the `precursor_file_path` should include precursors. Follow the scheme of the HBDxBase.csv and have a look at `PrecursorAugmenter` class in `transforna/src/processing/augmentation.py`
|
116 |
+
- `mapping_dict_path` should contain the mapping from sub class to major class. i.e: 'miR-141-5p' to 'miRNA'.
|
117 |
+
- `clf_target` sets the classification target of the mopdel and should be either `sub_class_hico` for training on targets in `subclass_name` or `major_class_hico` for training on targets in `small_RNA_class_annotation`. For both, only high confidence sequences are selected for training (based on `hico` column).
|
118 |
+
|
119 |
+
In configs/main_config, some changes should be made:
|
120 |
+
- change `task` to `custom` or to whatever name the `custom.py` has been renamed.
|
121 |
+
- set the `model_name` as desired.
|
122 |
+
|
123 |
+
For training TransfoRNA from the root directory:
|
124 |
+
```
|
125 |
+
python transforna/__main__.py
|
126 |
+
```
|
127 |
+
Using [Hydra](https://hydra.cc/), any option in the main config can be changed. For instance, to train a `Seq-Struct` TransfoRNA model without using a validation split:
|
128 |
+
```
|
129 |
+
python transforna/__main__.py train_split=False model_name='seq-struct'
|
130 |
+
```
|
131 |
+
After training, an output folder is automatically created in the root directory where training is logged.
|
132 |
+
The structure of the output folder is chosen by hydra to be `/day/time/results folders`. Results folders are a set of folders created during training:
|
133 |
+
- `ckpt`: (containing the latest checkpoint of the model)
|
134 |
+
- `embedds`:
|
135 |
+
- Contains a file per each split (train/valid/test/ood/na).
|
136 |
+
- Each file is a `csv` containing the sequences plus their embeddings (obtained by the model and represent numeric representation of a given RNA sequence) as well as the logits. The logits are values the models produce for each sequence, reflecting its confidence of a sequence belonging to a certain class.
|
137 |
+
- `meta`: A folder containing a `yaml` file with all the hyperparameters used for the current run.
|
138 |
+
- `analysis`: contains the learned novelty threshold seperating the in-distribution set(Familiar) from the out of distribution set (Novel).
|
139 |
+
- `figures`: some figures are saved containing the Normalized Levenstein Distance NLD, distribution per split.
|
140 |
+
|
141 |
+
|
142 |
+
## Additional Datasets (Objective):
|
143 |
+
- sncRNA, collected from [RFam](https://rfam.org/) (classification of RNA precursors into 13 classes)
|
144 |
+
- premiRNA [human miRNAs](http://www.mirbase.org)(classification of true vs pseudo precursors)
|
145 |
+
|
conf/__init__.py
ADDED
File without changes
|
conf/hydra/job_logging/custom.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 1
|
2 |
+
formatters:
|
3 |
+
simple:
|
4 |
+
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
5 |
+
handlers:
|
6 |
+
console:
|
7 |
+
class: logging.StreamHandler
|
8 |
+
formatter: simple
|
9 |
+
stream: ext://sys.stdout
|
10 |
+
root:
|
11 |
+
handlers: [console]
|
12 |
+
|
13 |
+
disable_existing_loggers: false
|
conf/inference_settings/default.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
infere_original_testset: false
|
2 |
+
model_path: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq/ckpt/model_params_tcga.pt
|
3 |
+
sequences_path: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/data/inference_set.csv
|
conf/main_config.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- model: transforna
|
3 |
+
- inference_settings: default
|
4 |
+
- override hydra/job_logging: disabled
|
5 |
+
|
6 |
+
task: tcga # tcga,sncrna or premirna or custom (for a custom dataset)
|
7 |
+
|
8 |
+
|
9 |
+
train_config:
|
10 |
+
_target_: train_model_configs.${task}.GeneEmbeddTrainConfig
|
11 |
+
|
12 |
+
model_config:
|
13 |
+
_target_: train_model_configs.${task}.GeneEmbeddModelConfig
|
14 |
+
|
15 |
+
#train settings
|
16 |
+
model_name: seq #seq, seq-seq, seq-struct, seq-reverse, or baseline
|
17 |
+
trained_on: full #full(production, more coverage) or id (for test/eval purposes)
|
18 |
+
path_to_models: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/ #edit path to point to models/tcga/ directory: will be used if trained_on is full
|
19 |
+
inference: False # Should TransfoRNA be used for inference or train? True or False
|
20 |
+
#if inference is true, should the logits be logged?
|
21 |
+
log_logits: False
|
22 |
+
|
23 |
+
|
24 |
+
train_split: True # True or False
|
25 |
+
valid_size: 0.15 # 0 < valid_size < 1
|
26 |
+
|
27 |
+
#CV
|
28 |
+
cross_val: True # True or False
|
29 |
+
num_replicates: 1 # Integer, num_replicates for cross-validation
|
30 |
+
|
31 |
+
#seed
|
32 |
+
seed: 1 # Integer
|
33 |
+
device_number: 1 # Integer, select GPU
|
34 |
+
|
35 |
+
|
36 |
+
#logging sequence embeddings + metrics to tensorboard
|
37 |
+
log_embedds: True # True or False
|
38 |
+
tensorboard: False # True or False
|
39 |
+
|
40 |
+
#disable hydra output
|
41 |
+
hydra:
|
42 |
+
run:
|
43 |
+
dir: ./outputs/${now:%Y-%m-%d}/${model_name}
|
44 |
+
searchpath:
|
45 |
+
- file:///nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/conf
|
46 |
+
# output_subdir: null #uncomment to disable hydra output
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
conf/model/transforna.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
skorch_model:
|
2 |
+
_target_: transforna.Net
|
3 |
+
module: transforna.GeneEmbeddModel
|
4 |
+
criterion: transforna.LossFunction
|
5 |
+
max_epochs: 0 #infered from task specific train config
|
6 |
+
optimizer: torch.optim.AdamW
|
7 |
+
device: cuda
|
8 |
+
batch_size: 64
|
9 |
+
iterator_train__shuffle: True
|
conf/readme.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The `inference_settings` has a default yaml containing four keys:
|
2 |
+
-`sequences_path`: The full path of the file containing the sequences for which their annotations are to be infered.
|
3 |
+
- `model_path`: the full path of the model to be used for inference.
|
4 |
+
- `model_name`: A model name indicating the inputs the model expects. One of `seq`,`seq-seq`,`seq-struct`,`seq-reverse` or `baseline`
|
5 |
+
- `infere_original_testset`: True/False indicating whether inference should be computed on the original test set.
|
6 |
+
|
7 |
+
`model` contains the skeleton of the model used, the optimizer, loss function and device. All models are built using [skorch](https://skorch.readthedocs.io/en/latest/?badge=latest)
|
8 |
+
|
9 |
+
`train_model_configs` contain the hyperparameters for each dataset; tcga, sncrna and premirna:
|
10 |
+
|
11 |
+
- Each file contains the model and the train config.
|
12 |
+
|
13 |
+
- Model config: contains the model hyperparameters, sequence tokenization scheme and allows for choosing the model.
|
14 |
+
|
15 |
+
- Train config: contains training settings such as the learning rate hyper parameters as well as `dataset_path_train`.
|
16 |
+
- `dataset_path_train`: should point to the dataset [(Anndata)](https://anndata.readthedocs.io/en/latest/) used for training.
|
17 |
+
|
conf/train_model_configs/__init__.py
ADDED
File without changes
|
conf/train_model_configs/custom.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
dirname, _ = os.path.split(os.path.dirname(__file__))
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class GeneEmbeddModelConfig:
|
11 |
+
|
12 |
+
model_input: str = "" #will be infered
|
13 |
+
|
14 |
+
num_embed_hidden: int = 100 #30 for exp, 100 for rest
|
15 |
+
ff_input_dim:int = 0 #is infered later, equals gene expression len
|
16 |
+
ff_hidden_dim: List = field(default_factory=lambda: [300]) #300 for exp hico
|
17 |
+
feed_forward1_hidden: int = 256
|
18 |
+
num_attention_project: int = 64
|
19 |
+
num_encoder_layers: int = 1
|
20 |
+
dropout: float = 0.2
|
21 |
+
n: int = 121
|
22 |
+
relative_attns: List = field(default_factory=lambda: [29, 4, 6, 8, 10, 11])
|
23 |
+
num_attention_heads: int = 5
|
24 |
+
|
25 |
+
window: int = 2
|
26 |
+
tokens_len: int = math.ceil(max_length / window)
|
27 |
+
second_input_token_len: int = 0 # is infered during runtime
|
28 |
+
vocab_size: int = 0 # is infered during runtime
|
29 |
+
second_input_vocab_size: int = 0 # is infered during runtime
|
30 |
+
tokenizer: str = (
|
31 |
+
"overlap" # either overlap or no_overlap or overlap_multi_window
|
32 |
+
)
|
33 |
+
|
34 |
+
clf_target:str = 'm' # sub_class_hico or major_class_hico. hico = high confidence
|
35 |
+
num_classes: int = 0 #will be infered during runtime
|
36 |
+
class_mappings:List = field(default_factory=lambda: [])#will be infered during runtime
|
37 |
+
class_weights :List = field(default_factory=lambda: [])
|
38 |
+
# how many extra window sizes other than deafault window
|
39 |
+
temperatures: List = field(default_factory=lambda: [0,10])
|
40 |
+
|
41 |
+
tokens_mapping_dict: Dict = None
|
42 |
+
false_input_perc:float = 0.0
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class GeneEmbeddTrainConfig:
|
47 |
+
dataset_path_train: str = 'path/to/anndata.h5ad'
|
48 |
+
precursor_file_path: str = 'path/to/precursor_file.csv' #if not provided, sampling from the precurosr will not be done
|
49 |
+
mapping_dict_path: str = 'path/to/mapping_dict.json' #required for mapping sub class to major class, i.e: mir-568-3p to miRNA
|
50 |
+
device: str = "cuda"
|
51 |
+
l2_weight_decay: float = 0.05
|
52 |
+
batch_size: int = 512
|
53 |
+
|
54 |
+
batch_per_epoch:int = 0 # is infered during runtime
|
55 |
+
|
56 |
+
label_smoothing_sim:float = 0.2
|
57 |
+
label_smoothing_clf:float = 0.0
|
58 |
+
# learning rate
|
59 |
+
learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
|
60 |
+
lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
|
61 |
+
lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
|
62 |
+
# TODO: 122 is the number of train batches per epoch, should be infered and set
|
63 |
+
# warmup batch should be during the form epoch*(train batch per epoch)
|
64 |
+
warmup_epoch: int = 10 # how many batches linear warm up for
|
65 |
+
final_epoch: int = 20 # final batch of training when want learning rate
|
66 |
+
|
67 |
+
top_k: int = 10#int(0.1 * batch_size) # if the corresponding rna/GE appears during the top k, the correctly classified
|
68 |
+
cross_val: bool = False
|
69 |
+
labels_mapping_path: str = None
|
70 |
+
filter_seq_length:bool = False
|
71 |
+
|
72 |
+
num_augment_exp:int = 20
|
73 |
+
shuffle_exp: bool = False
|
74 |
+
|
75 |
+
max_epochs: int = 3000
|
76 |
+
|
conf/train_model_configs/premirna.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class GeneEmbeddModelConfig:
|
7 |
+
# input dim for the embedding and positional encoders
|
8 |
+
# as well as all k,q,v input/output dims
|
9 |
+
model_input: str = "seq-struct"
|
10 |
+
num_embed_hidden: int = 256
|
11 |
+
ff_hidden_dim: List = field(default_factory=lambda: [1200, 800])
|
12 |
+
feed_forward1_hidden: int = 1024
|
13 |
+
num_attention_project: int = 64
|
14 |
+
num_encoder_layers: int = 1
|
15 |
+
dropout: float = 0.5
|
16 |
+
n: int = 121
|
17 |
+
relative_attns: List = field(default_factory=lambda: [int(112), int(112), 6*3, 8*3, 10*3, 11*3])
|
18 |
+
num_attention_heads: int = 1
|
19 |
+
|
20 |
+
window: int = 2
|
21 |
+
tokens_len: int = 0 #will be infered later
|
22 |
+
second_input_token_len: int = 0 # is infered in runtime
|
23 |
+
vocab_size: int = 0 # is infered in runtime
|
24 |
+
second_input_vocab_size: int = 0 # is infered in runtime
|
25 |
+
tokenizer: str = (
|
26 |
+
"overlap" # either overlap or no_overlap or overlap_multi_window
|
27 |
+
)
|
28 |
+
num_classes: int = 0 #will be infered in runtime
|
29 |
+
class_weights :List = field(default_factory=lambda: [])
|
30 |
+
tokens_mapping_dict: dict = None
|
31 |
+
|
32 |
+
#false input percentage
|
33 |
+
false_input_perc:float = 0.1
|
34 |
+
model_input: str = "seq-struct"
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class GeneEmbeddTrainConfig:
|
38 |
+
dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/premirna/train"
|
39 |
+
dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/premirna/test"
|
40 |
+
datset_path_additional_testset: str = "/data/hbdx_ldap_local/analysis/data/premirna/"
|
41 |
+
labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/premirna/labels_mapping_dict.pkl"
|
42 |
+
device: str = "cuda"
|
43 |
+
l2_weight_decay: float = 1e-5
|
44 |
+
batch_size: int = 64
|
45 |
+
|
46 |
+
batch_per_epoch: int = 0 #will be infered later
|
47 |
+
label_smoothing_sim:float = 0.0
|
48 |
+
label_smoothing_clf:float = 0.0
|
49 |
+
|
50 |
+
# learning rate
|
51 |
+
learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
|
52 |
+
lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
|
53 |
+
lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
|
54 |
+
# TODO: 122 is the number of train batches per epoch, should be infered and set
|
55 |
+
# warmup batch should be in the form epoch*(train batch per epoch)
|
56 |
+
warmup_epoch: int = 10 # how many batches linear warm up for
|
57 |
+
final_epoch: int = 20 # final batch of training when want learning rate
|
58 |
+
|
59 |
+
top_k: int = int(
|
60 |
+
0.05 * batch_size
|
61 |
+
) # if the corresponding rna/GE appears in the top k, the correctly classified
|
62 |
+
label_smoothing: float = 0.0
|
63 |
+
cross_val: bool = False
|
64 |
+
filter_seq_length:bool = True
|
65 |
+
train_epoch: int = 3000
|
66 |
+
max_epochs: int = 3500
|
67 |
+
|
68 |
+
|
conf/train_model_configs/sncrna.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from pickletools import int4
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class GeneEmbeddModelConfig:
|
9 |
+
# input dim for the embedding and positional encoders
|
10 |
+
# as well as all k,q,v input/output dims
|
11 |
+
model_input: str = "seq-struct"
|
12 |
+
num_embed_hidden: int = 256
|
13 |
+
ff_hidden_dim: List = field(default_factory=lambda: [1200, 800])
|
14 |
+
feed_forward1_hidden: int = 1024
|
15 |
+
num_attention_project: int = 64
|
16 |
+
num_encoder_layers: int = 2
|
17 |
+
dropout: float = 0.3
|
18 |
+
n: int = 121
|
19 |
+
window:int = 4
|
20 |
+
relative_attns: List = field(default_factory=lambda: [int(360), int(360)])
|
21 |
+
num_attention_heads: int = 4
|
22 |
+
|
23 |
+
tokens_len: int = 0 #will be infered later
|
24 |
+
second_input_token_len:int = 0 # is infered in runtime
|
25 |
+
vocab_size: int = 0 # is infered in runtime
|
26 |
+
second_input_vocab_size: int = 0 # is infered in runtime
|
27 |
+
tokenizer: str = (
|
28 |
+
"overlap" # either overlap or no_overlap or overlap_multi_window
|
29 |
+
)
|
30 |
+
# how many extra window sizes other than deafault window
|
31 |
+
num_classes: int = 0 #will be infered in runtime
|
32 |
+
class_weights :List = field(default_factory=lambda: [])
|
33 |
+
tokens_mapping_dict: dict = None
|
34 |
+
|
35 |
+
#false input percentage
|
36 |
+
false_input_perc:float = 0.2
|
37 |
+
|
38 |
+
model_input: str = "seq-struct"
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class GeneEmbeddTrainConfig:
|
43 |
+
dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/train.h5ad"
|
44 |
+
dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/test.h5ad"
|
45 |
+
labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/sncRNA/labels_mapping_dict.pkl"
|
46 |
+
device: str = "cuda"
|
47 |
+
l2_weight_decay: float = 1e-5
|
48 |
+
batch_size: int = 64
|
49 |
+
|
50 |
+
batch_per_epoch:int = 0 #will be infered later
|
51 |
+
label_smoothing_sim:float = 0.0
|
52 |
+
label_smoothing_clf:float = 0.0
|
53 |
+
|
54 |
+
# learning rate
|
55 |
+
learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
|
56 |
+
lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
|
57 |
+
lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
|
58 |
+
# TODO: 122 is the number of train batches per epoch, should be infered and set
|
59 |
+
# warmup batch should be in the form epoch*(train batch per epoch)
|
60 |
+
warmup_epoch: int = 10 # how many batches linear warm up for
|
61 |
+
final_epoch: int = 20 # final batch of training when want learning rate
|
62 |
+
|
63 |
+
top_k: int = int(
|
64 |
+
0.05 * batch_size
|
65 |
+
) # if the corresponding rna/GE appears in the top k, the correctly classified
|
66 |
+
label_smoothing: float = 0.0
|
67 |
+
cross_val: bool = False
|
68 |
+
filter_seq_length:bool = True
|
69 |
+
train_epoch: int = 800
|
70 |
+
max_epochs:int = 1000
|
conf/train_model_configs/tcga.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
dirname, _ = os.path.split(os.path.dirname(__file__))
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class GeneEmbeddModelConfig:
|
11 |
+
|
12 |
+
model_input: str = "" #will be infered
|
13 |
+
|
14 |
+
num_embed_hidden: int = 100 #30 for exp, 100 for rest
|
15 |
+
ff_input_dim:int = 0 #is infered later, equals gene expression len
|
16 |
+
ff_hidden_dim: List = field(default_factory=lambda: [300]) #300 for exp hico
|
17 |
+
feed_forward1_hidden: int = 256
|
18 |
+
num_attention_project: int = 64
|
19 |
+
num_encoder_layers: int = 1
|
20 |
+
dropout: float = 0.2
|
21 |
+
n: int = 121
|
22 |
+
relative_attns: List = field(default_factory=lambda: [29, 4, 6, 8, 10, 11])
|
23 |
+
num_attention_heads: int = 5
|
24 |
+
|
25 |
+
window: int = 2
|
26 |
+
# 200 is max rna length.
|
27 |
+
# TODO: if tokenizer is overlap, then max_length should be 60
|
28 |
+
# otherwise, will get cuda error, maybe dask can help
|
29 |
+
max_length: int = 40
|
30 |
+
tokens_len: int = math.ceil(max_length / window)
|
31 |
+
second_input_token_len: int = 0 # is infered during runtime
|
32 |
+
vocab_size: int = 0 # is infered during runtime
|
33 |
+
second_input_vocab_size: int = 0 # is infered during runtime
|
34 |
+
tokenizer: str = (
|
35 |
+
"overlap" # either overlap or no_overlap or overlap_multi_window
|
36 |
+
)
|
37 |
+
|
38 |
+
clf_target:str = 'sub_class_hico' # sub_class, major_class, sub_class_hico or major_class_hico. hico = high confidence
|
39 |
+
num_classes: int = 0 #will be infered during runtime
|
40 |
+
class_mappings:List = field(default_factory=lambda: [])#will be infered during runtime
|
41 |
+
class_weights :List = field(default_factory=lambda: [])
|
42 |
+
# how many extra window sizes other than deafault window
|
43 |
+
temperatures: List = field(default_factory=lambda: [0,10])
|
44 |
+
|
45 |
+
tokens_mapping_dict: Dict = None
|
46 |
+
false_input_perc:float = 0.0
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class GeneEmbeddTrainConfig:
|
51 |
+
dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
|
52 |
+
precursor_file_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/HBDxBase.csv'
|
53 |
+
mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'
|
54 |
+
device: str = "cuda"
|
55 |
+
l2_weight_decay: float = 0.05
|
56 |
+
batch_size: int = 512
|
57 |
+
|
58 |
+
batch_per_epoch:int = 0 # is infered during runtime
|
59 |
+
|
60 |
+
label_smoothing_sim:float = 0.2
|
61 |
+
label_smoothing_clf:float = 0.0
|
62 |
+
# learning rate
|
63 |
+
learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
|
64 |
+
lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
|
65 |
+
lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
|
66 |
+
# TODO: 122 is the number of train batches per epoch, should be infered and set
|
67 |
+
# warmup batch should be during the form epoch*(train batch per epoch)
|
68 |
+
warmup_epoch: int = 10 # how many batches linear warm up for
|
69 |
+
final_epoch: int = 20 # final batch of training when want learning rate
|
70 |
+
|
71 |
+
top_k: int = 10#int(0.1 * batch_size) # if the corresponding rna/GE appears during the top k, the correctly classified
|
72 |
+
cross_val: bool = False
|
73 |
+
labels_mapping_path: str = None
|
74 |
+
filter_seq_length:bool = False
|
75 |
+
|
76 |
+
num_augment_exp:int = 20
|
77 |
+
shuffle_exp: bool = False
|
78 |
+
|
79 |
+
max_epochs: int = 3000
|
80 |
+
|
81 |
+
|
environment.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: transforna
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- bioconda
|
5 |
+
- conda-forge
|
6 |
+
dependencies:
|
7 |
+
- anndata==0.8.0
|
8 |
+
- dill==0.3.6
|
9 |
+
- hydra-core==1.3.0
|
10 |
+
- imbalanced-learn==0.9.1
|
11 |
+
- matplotlib==3.5.3
|
12 |
+
- numpy==1.22.3
|
13 |
+
- omegaconf==2.2.2
|
14 |
+
- pandas==1.5.2
|
15 |
+
- plotly==5.10.0
|
16 |
+
- PyYAML==6.0
|
17 |
+
- rich==12.6.0
|
18 |
+
- viennarna=2.5.0=py39h98c8e45_0
|
19 |
+
- scanpy==1.9.1
|
20 |
+
- scikit-learn==1.2.0
|
21 |
+
- skorch==0.12.1
|
22 |
+
- pytorch=1.10.1=py3.9_cuda11.3_cudnn8.2.0_0
|
23 |
+
- tensorboard==2.11.2
|
24 |
+
- Levenshtein==0.21.0
|
install.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
python_version="3.9"
|
3 |
+
# Initialize conda
|
4 |
+
eval "$(conda shell.bash hook)"
|
5 |
+
#print the current environment
|
6 |
+
echo "The current environment is $CONDA_DEFAULT_ENV."
|
7 |
+
while [[ "$CONDA_DEFAULT_ENV" != "base" ]]; do
|
8 |
+
conda deactivate
|
9 |
+
done
|
10 |
+
|
11 |
+
#if conde transforna not found in the list of environments, then create the environment
|
12 |
+
if [[ $(conda env list | grep "transforna") == "" ]]; then
|
13 |
+
conda create -n transforna python=$python_version -y
|
14 |
+
conda activate transforna
|
15 |
+
conda install -c anaconda setuptools -y
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
fi
|
20 |
+
conda activate transforna
|
21 |
+
|
22 |
+
echo "The current environment is transforna."
|
23 |
+
pip install setuptools==59.5.0
|
24 |
+
# Uninstall TransfoRNA using pip
|
25 |
+
pip uninstall -y TransfoRNA
|
26 |
+
|
27 |
+
rm -rf dist TransfoRNA.egg-info
|
28 |
+
|
29 |
+
|
30 |
+
# Reinstall TransfoRNA using pip
|
31 |
+
python setup.py sdist
|
32 |
+
pip install dist/TransfoRNA-0.0.1.tar.gz
|
33 |
+
rm -rf TransfoRNA.egg-info dist
|
kba_pipeline/README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The HBDx knowledge-based annotation (KBA) pipeline for small RNA sequences
|
2 |
+
|
3 |
+
Most small RNA annotation tools map the sequences sequentially to different small RNA class specific reference databases, which prioritizes the distinct small RNA classes and conceals potential assignment ambiguities. The annotation strategy used here, maps the sequences to the reference sequences of all small RNA classes at the same time starting with zero mismatch tolerance. Unmapped sequences are intended to map with iterating mismatch tolerance up to three mismatches. To reduce ambiguity, sequences are first mapped to the standard coding and non-coding genes with increasing mismatch tolerance. Only then the unassigned sequences are mapped to pseudogenes in the same manner. Additionally, all small RNA sequences are checked for potential bacterial or viral origin, for genomic overlap to human transposable element loci and whether they contain potential prefixes of the 5‘ adapter.
|
4 |
+
|
5 |
+
In cases of multiple assignments per sequence (multiple precursors could be the origin of the sequence), the ambigous annotation is resolved if
|
6 |
+
a) all assigned precursors overlap with the genomic region of the precursor with the shortest genomic context -> the subclass name of the precursor with the shortest genomic context is used OR if
|
7 |
+
b) a bin of the assigned subclass names is at the 5' or 3' end of the respective precursor -> the subclass name matching the precursor end is used.
|
8 |
+
In cases where subclass names of a) and b) are not identical, the subclass name of method a) is assigned.
|
9 |
+
|
10 |
+
|
11 |
+
![kba_pipeline_scheme_v05](https://github.com/gitHBDX/TransfoRNA/assets/79092907/62bf9e36-c7c7-4ff5-b747-c2c651281b42)
|
12 |
+
|
13 |
+
|
14 |
+
a) Schematic overview of the knowledge-based annotation (KBA) strategy applied for TransfoRNA.
|
15 |
+
|
16 |
+
b) Schematic overview of the miRNA annotation of the custom annotation (isomiR definition based on recent miRNA research [1]).
|
17 |
+
|
18 |
+
c) Schematic overview of the tRNA annotation of the custom annotation (inspired by UNITAS sub-classification [2]).
|
19 |
+
|
20 |
+
d) Binning strategy used in the custom annotation for the remaining RNA major classes. The number of nucleotides per bin is constant for each precursor sequence and ranges between 20 and 39 nucleotides. Assignments are based on the bin with the highest overlap to the sequence of interest.
|
21 |
+
|
22 |
+
e) Filtering steps that were applied to obtain the set of HICO annotations that were used for training of the TransfoRNA models.
|
23 |
+
|
24 |
+
|
25 |
+
## Install environment
|
26 |
+
|
27 |
+
```bash
|
28 |
+
cd kba_pipeline
|
29 |
+
conda env create --file environment.yml
|
30 |
+
```
|
31 |
+
|
32 |
+
## Run annotation pipeline
|
33 |
+
|
34 |
+
<b>Prerequisites:</b>
|
35 |
+
- [ ] the sequences to be annotated need to be stored as fasta format in the `kba_pipeline/data` folder
|
36 |
+
- [ ] the reference files for mapping need to be stored in the `kba_pipeline/references` folder (the required subfolders `HBDxBase`, `hg38` and `bacterial_viral` can be downloaded together with the TransfoRNA models from https://www.dropbox.com/sh/y7u8cofmg41qs0y/AADvj5lw91bx7fcDxghMbMtsa?dl=0)
|
37 |
+
|
38 |
+
```bash
|
39 |
+
conda activate hbdx_kba
|
40 |
+
cd src
|
41 |
+
python make_anno.py --fasta_file your_sequences_to_be_annotated.fa
|
42 |
+
```
|
43 |
+
|
44 |
+
This script calls two major functions:
|
45 |
+
- <b>map_2_HBDxBase</b>: sequential mismatch mapping to HBDxBase and genome
|
46 |
+
- <b>annotate_from_mapping</b>: generate sequence annotation based on mapping outputs
|
47 |
+
|
48 |
+
The main annotation file `sRNA_anno_aggregated_on_seq.csv` will be generated in the folder `outputs`
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
## References
|
53 |
+
|
54 |
+
[1] Tomasello, Luisa, Rosario Distefano, Giovanni Nigita, and Carlo M. Croce. 2021. “The MicroRNA Family Gets Wider: The IsomiRs Classification and Role.” Frontiers in Cell and Developmental Biology 9 (June): 1–15. https://doi.org/10.3389/fcell.2021.668648.
|
55 |
+
|
56 |
+
[2] Gebert, Daniel, Charlotte Hewel, and David Rosenkranz. 2017. “Unitas: The Universal Tool for Annotation of Small RNAs.” BMC Genomics 18 (1): 1–14. https://doi.org/10.1186/s12864-017-4031-9.
|
57 |
+
|
58 |
+
|
kba_pipeline/environment.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: hbdx_kba
|
2 |
+
channels:
|
3 |
+
- bioconda
|
4 |
+
- conda-forge
|
5 |
+
dependencies:
|
6 |
+
- anndata=0.8.0
|
7 |
+
- bedtools=2.30.0
|
8 |
+
- biopython=1.79
|
9 |
+
- bowtie=1.3.1
|
10 |
+
- joblib>=1.2.0
|
11 |
+
- pyfastx=0.8.4
|
12 |
+
- pytest
|
13 |
+
- python=3.10.6
|
14 |
+
- pyyaml
|
15 |
+
- rich
|
16 |
+
- samtools=1.16.1
|
17 |
+
- tqdm
|
18 |
+
- viennarna=2.5.1
|
19 |
+
- levenshtein
|
20 |
+
- pip
|
kba_pipeline/src/annotate_from_mapping.py
ADDED
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
######################################################################################################
|
2 |
+
# annotate sequences based on mapping results
|
3 |
+
######################################################################################################
|
4 |
+
#%%
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from difflib import get_close_matches
|
11 |
+
from Levenshtein import distance
|
12 |
+
import json
|
13 |
+
|
14 |
+
from joblib import Parallel, delayed
|
15 |
+
import multiprocessing
|
16 |
+
|
17 |
+
|
18 |
+
from utils import (fasta2df, fasta2df_subheader,log_time, reverse_complement)
|
19 |
+
from precursor_bins import get_bin_with_max_overlap
|
20 |
+
|
21 |
+
|
22 |
+
log = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
pd.options.mode.chained_assignment = None
|
25 |
+
|
26 |
+
|
27 |
+
######################################################################################################
|
28 |
+
# paths to reference and mapping files
|
29 |
+
######################################################################################################
|
30 |
+
|
31 |
+
version = '_v4'
|
32 |
+
|
33 |
+
HBDxBase_csv = f'../../references/HBDxBase/HBDxBase_all{version}.csv'
|
34 |
+
miRBase_mature_path = '../../references/HBDxBase/miRBase/mature.fa'
|
35 |
+
mat_miRNA_pos_path = '../../references/HBDxBase/miRBase/hsa_mature_position.txt'
|
36 |
+
|
37 |
+
mapped_file = 'seqsmapped2HBDxBase_combined.txt'
|
38 |
+
unmapped_file = 'tmp_seqs3mm2HBDxBase_pseudo__unmapped.fa'
|
39 |
+
TE_file = 'tmp_seqsmapped2genome_intersect_TE.txt'
|
40 |
+
mapped_genome_file = 'seqsmapped2genome_combined.txt'
|
41 |
+
toomanyloci_genome_file = 'tmp_seqs0mm2genome__toomanyalign.fa'
|
42 |
+
unmapped_adapter_file = 'tmp_seqs3mm2adapters__unmapped.fa'
|
43 |
+
unmapped_genome_file = 'tmp_seqs0mm2genome__unmapped.fa'
|
44 |
+
unmapped_bacterial_file = 'tmp_seqs0mm2bacterial__unmapped.fa'
|
45 |
+
unmapped_viral_file = 'tmp_seqs0mm2viral__unmapped.fa'
|
46 |
+
|
47 |
+
|
48 |
+
sRNA_anno_file = 'sRNA_anno_from_mapping.csv'
|
49 |
+
aggreg_sRNA_anno_file = 'sRNA_anno_aggregated_on_seq.csv'
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
#%%
|
54 |
+
######################################################################################################
|
55 |
+
# specific functions
|
56 |
+
######################################################################################################
|
57 |
+
|
58 |
+
@log_time(log)
|
59 |
+
def extract_general_info(mapping_file):
|
60 |
+
# load mapping file
|
61 |
+
mapping_df = pd.read_csv(mapping_file, sep='\t', header=None)
|
62 |
+
mapping_df.columns = ['tmp_seq_id','reference','ref_start','sequence','other_alignments','mm_descriptors']
|
63 |
+
|
64 |
+
# add precursor length + number of bins that will be used for names
|
65 |
+
HBDxBase_df = pd.read_csv(HBDxBase_csv, index_col=0)
|
66 |
+
HBDxBase_df = HBDxBase_df[['precursor_length','precursor_bins','pseudo_class']].reset_index()
|
67 |
+
HBDxBase_df.rename(columns={'index': "reference"}, inplace=True)
|
68 |
+
mapping_df = mapping_df.merge(HBDxBase_df, left_on='reference', right_on='reference', how='left')
|
69 |
+
|
70 |
+
# extract information
|
71 |
+
mapping_df.loc[:,'mms'] = mapping_df.mm_descriptors.fillna('').str.count('>')
|
72 |
+
mapping_df.loc[:,'mm_descriptors'] = mapping_df.mm_descriptors.str.replace(',', ';')
|
73 |
+
mapping_df.loc[:,'small_RNA_class_annotation'] = mapping_df.reference.str.split('|').str[0]
|
74 |
+
mapping_df.loc[:,'subclass_type'] = mapping_df.reference.str.split('|').str[2]
|
75 |
+
mapping_df.loc[:,'precursor_name_full'] = mapping_df.reference.str.split('|').str[1].str.split('|').str[0]
|
76 |
+
mapping_df.loc[:,'precursor_name'] = mapping_df.precursor_name_full.str.split('__').str[0].str.split('|').str[0]
|
77 |
+
mapping_df.loc[:,'seq_length'] = mapping_df.sequence.apply(lambda x: len(x))
|
78 |
+
mapping_df.loc[:,'ref_end'] = mapping_df.ref_start + mapping_df.seq_length - 1
|
79 |
+
mapping_df.loc[:,'mitochondrial'] = np.where(mapping_df.reference.str.contains(r'(\|MT-)|(12S)|(16S)'), 'mito', 'nuclear')
|
80 |
+
|
81 |
+
return mapping_df
|
82 |
+
|
83 |
+
|
84 |
+
#%%
|
85 |
+
@log_time(log)
|
86 |
+
def tRNA_annotation(mapping_df):
|
87 |
+
"""Extract tRNA specific annotation from mapping.
|
88 |
+
"""
|
89 |
+
# keep only tRNA leader/trailer with right cutting sites (+/- 5nt)
|
90 |
+
# leader
|
91 |
+
tRF_leader_df = mapping_df[mapping_df['subclass_type'] == 'leader_tRF']
|
92 |
+
# assign as misc-leader-tRF if exceeding defined cutting site range
|
93 |
+
tRF_leader_df.loc[:,'subclass_type'] = np.where((tRF_leader_df.ref_start + tRF_leader_df.sequence.apply(lambda x: len(x))).between(45, 55, inclusive='both'), 'leader_tRF', 'misc-leader-tRF')
|
94 |
+
|
95 |
+
# trailer
|
96 |
+
tRF_trailer_df = mapping_df[mapping_df['subclass_type'] == 'trailer_tRF']
|
97 |
+
# assign as misc-trailer-tRF if exceeding defined cutting site range
|
98 |
+
tRF_trailer_df.loc[:,'subclass_type'] = np.where(tRF_trailer_df.ref_start.between(0, 5, inclusive='both'), 'trailer_tRF', 'misc-trailer-tRF')
|
99 |
+
|
100 |
+
# define tRF subclasses (leader_tRF and trailer_tRF have been assigned previously)
|
101 |
+
# NOTE: allow more flexibility at ends (similar to miRNA annotation)
|
102 |
+
tRNAs_df = mapping_df[((mapping_df['small_RNA_class_annotation'] == 'tRNA') & mapping_df['subclass_type'].isna())]
|
103 |
+
tRNAs_df.loc[((tRNAs_df.ref_start < 3) & (tRNAs_df.seq_length >= 30)),'subclass_type'] = '5p-tR-half'
|
104 |
+
tRNAs_df.loc[((tRNAs_df.ref_start < 3) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '5p-tRF'
|
105 |
+
tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)) < 6) & (tRNAs_df.seq_length >= 30)),'subclass_type'] = '3p-tR-half'
|
106 |
+
tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)).between(3,6,inclusive='neither')) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '3p-tRF'
|
107 |
+
tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)) < 3) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '3p-CCA-tRF'
|
108 |
+
tRNAs_df.loc[tRNAs_df.subclass_type.isna(),'subclass_type'] = 'misc-tRF'
|
109 |
+
# add ref_iso flag
|
110 |
+
tRNAs_df['tRNA_ref_iso'] = np.where(
|
111 |
+
(
|
112 |
+
(tRNAs_df.ref_start == 0)
|
113 |
+
| ((tRNAs_df.ref_end + 1) == tRNAs_df.precursor_length)
|
114 |
+
| ((tRNAs_df.ref_end + 1) == (tRNAs_df.precursor_length - 3))
|
115 |
+
), 'reftRF', 'isotRF'
|
116 |
+
)
|
117 |
+
# concat tRNA, leader & trailer dfs
|
118 |
+
tRNAs_df = pd.concat([tRNAs_df, tRF_leader_df, tRF_trailer_df],axis=0)
|
119 |
+
# adjust precursor name and create tRNA name
|
120 |
+
tRNAs_df['precursor_name'] = tRNAs_df.precursor_name.str.extract(r"((tRNA-...-...)|(MT-..)|(tRX-...-...)|(tRNA-i...-...))", expand=True)[0]
|
121 |
+
tRNAs_df['subclass_name'] = tRNAs_df.subclass_type + '__' + tRNAs_df.precursor_name
|
122 |
+
|
123 |
+
return tRNAs_df
|
124 |
+
|
125 |
+
#%%
|
126 |
+
def faustrules_check(row):
|
127 |
+
"""Check if isomiRs follow Faustrules (based on Tomasello et al. 2021).
|
128 |
+
"""
|
129 |
+
|
130 |
+
# mark seqs that are not in range +/- 2nt of mature start
|
131 |
+
# check if ref_start.between(miRNAs_df.mature_start-2, miRNAs_df.mature_start+2, inclusive='both')]
|
132 |
+
ref_start = row['ref_start']
|
133 |
+
mature_start = row['mature_start']
|
134 |
+
|
135 |
+
if ref_start < mature_start - 2 or ref_start > mature_start + 2:
|
136 |
+
return False
|
137 |
+
|
138 |
+
# mark seqs with mismatch unless A>G or C>T in seed region (= position 0-8) or 3' polyA/polyT (max 3nt)
|
139 |
+
if pd.isna(row['mm_descriptors']):
|
140 |
+
return True
|
141 |
+
|
142 |
+
seed_region_positions = set(range(9))
|
143 |
+
non_templated_ends = {'A', 'AA', 'AAA', 'T', 'TT', 'TTT'}
|
144 |
+
|
145 |
+
sequence = row['sequence']
|
146 |
+
mm_descriptors = row['mm_descriptors'].split(';')
|
147 |
+
|
148 |
+
seed_region_mismatches = 0
|
149 |
+
three_prime_end_mismatches = 0
|
150 |
+
|
151 |
+
for descriptor in mm_descriptors:
|
152 |
+
pos, change = descriptor.split(':')
|
153 |
+
pos = int(pos)
|
154 |
+
original, new = change.split('>')
|
155 |
+
|
156 |
+
if pos in seed_region_positions and (original == 'A' and new == 'G' or original == 'C' and new == 'T'):
|
157 |
+
seed_region_mismatches += 1
|
158 |
+
|
159 |
+
if pos >= len(sequence) - 3 and sequence[pos:] in non_templated_ends:
|
160 |
+
three_prime_end_mismatches += 1
|
161 |
+
|
162 |
+
total_mismatches = seed_region_mismatches + three_prime_end_mismatches
|
163 |
+
|
164 |
+
return total_mismatches == len(mm_descriptors)
|
165 |
+
|
166 |
+
@log_time(log)
|
167 |
+
def miRNA_annotation(mapping_df):
|
168 |
+
"""Extract miRNA specific annotation from mapping. RaH Faustrules are applied.
|
169 |
+
"""
|
170 |
+
|
171 |
+
miRNAs_df = mapping_df[mapping_df.small_RNA_class_annotation == 'miRNA']
|
172 |
+
|
173 |
+
nr_missing_alignments_expected = len(miRNAs_df.loc[miRNAs_df.duplicated(['tmp_seq_id','reference'], keep='first'),:])
|
174 |
+
|
175 |
+
# load positions of mature miRNAs within precursor
|
176 |
+
miRNA_pos_df = pd.read_csv(mat_miRNA_pos_path, sep='\t')
|
177 |
+
miRNA_pos_df.drop(columns=['precursor_length'], inplace=True)
|
178 |
+
miRNAs_df = miRNAs_df.merge(miRNA_pos_df, left_on='precursor_name_full', right_on='name_precursor', how='left')
|
179 |
+
|
180 |
+
# load mature miRNA sequences from miRBase
|
181 |
+
miRBase_mature_df = fasta2df_subheader(miRBase_mature_path,0)
|
182 |
+
# subset to human miRNAs
|
183 |
+
miRBase_mature_df = miRBase_mature_df.loc[miRBase_mature_df.index.str.contains('hsa-'),:]
|
184 |
+
miRBase_mature_df.index = miRBase_mature_df.index.str.replace('hsa-','')
|
185 |
+
miRBase_mature_df.reset_index(inplace=True)
|
186 |
+
miRBase_mature_df.columns = ['name_mature','ref_miR_seq']
|
187 |
+
# add 'ref_miR_seq'
|
188 |
+
miRNAs_df = miRNAs_df.merge(miRBase_mature_df, left_on='name_mature', right_on='name_mature', how='left')
|
189 |
+
|
190 |
+
# for each duplicated tmp_seq_id/reference combi, keep the one lowest lev dist of sequence to ref_miR_seq
|
191 |
+
miRNAs_df['lev_dist'] = miRNAs_df.apply(lambda x: distance(x['sequence'], x['ref_miR_seq']), axis=1)
|
192 |
+
miRNAs_df = miRNAs_df.sort_values(by=['tmp_seq_id','lev_dist'], ascending=[True, True]).drop_duplicates(['tmp_seq_id','reference'], keep='first')
|
193 |
+
|
194 |
+
# add ref_iso flag
|
195 |
+
miRNAs_df['miRNA_ref_iso'] = np.where(
|
196 |
+
(
|
197 |
+
(miRNAs_df.ref_start == miRNAs_df.mature_start)
|
198 |
+
& (miRNAs_df.ref_end == miRNAs_df.mature_end)
|
199 |
+
& (miRNAs_df.mms == 0)
|
200 |
+
), 'refmiR', 'isomiR'
|
201 |
+
)
|
202 |
+
|
203 |
+
# apply RaH Faustrules
|
204 |
+
miRNAs_df['faustrules_check'] = miRNAs_df.apply(faustrules_check, axis=1)
|
205 |
+
|
206 |
+
# set miRNA_ref_iso to 'misc-miR' if faustrules_check is False
|
207 |
+
miRNAs_df.loc[~miRNAs_df.faustrules_check,'miRNA_ref_iso'] = 'misc-miR'
|
208 |
+
|
209 |
+
# set subclass_name to name_mature if faustrules_check is True, else use precursor_name
|
210 |
+
miRNAs_df['subclass_name'] = np.where(miRNAs_df.faustrules_check, miRNAs_df.name_mature, miRNAs_df.precursor_name)
|
211 |
+
|
212 |
+
# store name_mature for functional analysis as miRNA_names, set miR- to mir- if faustrules_check is False
|
213 |
+
miRNAs_df['miRNA_names'] = np.where(miRNAs_df.faustrules_check, miRNAs_df.name_mature, miRNAs_df.name_mature.str.replace('miR-', 'mir-'))
|
214 |
+
|
215 |
+
# add subclass (NOTE: in cases where subclass is not part of mature name, use position relative to precursor half to define group )
|
216 |
+
miRNAs_df['subclass_type'] = np.where(miRNAs_df.name_mature.str.endswith('5p'), '5p', np.where(miRNAs_df.name_mature.str.endswith('3p'), '3p', 'tbd'))
|
217 |
+
miRNAs_df.loc[((miRNAs_df.subclass_type == 'tbd') & (miRNAs_df.mature_start < miRNAs_df.precursor_length/2)), 'subclass_type'] = '5p'
|
218 |
+
miRNAs_df.loc[((miRNAs_df.subclass_type == 'tbd') & (miRNAs_df.mature_start >= miRNAs_df.precursor_length/2)), 'subclass_type'] = '3p'
|
219 |
+
|
220 |
+
# subset to relevant columns
|
221 |
+
miRNAs_df = miRNAs_df[list(mapping_df.columns) + ['subclass_name','miRNA_ref_iso','miRNA_names','ref_miR_seq']]
|
222 |
+
|
223 |
+
return miRNAs_df, nr_missing_alignments_expected
|
224 |
+
|
225 |
+
|
226 |
+
#%%
|
227 |
+
######################################################################################################
|
228 |
+
# annotation of other sRNA classes
|
229 |
+
######################################################################################################
|
230 |
+
def get_bin_with_max_overlap_parallel(df):
|
231 |
+
return df.apply(get_bin_with_max_overlap, axis=1)
|
232 |
+
|
233 |
+
def applyParallel(df, func):
|
234 |
+
retLst = Parallel(n_jobs=multiprocessing.cpu_count())(delayed(func)(group) for group in np.array_split(df,30))
|
235 |
+
return pd.concat(retLst)
|
236 |
+
|
237 |
+
|
238 |
+
@log_time(log)
|
239 |
+
def other_sRNA_annotation_new_binning(mapping_df):
|
240 |
+
"""Generate subclass_name for non-tRNA/miRNA sRNAs by precursor-binning.
|
241 |
+
New binning approach: bin size is dynamically determined by the precursor length. Assignments are based on the bin with the highest overlap.
|
242 |
+
"""
|
243 |
+
|
244 |
+
other_sRNAs_df = mapping_df[~((mapping_df.small_RNA_class_annotation == 'miRNA') | (mapping_df.small_RNA_class_annotation == 'tRNA'))]
|
245 |
+
|
246 |
+
#create empty columns; bin start and bin end
|
247 |
+
other_sRNAs_df['bin_start'] = ''
|
248 |
+
other_sRNAs_df['bin_end'] = ''
|
249 |
+
|
250 |
+
other_sRNAs_df = applyParallel(other_sRNAs_df, get_bin_with_max_overlap_parallel)
|
251 |
+
|
252 |
+
return other_sRNAs_df
|
253 |
+
|
254 |
+
|
255 |
+
#%%
|
256 |
+
@log_time(log)
|
257 |
+
def extract_sRNA_class_specific_info(mapping_df):
|
258 |
+
tRNAs_df = tRNA_annotation(mapping_df)
|
259 |
+
miRNAs_df, nr_missing_alignments_expected = miRNA_annotation(mapping_df)
|
260 |
+
other_sRNAs_df = other_sRNA_annotation_new_binning(mapping_df)
|
261 |
+
|
262 |
+
# add miRNA columns
|
263 |
+
tRNAs_df[['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']] = pd.DataFrame(columns=['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq'])
|
264 |
+
other_sRNAs_df[['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']] = pd.DataFrame(columns=['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq'])
|
265 |
+
|
266 |
+
# re-concat sRNA class dfs
|
267 |
+
sRNA_anno_df = pd.concat([miRNAs_df, tRNAs_df, other_sRNAs_df],axis=0)
|
268 |
+
|
269 |
+
# TEST if alignments were lost or duplicated
|
270 |
+
assert ((len(mapping_df) - nr_missing_alignments_expected) == len(sRNA_anno_df)), "alignments were lost or duplicated"
|
271 |
+
|
272 |
+
return sRNA_anno_df
|
273 |
+
|
274 |
+
#%%
|
275 |
+
def get_nth_nt(row):
|
276 |
+
return row['sequence'][int(row['PTM_position_in_seq'])-1]
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
#%%
|
281 |
+
@log_time(log)
|
282 |
+
def aggregate_info_per_seq(sRNA_anno_df):
|
283 |
+
# fillna of 'subclass_name_bin_pos' with 'subclass_name'
|
284 |
+
sRNA_anno_df['subclass_name_bin_pos'] = sRNA_anno_df['subclass_name_bin_pos'].fillna(sRNA_anno_df['subclass_name'])
|
285 |
+
# get aggregated info per seq
|
286 |
+
aggreg_per_seq_df = sRNA_anno_df.groupby(['sequence']).agg({'small_RNA_class_annotation': lambda x: ';'.join(sorted(x.unique())), 'pseudo_class': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'subclass_type': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'subclass_name': lambda x: ';'.join(sorted(x.unique())), 'subclass_name_bin_pos': lambda x: ';'.join(sorted(x.unique())), 'miRNA_names': lambda x: ';'.join(x.fillna('').unique()), 'precursor_name_full': lambda x: ';'.join(sorted(x.unique())), 'mms': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'reference': lambda x: len(x), 'mitochondrial': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'ref_miR_seq': lambda x: ';'.join(x.fillna('').unique())})
|
287 |
+
aggreg_per_seq_df['miRNA_names'] = aggreg_per_seq_df.miRNA_names.str.replace(r';$','', regex=True)
|
288 |
+
aggreg_per_seq_df['ref_miR_seq'] = aggreg_per_seq_df.ref_miR_seq.str.replace(r';$','', regex=True)
|
289 |
+
aggreg_per_seq_df['mms'] = aggreg_per_seq_df['mms'].astype(int)
|
290 |
+
|
291 |
+
# re-add 'miRNA_ref_iso','tRNA_ref_iso'
|
292 |
+
refmir_df = sRNA_anno_df[['sequence','miRNA_ref_iso','tRNA_ref_iso']]
|
293 |
+
refmir_df.drop_duplicates('sequence', inplace=True)
|
294 |
+
refmir_df.set_index('sequence', inplace=True)
|
295 |
+
aggreg_per_seq_df = aggreg_per_seq_df.merge(refmir_df, left_index=True, right_index=True, how='left')
|
296 |
+
|
297 |
+
# TEST if sequences were lost
|
298 |
+
assert (len(aggreg_per_seq_df) == len(sRNA_anno_df.sequence.unique())), "sequences were lost by aggregation"
|
299 |
+
|
300 |
+
# load unmapped seqs, if it exits
|
301 |
+
if os.path.exists(unmapped_file):
|
302 |
+
unmapped_df = fasta2df(unmapped_file)
|
303 |
+
unmapped_df = pd.DataFrame(data='no_annotation', index=unmapped_df.sequence, columns=aggreg_per_seq_df.columns)
|
304 |
+
unmapped_df['mms'] = np.nan
|
305 |
+
unmapped_df['reference'] = np.nan
|
306 |
+
unmapped_df['pseudo_class'] = True # set no annotation as pseudo_class
|
307 |
+
|
308 |
+
# merge mapped and unmapped
|
309 |
+
annotation_df = pd.concat([aggreg_per_seq_df,unmapped_df])
|
310 |
+
else:
|
311 |
+
annotation_df = aggreg_per_seq_df.copy()
|
312 |
+
|
313 |
+
# load mapping to genome file
|
314 |
+
mapping_genome_df = pd.read_csv(mapped_genome_file, index_col=0, sep='\t', header=None)
|
315 |
+
mapping_genome_df.columns = ['strand','reference','ref_start','sequence','other_alignments','mm_descriptors']
|
316 |
+
mapping_genome_df = mapping_genome_df[['strand','reference','ref_start','sequence','other_alignments']]
|
317 |
+
|
318 |
+
# use reverse complement of 'sequence' for 'strand' == '-'
|
319 |
+
mapping_genome_df.loc[:,'sequence'] = np.where(mapping_genome_df.strand == '-', mapping_genome_df.sequence.apply(lambda x: reverse_complement(x)), mapping_genome_df.sequence)
|
320 |
+
|
321 |
+
# get aggregated info per seq
|
322 |
+
aggreg_per_seq__genome_df = mapping_genome_df.groupby('sequence').agg({'reference': lambda x: ';'.join(sorted(x.unique())), 'other_alignments': lambda x: len(x)})
|
323 |
+
aggreg_per_seq__genome_df['other_alignments'] = aggreg_per_seq__genome_df['other_alignments'].astype(int)
|
324 |
+
|
325 |
+
# number of genomic loci
|
326 |
+
genomic_loci_df = pd.DataFrame(mapping_genome_df.sequence.value_counts())
|
327 |
+
genomic_loci_df.columns = ['num_genomic_loci_maps']
|
328 |
+
|
329 |
+
# load too many aligments seqs
|
330 |
+
if os.path.exists(toomanyloci_genome_file):
|
331 |
+
toomanyloci_genome_df = fasta2df(toomanyloci_genome_file)
|
332 |
+
toomanyloci_genome_df = pd.DataFrame(data=101, index=toomanyloci_genome_df.sequence, columns=genomic_loci_df.columns)
|
333 |
+
else:
|
334 |
+
toomanyloci_genome_df = pd.DataFrame(columns=genomic_loci_df.columns)
|
335 |
+
|
336 |
+
# load unmapped seqs
|
337 |
+
if os.path.exists(unmapped_genome_file):
|
338 |
+
unmapped_genome_df = fasta2df(unmapped_genome_file)
|
339 |
+
unmapped_genome_df = pd.DataFrame(data=0, index=unmapped_genome_df.sequence, columns=genomic_loci_df.columns)
|
340 |
+
else:
|
341 |
+
unmapped_genome_df = pd.DataFrame(columns=genomic_loci_df.columns)
|
342 |
+
|
343 |
+
# concat toomanyloci, unmapped, and genomic_loci
|
344 |
+
num_genomic_loci_maps_df = pd.concat([genomic_loci_df,toomanyloci_genome_df,unmapped_genome_df])
|
345 |
+
|
346 |
+
# merge to annotation_df
|
347 |
+
annotation_df = annotation_df.merge(num_genomic_loci_maps_df, left_index=True, right_index=True, how='left')
|
348 |
+
annotation_df.reset_index(inplace=True)
|
349 |
+
|
350 |
+
# add 'miRNA_seed'
|
351 |
+
annotation_df.loc[:,"miRNA_seed"] = np.where(annotation_df.small_RNA_class_annotation.str.contains('miRNA', na=False), annotation_df.sequence.str[1:9], "")
|
352 |
+
|
353 |
+
# TEST if nan values in 'num_genomic_loci_maps'
|
354 |
+
assert (annotation_df.num_genomic_loci_maps.isna().any() == False), "nan values in 'num_genomic_loci_maps'"
|
355 |
+
|
356 |
+
return annotation_df
|
357 |
+
|
358 |
+
|
359 |
+
|
360 |
+
|
361 |
+
#%%
|
362 |
+
@log_time(log)
|
363 |
+
def get_five_prime_adapter_info(annotation_df, five_prime_adapter):
|
364 |
+
adapter_df = pd.DataFrame(index=annotation_df.sequence)
|
365 |
+
|
366 |
+
min_length = 6
|
367 |
+
|
368 |
+
is_prefixed = None
|
369 |
+
print("5' adapter affixes:")
|
370 |
+
for l in range(0, len(five_prime_adapter) - min_length):
|
371 |
+
is_prefixed_l = adapter_df.index.str.startswith(five_prime_adapter[l:])
|
372 |
+
print(f"{five_prime_adapter[l:].ljust(30, ' ')}{is_prefixed_l.sum()}")
|
373 |
+
adapter_df.loc[adapter_df.index.str.startswith(five_prime_adapter[l:]), "five_prime_adapter_length"] = len(five_prime_adapter[l:])
|
374 |
+
if is_prefixed is None:
|
375 |
+
is_prefixed = is_prefixed_l
|
376 |
+
else:
|
377 |
+
is_prefixed |= is_prefixed_l
|
378 |
+
|
379 |
+
print(f"There are {is_prefixed.sum()} prefixed features.")
|
380 |
+
print("\n")
|
381 |
+
|
382 |
+
adapter_df['five_prime_adapter_length'] = adapter_df['five_prime_adapter_length'].fillna(0)
|
383 |
+
adapter_df['five_prime_adapter_length'] = adapter_df['five_prime_adapter_length'].astype('int')
|
384 |
+
adapter_df['five_prime_adapter_filter'] = np.where(adapter_df['five_prime_adapter_length'] == 0, True, False)
|
385 |
+
adapter_df = adapter_df.reset_index()
|
386 |
+
|
387 |
+
return adapter_df
|
388 |
+
|
389 |
+
#%%
|
390 |
+
@log_time(log)
|
391 |
+
def reduce_ambiguity(annotation_df: pd.DataFrame) -> pd.DataFrame:
|
392 |
+
"""Reduce ambiguity by
|
393 |
+
|
394 |
+
a) using subclass_name of precursor with shortest genomic context, if all other assigned precursors overlap with its genomic region
|
395 |
+
|
396 |
+
b) using subclass_name whose bin is at the 5' or 3' end of the precursor
|
397 |
+
|
398 |
+
Parameters
|
399 |
+
----------
|
400 |
+
annotation_df : pd.DataFrame
|
401 |
+
A DataFrame containing the annotation of the sequences (var)
|
402 |
+
|
403 |
+
Returns
|
404 |
+
-------
|
405 |
+
pd.DataFrame
|
406 |
+
An improved version of the input DataFrame with reduced ambiguity
|
407 |
+
"""
|
408 |
+
|
409 |
+
# extract ambigious assignments for subclass name
|
410 |
+
ambigious_matches_df = annotation_df[annotation_df.subclass_name.str.contains(';',na=False)]
|
411 |
+
if len(ambigious_matches_df) == 0:
|
412 |
+
print('No ambigious assignments for subclass name found.')
|
413 |
+
return annotation_df
|
414 |
+
clear_matches_df = annotation_df[~annotation_df.subclass_name.str.contains(';',na=False)]
|
415 |
+
|
416 |
+
# extract required information from HBDxBase
|
417 |
+
HBDxBase_all_df = pd.read_csv(HBDxBase_csv, index_col=0)
|
418 |
+
bin_dict = HBDxBase_all_df[['precursor_name','precursor_bins']].set_index('precursor_name').to_dict()['precursor_bins']
|
419 |
+
sRNA_class_dict = HBDxBase_all_df[['precursor_name','small_RNA_class_annotation']].set_index('precursor_name').to_dict()['small_RNA_class_annotation']
|
420 |
+
pseudo_class_dict = HBDxBase_all_df[['precursor_name','pseudo_class']].set_index('precursor_name').to_dict()['pseudo_class']
|
421 |
+
sc_type_dict = HBDxBase_all_df[['precursor_name','subclass_type']].set_index('precursor_name').to_dict()['subclass_type']
|
422 |
+
genomic_context_bed = HBDxBase_all_df[['chr','start','end','precursor_name','score','strand']]
|
423 |
+
genomic_context_bed.columns = ['seq_id','start','end','name','score','strand']
|
424 |
+
genomic_context_bed.reset_index(drop=True, inplace=True)
|
425 |
+
genomic_context_bed['genomic_length'] = genomic_context_bed.end - genomic_context_bed.start
|
426 |
+
|
427 |
+
|
428 |
+
def get_overlaps(genomic_context_bed: pd.DataFrame, name: str = None, complement: bool = False) -> list:
|
429 |
+
"""Get genomic overlap of a given precursor name
|
430 |
+
|
431 |
+
Parameters
|
432 |
+
----------
|
433 |
+
genomic_context_bed : pd.DataFrame
|
434 |
+
A DataFrame containing genomic locations of precursors in bed format
|
435 |
+
with column names: 'chr','start','end','precursor_name','score','strand'
|
436 |
+
name : str
|
437 |
+
The name of the precursor to get genomic context for
|
438 |
+
complement : bool
|
439 |
+
If True, return all precursors that do not overlap with the given precursor
|
440 |
+
|
441 |
+
Returns
|
442 |
+
-------
|
443 |
+
list
|
444 |
+
A list containing the precursors in the genomic (anti-)context of the given precursor
|
445 |
+
(including the precursor itself)
|
446 |
+
"""
|
447 |
+
series_OI = genomic_context_bed[genomic_context_bed['name'] == name]
|
448 |
+
start = series_OI['start'].values[0]
|
449 |
+
end = series_OI['end'].values[0]
|
450 |
+
seq_id = series_OI['seq_id'].values[0]
|
451 |
+
strand = series_OI['strand'].values[0]
|
452 |
+
|
453 |
+
overlap_df = genomic_context_bed.copy()
|
454 |
+
|
455 |
+
condition = (((overlap_df.start > start) &
|
456 |
+
(overlap_df.start < end)) |
|
457 |
+
((overlap_df.end > start) &
|
458 |
+
(overlap_df.end < end)) |
|
459 |
+
((overlap_df.start < start) &
|
460 |
+
(overlap_df.end > start)) |
|
461 |
+
((overlap_df.start == start) &
|
462 |
+
(overlap_df.end == end)) |
|
463 |
+
((overlap_df.start == start) &
|
464 |
+
(overlap_df.end > end)) |
|
465 |
+
((overlap_df.start < start) &
|
466 |
+
(overlap_df.end == end)))
|
467 |
+
if not complement:
|
468 |
+
overlap_df = overlap_df[condition]
|
469 |
+
else:
|
470 |
+
overlap_df = overlap_df[~condition]
|
471 |
+
overlap_df = overlap_df[overlap_df.seq_id == seq_id]
|
472 |
+
if strand is not None:
|
473 |
+
overlap_df = overlap_df[overlap_df.strand == strand]
|
474 |
+
overlap_list = overlap_df['name'].tolist()
|
475 |
+
return overlap_list
|
476 |
+
|
477 |
+
|
478 |
+
def check_genomic_ctx_of_smallest_prec(precursor_name: str) -> str:
|
479 |
+
"""Check for a given ambigious precursor assignment (several names separated by ';')
|
480 |
+
if all assigned precursors overlap with the genomic region
|
481 |
+
of the precursor with the shortest genomic context
|
482 |
+
|
483 |
+
Parameters
|
484 |
+
----------
|
485 |
+
precursor_name: str
|
486 |
+
A string containing several precursor names separated by ';'
|
487 |
+
|
488 |
+
Returns
|
489 |
+
-------
|
490 |
+
str
|
491 |
+
The precursor suggested to be used instead of the multi assignment,
|
492 |
+
or None if the ambiguity could not be resolved
|
493 |
+
"""
|
494 |
+
assigned_names = precursor_name.split(';')
|
495 |
+
|
496 |
+
tmp_genomic_context = genomic_context_bed[genomic_context_bed.name.isin(assigned_names)]
|
497 |
+
# get name of smallest genomic region
|
498 |
+
if len(tmp_genomic_context) > 0:
|
499 |
+
smallest_name = tmp_genomic_context.name[tmp_genomic_context.genomic_length.idxmin()]
|
500 |
+
# check if all assigned names are in overlap of smallest genomic region
|
501 |
+
if set(assigned_names).issubset(set(get_overlaps(genomic_context_bed,smallest_name))):
|
502 |
+
return smallest_name
|
503 |
+
else:
|
504 |
+
return None
|
505 |
+
else:
|
506 |
+
return None
|
507 |
+
|
508 |
+
def get_subclass_name(subclass_name: str, short_prec_match_new_name: str) -> str:
|
509 |
+
"""Get subclass name matching to a precursor name from a ambigious assignment (several names separated by ';')
|
510 |
+
|
511 |
+
Parameters
|
512 |
+
----------
|
513 |
+
subclass_name: str
|
514 |
+
A string containing several subclass names separated by ';'
|
515 |
+
short_prec_match_new_name: str
|
516 |
+
The name of the precursor to be used instead of the multi assignment
|
517 |
+
|
518 |
+
Returns
|
519 |
+
-------
|
520 |
+
str
|
521 |
+
The subclass name suggested to be used instead of the multi assignment,
|
522 |
+
or None if the ambiguity could not be resolved
|
523 |
+
"""
|
524 |
+
if short_prec_match_new_name is not None:
|
525 |
+
matches = get_close_matches(short_prec_match_new_name,subclass_name.split(';'),cutoff=0.2)
|
526 |
+
if matches:
|
527 |
+
return matches[0]
|
528 |
+
else:
|
529 |
+
print(f"Could not find match for {short_prec_match_new_name} in {subclass_name}")
|
530 |
+
return subclass_name
|
531 |
+
else:
|
532 |
+
return None
|
533 |
+
|
534 |
+
|
535 |
+
def check_end_bins(subclass_name: str) -> str:
|
536 |
+
"""Check for a given ambigious subclass name assignment (several names separated by ';')
|
537 |
+
if ambiguity can be resolved by selecting the subclass name whose bin matches the 3'/5' end of the precursor
|
538 |
+
|
539 |
+
Parameters
|
540 |
+
----------
|
541 |
+
subclass_name: str
|
542 |
+
A string containing several subclass names separated by ';'
|
543 |
+
|
544 |
+
Returns
|
545 |
+
-------
|
546 |
+
str
|
547 |
+
The subclass name suggested to be used instead of the multi assignment,
|
548 |
+
or None if the ambiguity could not be resolved
|
549 |
+
"""
|
550 |
+
for name in subclass_name.split(';'):
|
551 |
+
if '_bin-' in name:
|
552 |
+
name_parts = name.split('_bin-')
|
553 |
+
if name_parts[0] in bin_dict and bin_dict[name_parts[0]] == int(name_parts[1]):
|
554 |
+
return name
|
555 |
+
elif int(name_parts[1]) == 1:
|
556 |
+
return name
|
557 |
+
return None
|
558 |
+
|
559 |
+
|
560 |
+
def adjust_4_resolved_cases(row: pd.Series) -> tuple:
|
561 |
+
"""For a resolved ambiguous subclass names return adjusted values of
|
562 |
+
precursor_name_full, small_RNA_class_annotation, pseudo_class, and subclass_type
|
563 |
+
|
564 |
+
Parameters
|
565 |
+
----------
|
566 |
+
row: pd.Series
|
567 |
+
A row of the var annotation containing the columns 'subclass_name', 'precursor_name_full',
|
568 |
+
'small_RNA_class_annotation', 'pseudo_class', 'subclass_type', and 'ambiguity_resolved'
|
569 |
+
|
570 |
+
Returns
|
571 |
+
-------
|
572 |
+
tuple
|
573 |
+
A tuple containing the adjusted values of 'precursor_name_full', 'small_RNA_class_annotation',
|
574 |
+
'pseudo_class', and 'subclass_type' for resolved ambiguous cases and the original values for unresolved cases
|
575 |
+
"""
|
576 |
+
if row.ambiguity_resolved:
|
577 |
+
matches_prec = get_close_matches(row.subclass_name, row.precursor_name_full.split(';'), cutoff=0.2)
|
578 |
+
if matches_prec:
|
579 |
+
return matches_prec[0], sRNA_class_dict[matches_prec[0]], pseudo_class_dict[matches_prec[0]], sc_type_dict[matches_prec[0]]
|
580 |
+
return row.precursor_name_full, row.small_RNA_class_annotation, row.pseudo_class, row.subclass_type
|
581 |
+
|
582 |
+
|
583 |
+
# resolve ambiguity by checking genomic context of smallest precursor
|
584 |
+
ambigious_matches_df['short_prec_match_new_name'] = ambigious_matches_df.precursor_name_full.apply(check_genomic_ctx_of_smallest_prec)
|
585 |
+
ambigious_matches_df['short_prec_match_new_name'] = ambigious_matches_df.apply(lambda x: get_subclass_name(x.subclass_name, x.short_prec_match_new_name), axis=1)
|
586 |
+
ambigious_matches_df['short_prec_match'] = ambigious_matches_df['short_prec_match_new_name'].notnull()
|
587 |
+
|
588 |
+
# resolve ambiguity by checking if bin matches 3'/5' end of precursor
|
589 |
+
ambigious_matches_df['end_bin_match_new_name'] = ambigious_matches_df.subclass_name.apply(check_end_bins)
|
590 |
+
ambigious_matches_df['end_bin_match'] = ambigious_matches_df['end_bin_match_new_name'].notnull()
|
591 |
+
|
592 |
+
# check if short_prec_match and end_bin_match are equal in any case
|
593 |
+
test_df = ambigious_matches_df[((ambigious_matches_df.short_prec_match == True) & (ambigious_matches_df.end_bin_match == True))]
|
594 |
+
if not (test_df.short_prec_match_new_name == test_df.end_bin_match_new_name).all():
|
595 |
+
print('Number of cases where short_prec_match is not matching end_bin_match_new_name:',len(test_df[(test_df.short_prec_match_new_name != test_df.end_bin_match_new_name)]))
|
596 |
+
|
597 |
+
# replace subclass_name with short_prec_match_new_name or end_bin_match_new_name
|
598 |
+
# NOTE: if short_prec_match and end_bin_match are True, short_prec_match_new_name is used
|
599 |
+
ambigious_matches_df['subclass_name'] = ambigious_matches_df.apply(lambda x: x.end_bin_match_new_name if x.end_bin_match == True else x.subclass_name, axis=1)
|
600 |
+
ambigious_matches_df['subclass_name'] = ambigious_matches_df.apply(lambda x: x.short_prec_match_new_name if x.short_prec_match == True else x.subclass_name, axis=1)
|
601 |
+
|
602 |
+
# generate column 'ambiguity_resolved' which is True if short_prec_match and/or end_bin_match is True
|
603 |
+
ambigious_matches_df['ambiguity_resolved'] = ambigious_matches_df.short_prec_match | ambigious_matches_df.end_bin_match
|
604 |
+
print("Ambiguity resolved?\n",ambigious_matches_df.ambiguity_resolved.value_counts(normalize=True))
|
605 |
+
|
606 |
+
# for resolved ambiguous matches, adjust precursor_name_full, small_RNA_class_annotation, pseudo_class, subclass_type
|
607 |
+
ambigious_matches_df[['precursor_name_full','small_RNA_class_annotation','pseudo_class','subclass_type']] = ambigious_matches_df.apply(adjust_4_resolved_cases, axis=1, result_type='expand')
|
608 |
+
|
609 |
+
# drop temporary columns
|
610 |
+
ambigious_matches_df.drop(columns=['short_prec_match_new_name','short_prec_match','end_bin_match_new_name','end_bin_match'], inplace=True)
|
611 |
+
|
612 |
+
# concat with clear_matches_df
|
613 |
+
clear_matches_df['ambiguity_resolved'] = False
|
614 |
+
improved_annotation_df = pd.concat([clear_matches_df, ambigious_matches_df], axis=0)
|
615 |
+
improved_annotation_df = improved_annotation_df.reindex(annotation_df.index)
|
616 |
+
|
617 |
+
return improved_annotation_df
|
618 |
+
|
619 |
+
#%%
|
620 |
+
######################################################################################################
|
621 |
+
# HICO (=high confidence) annotation
|
622 |
+
######################################################################################################
|
623 |
+
@log_time(log)
|
624 |
+
def add_hico_annotation(annotation_df, five_prime_adapter):
|
625 |
+
"""For miRNAs only use hico annotation if part of miRBase hico set AND refmiR
|
626 |
+
"""
|
627 |
+
|
628 |
+
# add 'TE_annotation'
|
629 |
+
TE_df = pd.read_csv(TE_file, sep='\t', header=None, names=['sequence','TE_annotation'])
|
630 |
+
annotation_df = annotation_df.merge(TE_df, left_on='sequence', right_on='sequence', how='left')
|
631 |
+
|
632 |
+
# add 'bacterial' mapping filter
|
633 |
+
bacterial_unmapped_df = fasta2df(unmapped_bacterial_file)
|
634 |
+
annotation_df.loc[:,'bacterial'] = np.where(annotation_df.sequence.isin(bacterial_unmapped_df.sequence), False, True)
|
635 |
+
|
636 |
+
# add 'viral' mapping filter
|
637 |
+
viral_unmapped_df = fasta2df(unmapped_viral_file)
|
638 |
+
annotation_df.loc[:,'viral'] = np.where(annotation_df.sequence.isin(viral_unmapped_df.sequence), False, True)
|
639 |
+
|
640 |
+
# add 'adapter_mapping_filter' column
|
641 |
+
adapter_unmapped_df = fasta2df(unmapped_adapter_file)
|
642 |
+
annotation_df.loc[:,'adapter_mapping_filter'] = np.where(annotation_df.sequence.isin(adapter_unmapped_df.sequence), True, False)
|
643 |
+
|
644 |
+
# add filter column 'five_prime_adapter_filter' and column 'five_prime_adapter_length' indicating the length of the prefixed 5' adapter sequence
|
645 |
+
adapter_df = get_five_prime_adapter_info(annotation_df, five_prime_adapter)
|
646 |
+
annotation_df = annotation_df.merge(adapter_df, left_on='sequence', right_on='sequence', how='left')
|
647 |
+
|
648 |
+
# apply ambiguity reduction
|
649 |
+
annotation_df = reduce_ambiguity(annotation_df)
|
650 |
+
|
651 |
+
# add 'single_class_annotation'
|
652 |
+
annotation_df.loc[:,'single_class_annotation'] = np.where(annotation_df.small_RNA_class_annotation.str.contains(';',na=True), False, True)
|
653 |
+
|
654 |
+
# add 'single_name_annotation'
|
655 |
+
annotation_df.loc[:,'single_name_annotation'] = np.where(annotation_df.subclass_name.str.contains(';',na=True), False, True)
|
656 |
+
|
657 |
+
# add 'hypermapper' for sequences where more than 50 potential mapping references are recorded
|
658 |
+
annotation_df.loc[annotation_df.reference > 50,'subclass_name'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str)
|
659 |
+
annotation_df.loc[annotation_df.reference > 50,'subclass_name_bin_pos'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str)
|
660 |
+
annotation_df.loc[annotation_df.reference > 50,'precursor_name_full'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str)
|
661 |
+
|
662 |
+
annotation_df.loc[:,'mitochondrial'] = np.where(annotation_df.mitochondrial.str.contains('mito',na=False), True, False)
|
663 |
+
|
664 |
+
# add 'hico'
|
665 |
+
annotation_df.loc[:,'hico'] = np.where((
|
666 |
+
(annotation_df.mms == 0)
|
667 |
+
& (annotation_df.single_name_annotation == True)
|
668 |
+
& (annotation_df.TE_annotation.isna() == True)
|
669 |
+
& (annotation_df.bacterial == False)
|
670 |
+
& (annotation_df.viral == False)
|
671 |
+
& (annotation_df.adapter_mapping_filter == True)
|
672 |
+
& (annotation_df.five_prime_adapter_filter == True)
|
673 |
+
), True, False)
|
674 |
+
## NOTE: for miRNAs only use hico annotation if part of refmiR set
|
675 |
+
annotation_df.loc[annotation_df.small_RNA_class_annotation == 'miRNA','hico'] = annotation_df.loc[annotation_df.small_RNA_class_annotation == 'miRNA','hico'] & (annotation_df.miRNA_ref_iso == 'refmiR')
|
676 |
+
|
677 |
+
print(annotation_df[annotation_df.single_class_annotation == True].groupby('small_RNA_class_annotation').hico.value_counts())
|
678 |
+
|
679 |
+
return annotation_df
|
680 |
+
|
681 |
+
|
682 |
+
|
683 |
+
|
684 |
+
#%%
|
685 |
+
######################################################################################################
|
686 |
+
# annotation pipeline
|
687 |
+
######################################################################################################
|
688 |
+
@log_time(log)
|
689 |
+
def main(five_prime_adapter):
|
690 |
+
"""Executes 'annotate_from_mapping'.
|
691 |
+
|
692 |
+
Uses:
|
693 |
+
|
694 |
+
- HBDxBase_csv
|
695 |
+
- miRBase_mature_path
|
696 |
+
- mat_miRNA_pos_path
|
697 |
+
|
698 |
+
- mapping_file
|
699 |
+
- unmapped_file
|
700 |
+
- mapped_genome_file
|
701 |
+
- toomanyloci_genome_file
|
702 |
+
- unmapped_genome_file
|
703 |
+
|
704 |
+
- TE_file
|
705 |
+
- unmapped_adapter_file
|
706 |
+
- unmapped_bacterial_file
|
707 |
+
- unmapped_viral_file
|
708 |
+
- five_prime_adapter
|
709 |
+
|
710 |
+
"""
|
711 |
+
|
712 |
+
|
713 |
+
print('-------- extract general information for sequences that mapped to the HBDxBase --------')
|
714 |
+
mapped_info_df = extract_general_info(mapped_file)
|
715 |
+
print("\n")
|
716 |
+
|
717 |
+
print('-------- extract sRNA class specific information for sequences that mapped to the HBDxBase --------')
|
718 |
+
mapped_sRNA_anno_df = extract_sRNA_class_specific_info(mapped_info_df)
|
719 |
+
|
720 |
+
print('-------- save to file --------')
|
721 |
+
mapped_sRNA_anno_df.to_csv(sRNA_anno_file)
|
722 |
+
print("\n")
|
723 |
+
|
724 |
+
print('-------- aggregate information for mapped and unmapped sequences (HBDxBase & human genome) --------')
|
725 |
+
sRNA_anno_per_seq_df = aggregate_info_per_seq(mapped_sRNA_anno_df)
|
726 |
+
print("\n")
|
727 |
+
|
728 |
+
print('-------- add hico annotation (based on aggregated infos + mapping to viral/bacterial genomes + intersection with TEs) --------')
|
729 |
+
sRNA_anno_per_seq_df = add_hico_annotation(sRNA_anno_per_seq_df, five_prime_adapter)
|
730 |
+
print("\n")
|
731 |
+
|
732 |
+
print('-------- save to file --------')
|
733 |
+
# set sequence as index again
|
734 |
+
sRNA_anno_per_seq_df.set_index('sequence', inplace=True)
|
735 |
+
sRNA_anno_per_seq_df.to_csv(aggreg_sRNA_anno_file)
|
736 |
+
print("\n")
|
737 |
+
|
738 |
+
print('-------- generate subclass_to_annotation dict --------')
|
739 |
+
result_df = sRNA_anno_per_seq_df[['subclass_name', 'small_RNA_class_annotation']].copy()
|
740 |
+
result_df.reset_index(drop=True, inplace=True)
|
741 |
+
result_df.drop_duplicates(inplace=True)
|
742 |
+
result_df = result_df[~result_df["subclass_name"].str.contains(";")]
|
743 |
+
subclass_to_annotation = dict(zip(result_df["subclass_name"],result_df["small_RNA_class_annotation"]))
|
744 |
+
with open('subclass_to_annotation.json', 'w') as fp:
|
745 |
+
json.dump(subclass_to_annotation, fp)
|
746 |
+
|
747 |
+
print('-------- delete tmp files --------')
|
748 |
+
os.system("rm *tmp_*")
|
749 |
+
|
750 |
+
|
751 |
+
#%%
|
kba_pipeline/src/make_anno.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#%%
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
|
7 |
+
from utils import make_output_dir,write_2_log,log_time
|
8 |
+
import map_2_HBDxBase as map_2_HBDxBase
|
9 |
+
import annotate_from_mapping as annotate_from_mapping
|
10 |
+
|
11 |
+
|
12 |
+
log = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
#%%
|
17 |
+
# get command line arguments
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument('--five_prime_adapter', type=str, default='GTTCAGAGTTCTACAGTCCGACGATC')
|
20 |
+
parser.add_argument('--fasta_file', type=str, help="Required to provide: --fasta_file sequences_to_be_annotated.fa") # NOTE: needs to be stored in "data" folder
|
21 |
+
args = parser.parse_args()
|
22 |
+
if not args.fasta_file:
|
23 |
+
parser.print_help()
|
24 |
+
exit()
|
25 |
+
five_prime_adapter = args.five_prime_adapter
|
26 |
+
sequence_file = args.fasta_file
|
27 |
+
|
28 |
+
#%%
|
29 |
+
@log_time(log)
|
30 |
+
def main(five_prime_adapter, sequence_file):
|
31 |
+
"""Executes 'make_anno'.
|
32 |
+
1. Maps input sequences to HBDxBase, the human genome, and a collection of viral and bacterial genomes.
|
33 |
+
2. Extracts information from mapping files.
|
34 |
+
3. Generates annotation columns and final annotation dataframe.
|
35 |
+
|
36 |
+
Uses:
|
37 |
+
|
38 |
+
- sequence_file
|
39 |
+
- five_prime_adapter
|
40 |
+
|
41 |
+
"""
|
42 |
+
output_dir = make_output_dir(sequence_file)
|
43 |
+
os.chdir(output_dir)
|
44 |
+
|
45 |
+
log_folder = "log"
|
46 |
+
if not os.path.exists(log_folder):
|
47 |
+
os.makedirs(log_folder)
|
48 |
+
write_2_log(f"{log_folder}/make_anno.log")
|
49 |
+
|
50 |
+
# add name of sequence_file to log file
|
51 |
+
with open(f"{log_folder}/make_anno.log", "a") as ofile:
|
52 |
+
ofile.write(f"Sequence file: {sequence_file}\n")
|
53 |
+
|
54 |
+
map_2_HBDxBase.main("../../data/" + sequence_file)
|
55 |
+
annotate_from_mapping.main(five_prime_adapter)
|
56 |
+
|
57 |
+
|
58 |
+
main(five_prime_adapter, sequence_file)
|
59 |
+
# %%
|
kba_pipeline/src/map_2_HBDxBase.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
######################################################################################################
|
2 |
+
# map sequences to HBDxBase
|
3 |
+
######################################################################################################
|
4 |
+
#%%
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from utils import fasta2df,log_time
|
9 |
+
|
10 |
+
log = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
######################################################################################################
|
14 |
+
# paths to reference files
|
15 |
+
######################################################################################################
|
16 |
+
|
17 |
+
version = '_v4'
|
18 |
+
|
19 |
+
HBDxBase_index_path = f'../../references/HBDxBase/HBDxBase{version}'
|
20 |
+
HBDxBase_pseudo_index_path = f'../../references/HBDxBase/HBDxBase_pseudo{version}'
|
21 |
+
genome_index_path = '../../references/hg38/genome'
|
22 |
+
adapter_index_path = '../../references/HBDxBase/adapters'
|
23 |
+
TE_path = '../../references/hg38/TE.bed'
|
24 |
+
bacterial_index_path = '../../references/bacterial_viral/all_bacterial_refseq_with_human_host__201127.index'
|
25 |
+
viral_index_path = '../../references/bacterial_viral/viral_refseq_with_human_host__201127.index'
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
#%%
|
31 |
+
######################################################################################################
|
32 |
+
# specific functions
|
33 |
+
######################################################################################################
|
34 |
+
|
35 |
+
@log_time(log)
|
36 |
+
def prepare_input_files(seq_input):
|
37 |
+
|
38 |
+
# check if seq_input is path or list
|
39 |
+
if type(seq_input) == str:
|
40 |
+
# get seqs in dataset
|
41 |
+
seqs = fasta2df(seq_input)
|
42 |
+
seqs = seqs.sequence
|
43 |
+
elif type(seq_input) == list:
|
44 |
+
seqs = seq_input
|
45 |
+
else:
|
46 |
+
raise ValueError('seq_input must be either path to fasta file or list of sequences')
|
47 |
+
|
48 |
+
# add number of sequences to log file
|
49 |
+
log_folder = "log"
|
50 |
+
with open(f"{log_folder}/make_anno.log", "a") as ofile:
|
51 |
+
ofile.write(f"KBA pipeline based on HBDxBase{version}\n")
|
52 |
+
ofile.write(f"Number of sequences to be annotated: {str(len(seqs))}\n")
|
53 |
+
|
54 |
+
if type(seq_input) == str:
|
55 |
+
with open('seqs.fa', 'w') as ofile_1:
|
56 |
+
for i in range(len(seqs)):
|
57 |
+
ofile_1.write(">" + seqs.index[i] + "\n" + seqs[i] + "\n")
|
58 |
+
else:
|
59 |
+
with open('seqs.fa', 'w') as ofile_1:
|
60 |
+
for i in range(len(seqs)):
|
61 |
+
ofile_1.write(">seq_" + str(i) + "\n" + seqs[i] + "\n")
|
62 |
+
|
63 |
+
@log_time(log)
|
64 |
+
def map_seq_2_HBDxBase(
|
65 |
+
number_mm,
|
66 |
+
fasta_in_file,
|
67 |
+
out_prefix
|
68 |
+
):
|
69 |
+
|
70 |
+
bowtie_index_file = HBDxBase_index_path
|
71 |
+
|
72 |
+
os.system(
|
73 |
+
f"bowtie -a --norc -v {number_mm} -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
|
74 |
+
--al {out_prefix + str(number_mm) + 'mm2HBDxBase__mapped.fa'} \
|
75 |
+
--un {out_prefix + str(number_mm) + 'mm2HBDxBase__unmapped.fa'} \
|
76 |
+
{out_prefix + str(number_mm) + 'mm2HBDxBase.txt'}"
|
77 |
+
)
|
78 |
+
@log_time(log)
|
79 |
+
def map_seq_2_HBDxBase_pseudo(
|
80 |
+
number_mm,
|
81 |
+
fasta_in_file,
|
82 |
+
out_prefix
|
83 |
+
):
|
84 |
+
|
85 |
+
bowtie_index_file = HBDxBase_pseudo_index_path
|
86 |
+
|
87 |
+
os.system(
|
88 |
+
f"bowtie -a --norc -v {number_mm} -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
|
89 |
+
--al {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo__mapped.fa'} \
|
90 |
+
--un {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo__unmapped.fa'} \
|
91 |
+
{out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo.txt'}"
|
92 |
+
)
|
93 |
+
# -a Report all valid alignments per read
|
94 |
+
# --norc No mapping to reverse strand
|
95 |
+
# -v Report alignments with at most <int> mismatches
|
96 |
+
# -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
|
97 |
+
# -suppress Suppress columns of output in the default output mode
|
98 |
+
# -x The basename of the Bowtie, or Bowtie 2, index to be searched
|
99 |
+
|
100 |
+
@log_time(log)
|
101 |
+
def map_seq_2_adapters(
|
102 |
+
fasta_in_file,
|
103 |
+
out_prefix
|
104 |
+
):
|
105 |
+
|
106 |
+
bowtie_index_file = adapter_index_path
|
107 |
+
|
108 |
+
os.system(
|
109 |
+
f"bowtie -a --best --strata --norc -v 3 -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
|
110 |
+
--al {out_prefix + '3mm2adapters__mapped.fa'} \
|
111 |
+
--un {out_prefix + '3mm2adapters__unmapped.fa'} \
|
112 |
+
{out_prefix + '3mm2adapters.txt'}"
|
113 |
+
)
|
114 |
+
# -a --best --strata Specifying --strata in addition to -a and --best causes bowtie to report only those alignments in the best alignment “stratum”. The alignments in the best stratum are those having the least number of mismatches
|
115 |
+
# --norc No mapping to reverse strand
|
116 |
+
# -v Report alignments with at most <int> mismatches
|
117 |
+
# -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
|
118 |
+
# -suppress Suppress columns of output in the default output mode
|
119 |
+
# -x The basename of the Bowtie, or Bowtie 2, index to be searched
|
120 |
+
|
121 |
+
|
122 |
+
@log_time(log)
|
123 |
+
def map_seq_2_genome(
|
124 |
+
fasta_in_file,
|
125 |
+
out_prefix
|
126 |
+
):
|
127 |
+
|
128 |
+
bowtie_index_file = genome_index_path
|
129 |
+
|
130 |
+
os.system(
|
131 |
+
f"bowtie -a -v 0 -f -m 100 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
|
132 |
+
--max {out_prefix + '0mm2genome__toomanyalign.fa'} \
|
133 |
+
--un {out_prefix + '0mm2genome__unmapped.fa'} \
|
134 |
+
{out_prefix + '0mm2genome.txt'}"
|
135 |
+
)
|
136 |
+
# -a Report all valid alignments per read
|
137 |
+
# -v Report alignments with at most <int> mismatches
|
138 |
+
# -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
|
139 |
+
# -m Suppress all alignments for a particular read if more than <int> reportable alignments exist for it
|
140 |
+
# -suppress Suppress columns of output in the default output mode
|
141 |
+
# -x The basename of the Bowtie, or Bowtie 2, index to be searched
|
142 |
+
|
143 |
+
|
144 |
+
@log_time(log)
|
145 |
+
def map_seq_2_bacterial_viral(
|
146 |
+
fasta_in_file,
|
147 |
+
out_prefix
|
148 |
+
):
|
149 |
+
|
150 |
+
bowtie_index_file = bacterial_index_path
|
151 |
+
|
152 |
+
os.system(
|
153 |
+
f"bowtie -a -v 0 -f -m 10 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
|
154 |
+
--al {out_prefix + '0mm2bacterial__mapped.fa'} \
|
155 |
+
--max {out_prefix + '0mm2bacterial__toomanyalign.fa'} \
|
156 |
+
--un {out_prefix + '0mm2bacterial__unmapped.fa'} \
|
157 |
+
{out_prefix + '0mm2bacterial.txt'}"
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
bowtie_index_file = viral_index_path
|
162 |
+
|
163 |
+
os.system(
|
164 |
+
f"bowtie -a -v 0 -f -m 10 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
|
165 |
+
--al {out_prefix + '0mm2viral__mapped.fa'} \
|
166 |
+
--max {out_prefix + '0mm2viral__toomanyalign.fa'} \
|
167 |
+
--un {out_prefix + '0mm2viral__unmapped.fa'} \
|
168 |
+
{out_prefix + '0mm2viral.txt'}"
|
169 |
+
)
|
170 |
+
# -a Report all valid alignments per read
|
171 |
+
# -v Report alignments with at most <int> mismatches
|
172 |
+
# -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
|
173 |
+
# -m Suppress all alignments for a particular read if more than <int> reportable alignments exist for it
|
174 |
+
# -suppress Suppress columns of output in the default output mode
|
175 |
+
# -x The basename of the Bowtie, or Bowtie 2, index to be searched
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
#%%
|
182 |
+
######################################################################################################
|
183 |
+
# mapping pipeline
|
184 |
+
######################################################################################################
|
185 |
+
@log_time(log)
|
186 |
+
def main(sequence_file):
|
187 |
+
"""Executes 'map_2_HBDxBase'. Maps input sequences to HBDxBase, the human genome, and a collection of viral and bacterial genomes.
|
188 |
+
|
189 |
+
Uses:
|
190 |
+
|
191 |
+
- HBDxBase_index_path
|
192 |
+
- HBDxBase_pseudo_index_path
|
193 |
+
- genome_index_path
|
194 |
+
- bacterial_index_path
|
195 |
+
- viral_index_path
|
196 |
+
- sequence_file
|
197 |
+
|
198 |
+
"""
|
199 |
+
|
200 |
+
prepare_input_files(sequence_file)
|
201 |
+
|
202 |
+
# sequential mm mapping to HBDxBase
|
203 |
+
print('-------- map to HBDxBase --------')
|
204 |
+
|
205 |
+
print('-------- mapping seqs (0 mm) --------')
|
206 |
+
map_seq_2_HBDxBase(
|
207 |
+
0,
|
208 |
+
'seqs.fa',
|
209 |
+
'tmp_seqs'
|
210 |
+
)
|
211 |
+
|
212 |
+
print('-------- mapping seqs (1 mm) --------')
|
213 |
+
map_seq_2_HBDxBase(
|
214 |
+
1,
|
215 |
+
'tmp_seqs0mm2HBDxBase__unmapped.fa',
|
216 |
+
'tmp_seqs'
|
217 |
+
)
|
218 |
+
|
219 |
+
print('-------- mapping seqs (2 mm) --------')
|
220 |
+
map_seq_2_HBDxBase(
|
221 |
+
2,
|
222 |
+
'tmp_seqs1mm2HBDxBase__unmapped.fa',
|
223 |
+
'tmp_seqs'
|
224 |
+
)
|
225 |
+
|
226 |
+
print('-------- mapping seqs (3 mm) --------')
|
227 |
+
map_seq_2_HBDxBase(
|
228 |
+
3,
|
229 |
+
'tmp_seqs2mm2HBDxBase__unmapped.fa',
|
230 |
+
'tmp_seqs'
|
231 |
+
)
|
232 |
+
|
233 |
+
# sequential mm mapping to Pseudo-HBDxBase
|
234 |
+
print('-------- map to Pseudo-HBDxBase --------')
|
235 |
+
|
236 |
+
print('-------- mapping seqs (0 mm) --------')
|
237 |
+
map_seq_2_HBDxBase_pseudo(
|
238 |
+
0,
|
239 |
+
'tmp_seqs3mm2HBDxBase__unmapped.fa',
|
240 |
+
'tmp_seqs'
|
241 |
+
)
|
242 |
+
|
243 |
+
print('-------- mapping seqs (1 mm) --------')
|
244 |
+
map_seq_2_HBDxBase_pseudo(
|
245 |
+
1,
|
246 |
+
'tmp_seqs0mm2HBDxBase_pseudo__unmapped.fa',
|
247 |
+
'tmp_seqs'
|
248 |
+
)
|
249 |
+
|
250 |
+
print('-------- mapping seqs (2 mm) --------')
|
251 |
+
map_seq_2_HBDxBase_pseudo(
|
252 |
+
2,
|
253 |
+
'tmp_seqs1mm2HBDxBase_pseudo__unmapped.fa',
|
254 |
+
'tmp_seqs'
|
255 |
+
)
|
256 |
+
|
257 |
+
print('-------- mapping seqs (3 mm) --------')
|
258 |
+
map_seq_2_HBDxBase_pseudo(
|
259 |
+
3,
|
260 |
+
'tmp_seqs2mm2HBDxBase_pseudo__unmapped.fa',
|
261 |
+
'tmp_seqs'
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
# concatenate files
|
266 |
+
print('-------- concatenate mapping files --------')
|
267 |
+
os.system("cat tmp_seqs0mm2HBDxBase.txt tmp_seqs1mm2HBDxBase.txt tmp_seqs2mm2HBDxBase.txt tmp_seqs3mm2HBDxBase.txt tmp_seqs0mm2HBDxBase_pseudo.txt tmp_seqs1mm2HBDxBase_pseudo.txt tmp_seqs2mm2HBDxBase_pseudo.txt tmp_seqs3mm2HBDxBase_pseudo.txt > seqsmapped2HBDxBase_combined.txt")
|
268 |
+
|
269 |
+
print('\n')
|
270 |
+
|
271 |
+
# mapping to adapters (allowing for 3 mms)
|
272 |
+
print('-------- map to adapters (3 mm) --------')
|
273 |
+
map_seq_2_adapters(
|
274 |
+
'seqs.fa',
|
275 |
+
'tmp_seqs'
|
276 |
+
)
|
277 |
+
|
278 |
+
# mapping to genome (more than 50 alignments are not reported)
|
279 |
+
print('-------- map to human genome --------')
|
280 |
+
|
281 |
+
print('-------- mapping seqs (0 mm) --------')
|
282 |
+
map_seq_2_genome(
|
283 |
+
'seqs.fa',
|
284 |
+
'tmp_seqs'
|
285 |
+
)
|
286 |
+
|
287 |
+
|
288 |
+
## concatenate files
|
289 |
+
print('-------- concatenate mapping files --------')
|
290 |
+
os.system("cp tmp_seqs0mm2genome.txt seqsmapped2genome_combined.txt")
|
291 |
+
|
292 |
+
print('\n')
|
293 |
+
|
294 |
+
## intersect genome mapping hits with TE.bed
|
295 |
+
print('-------- intersect genome mapping hits with TE.bed --------')
|
296 |
+
# convert to BED format
|
297 |
+
os.system("awk 'BEGIN {FS= \"\t\"; OFS=\"\t\"} {print $3, $4, $4+length($5)-1, $5, 111, $2}' seqsmapped2genome_combined.txt > tmp_seqsmapped2genome_combined.bed")
|
298 |
+
# intersect with TE.bed (force strandedness -> fetch only sRNA_sequence and TE_name -> aggregate TE annotation on sequences)
|
299 |
+
os.system(f"bedtools intersect -a tmp_seqsmapped2genome_combined.bed -b {TE_path} -wa -wb -s" + "| awk '{print $4,$10}' | awk '{a[$1]=a[$1]\";\"$2} END {for(i in a) print i\"\t\"substr(a[i],2)}' > tmp_seqsmapped2genome_intersect_TE.txt")
|
300 |
+
|
301 |
+
# mapping to bacterial and viral genomes (more than 10 alignments are not reported)
|
302 |
+
print('-------- map to bacterial and viral genome --------')
|
303 |
+
|
304 |
+
print('-------- mapping seqs (0 mm) --------')
|
305 |
+
map_seq_2_bacterial_viral(
|
306 |
+
'seqs.fa',
|
307 |
+
'tmp_seqs'
|
308 |
+
)
|
309 |
+
|
310 |
+
## concatenate files
|
311 |
+
print('-------- concatenate mapping files --------')
|
312 |
+
os.system("cat tmp_seqs0mm2bacterial.txt tmp_seqs0mm2viral.txt > seqsmapped2bacterialviral_combined.txt")
|
313 |
+
|
314 |
+
print('\n')
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
kba_pipeline/src/precursor_bins.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import pandas as pd
|
3 |
+
from typing import List
|
4 |
+
from collections.abc import Callable
|
5 |
+
|
6 |
+
def load_HBDxBase():
|
7 |
+
version = '_v4'
|
8 |
+
HBDxBase_file = f'../../references/HBDxBase/HBDxBase_all{version}.csv'
|
9 |
+
HBDxBase_df = pd.read_csv(HBDxBase_file, index_col=0)
|
10 |
+
HBDxBase_df.loc[:,'precursor_bins'] = (HBDxBase_df.precursor_length/25).astype(int)
|
11 |
+
return HBDxBase_df
|
12 |
+
|
13 |
+
def compute_dynamic_bin_size(precursor_len:int, name:str=None, min_bin_size:int=20, max_bin_size:int=30) -> List[int]:
|
14 |
+
'''
|
15 |
+
This function splits precursor to bins of size max_bin_size
|
16 |
+
if the last bin is smaller than min_bin_size, it will split the precursor to bins of size max_bin_size-1
|
17 |
+
This process will continue until the last bin is larger than min_bin_size.
|
18 |
+
if the min bin size is reached and still the last bin is smaller than min_bin_size, the last two bins will be merged.
|
19 |
+
so the maximimum bin size possible would be min_bin_size+(min_bin_size-1) = 39
|
20 |
+
'''
|
21 |
+
def split_precursor_to_bins(precursor_len,max_bin_size):
|
22 |
+
'''
|
23 |
+
This function splits precursor to bins of size max_bin_size
|
24 |
+
'''
|
25 |
+
precursor_bin_lens = []
|
26 |
+
for i in range(0, precursor_len, max_bin_size):
|
27 |
+
if i+max_bin_size < precursor_len:
|
28 |
+
precursor_bin_lens.append(max_bin_size)
|
29 |
+
else:
|
30 |
+
precursor_bin_lens.append(precursor_len-i)
|
31 |
+
return precursor_bin_lens
|
32 |
+
|
33 |
+
if precursor_len < min_bin_size:
|
34 |
+
return [precursor_len]
|
35 |
+
else:
|
36 |
+
precursor_bin_lens = split_precursor_to_bins(precursor_len,max_bin_size)
|
37 |
+
reduced_len = max_bin_size-1
|
38 |
+
while precursor_bin_lens[-1] < min_bin_size:
|
39 |
+
precursor_bin_lens = split_precursor_to_bins(precursor_len,reduced_len)
|
40 |
+
reduced_len -= 1
|
41 |
+
if reduced_len < min_bin_size:
|
42 |
+
#add last two bins together
|
43 |
+
precursor_bin_lens[-2] += precursor_bin_lens[-1]
|
44 |
+
precursor_bin_lens = precursor_bin_lens[:-1]
|
45 |
+
break
|
46 |
+
|
47 |
+
return precursor_bin_lens
|
48 |
+
|
49 |
+
def get_bin_no_from_pos(precursor_len:int,position:int,name:str=None,min_bin_size:int=20,max_bin_size:int=30) -> int:
|
50 |
+
'''
|
51 |
+
This function returns the bin number of a position in a precursor
|
52 |
+
bins start from 1
|
53 |
+
'''
|
54 |
+
precursor_bin_lens = compute_dynamic_bin_size(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size)
|
55 |
+
bin_no = 0
|
56 |
+
for i,bin_len in enumerate(precursor_bin_lens):
|
57 |
+
if position < bin_len:
|
58 |
+
bin_no = i
|
59 |
+
break
|
60 |
+
else:
|
61 |
+
position -= bin_len
|
62 |
+
return bin_no+1
|
63 |
+
|
64 |
+
def get_bin_with_max_overlap(row) -> int:
|
65 |
+
'''
|
66 |
+
This function returns the bin number of a fragment that overlaps the most with the fragment
|
67 |
+
'''
|
68 |
+
precursor_len = row.precursor_length
|
69 |
+
start_frag_pos = row.ref_start
|
70 |
+
frag_len = row.seq_length
|
71 |
+
name = row.precursor_name_full
|
72 |
+
min_bin_size = 20
|
73 |
+
max_bin_size = 30
|
74 |
+
precursor_bin_lens = compute_dynamic_bin_size(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size)
|
75 |
+
bin_no = 0
|
76 |
+
for i,bin_len in enumerate(precursor_bin_lens):
|
77 |
+
if start_frag_pos < bin_len:
|
78 |
+
#get overlap with curr bin
|
79 |
+
overlap = min(bin_len-start_frag_pos,frag_len)
|
80 |
+
|
81 |
+
if overlap > frag_len/2:
|
82 |
+
bin_no = i
|
83 |
+
else:
|
84 |
+
bin_no = i+1
|
85 |
+
break
|
86 |
+
|
87 |
+
else:
|
88 |
+
start_frag_pos -= bin_len
|
89 |
+
#get bin start and bin end
|
90 |
+
bin_start,bin_end = sum(precursor_bin_lens[:bin_no]),sum(precursor_bin_lens[:bin_no+1])
|
91 |
+
row['bin_start'] = bin_start
|
92 |
+
row['bin_end'] = bin_end
|
93 |
+
row['subclass_name'] = name + '_bin-' + str(bin_no+1)
|
94 |
+
row['precursor_bins'] = len(precursor_bin_lens)
|
95 |
+
row['subclass_name_bin_pos'] = name + '_binpos-' + str(bin_start) + ':' + str(bin_end)
|
96 |
+
return row
|
97 |
+
|
98 |
+
def convert_bin_to_pos(precursor_len:int,bin_no:int,bin_function:Callable=compute_dynamic_bin_size,name:str=None,min_bin_size:int=20,max_bin_size:int=30):
|
99 |
+
'''
|
100 |
+
This function returns the start and end position of a bin
|
101 |
+
'''
|
102 |
+
precursor_bin_lens = bin_function(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size)
|
103 |
+
start_pos = 0
|
104 |
+
end_pos = 0
|
105 |
+
for i,bin_len in enumerate(precursor_bin_lens):
|
106 |
+
if i+1 == bin_no:
|
107 |
+
end_pos = start_pos+bin_len
|
108 |
+
break
|
109 |
+
else:
|
110 |
+
start_pos += bin_len
|
111 |
+
return start_pos,end_pos
|
112 |
+
|
113 |
+
#main
|
114 |
+
if __name__ == '__main__':
|
115 |
+
#read hbdxbase
|
116 |
+
HBDxBase_df = load_HBDxBase()
|
117 |
+
min_bin_size = 20
|
118 |
+
max_bin_size = 30
|
119 |
+
#select indices of precurosrs that include 'rRNA' but not 'pseudo'
|
120 |
+
rRNA_df = HBDxBase_df[HBDxBase_df.index.str.contains('rRNA') * ~HBDxBase_df.index.str.contains('pseudo')]
|
121 |
+
|
122 |
+
#get bin of index 1
|
123 |
+
bins = compute_dynamic_bin_size(len(rRNA_df.iloc[0].sequence),rRNA_df.iloc[0].name,min_bin_size,max_bin_size)
|
124 |
+
bin_no = get_bin_no_from_pos(len(rRNA_df.iloc[0].sequence),name=rRNA_df.iloc[0].name,position=1)
|
125 |
+
annotation_bin = get_bin_with_max_overlap(len(rRNA_df.iloc[0].sequence),start_frag_pos=1,frag_len=50,name=rRNA_df.iloc[0].name)
|
126 |
+
|
127 |
+
# %%
|
kba_pipeline/src/utils.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import errno
|
4 |
+
from pathlib import Path
|
5 |
+
from Bio.SeqIO.FastaIO import SimpleFastaParser
|
6 |
+
from datetime import datetime
|
7 |
+
from getpass import getuser
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from rich.logging import RichHandler
|
11 |
+
from functools import wraps
|
12 |
+
from time import perf_counter
|
13 |
+
from typing import Callable
|
14 |
+
|
15 |
+
default_path = '../outputs/'
|
16 |
+
|
17 |
+
def humanize_time(time_in_seconds: float, /) -> str:
|
18 |
+
"""Return a nicely human-readable string of a time_in_seconds.
|
19 |
+
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
time_in_seconds : float
|
23 |
+
Time in seconds, (not full seconds).
|
24 |
+
|
25 |
+
Returns
|
26 |
+
-------
|
27 |
+
str
|
28 |
+
A description of the time in one of the forms:
|
29 |
+
- 300.1 ms
|
30 |
+
- 4.5 sec
|
31 |
+
- 5 min 43.1 sec
|
32 |
+
"""
|
33 |
+
sgn = "" if time_in_seconds >= 0 else "- "
|
34 |
+
time_in_seconds = abs(time_in_seconds)
|
35 |
+
if time_in_seconds < 1:
|
36 |
+
return f"{sgn}{time_in_seconds*1e3:.1f} ms"
|
37 |
+
elif time_in_seconds < 60:
|
38 |
+
return f"{sgn}{time_in_seconds:.1f} sec"
|
39 |
+
else:
|
40 |
+
return f"{sgn}{int(time_in_seconds//60)} min {time_in_seconds%60:.1f} sec"
|
41 |
+
|
42 |
+
|
43 |
+
class log_time:
|
44 |
+
"""A decorator / context manager to log the time a certain function / code block took.
|
45 |
+
|
46 |
+
Usage either with:
|
47 |
+
|
48 |
+
@log_time(log)
|
49 |
+
def function_getting_logged_every_time(…):
|
50 |
+
…
|
51 |
+
|
52 |
+
producing:
|
53 |
+
|
54 |
+
function_getting_logged_every_time took 5 sec.
|
55 |
+
|
56 |
+
or:
|
57 |
+
|
58 |
+
with log_time(log, "Name of this codeblock"):
|
59 |
+
…
|
60 |
+
|
61 |
+
producing:
|
62 |
+
|
63 |
+
Name of this codeblock took 5 sec.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, logger: logging.Logger, name: str = None):
|
67 |
+
"""
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
logger : logging.Logger
|
71 |
+
The logger to use for logging the time, if None use print.
|
72 |
+
name : str, optional
|
73 |
+
The name in the message, when used as a decorator this defaults to the function name, by default None
|
74 |
+
"""
|
75 |
+
self.logger = logger
|
76 |
+
self.name = name
|
77 |
+
|
78 |
+
def __call__(self, func: Callable):
|
79 |
+
if self.name is None:
|
80 |
+
self.name = func.__qualname__
|
81 |
+
|
82 |
+
@wraps(func)
|
83 |
+
def inner(*args, **kwds):
|
84 |
+
with self:
|
85 |
+
return func(*args, **kwds)
|
86 |
+
|
87 |
+
return inner
|
88 |
+
|
89 |
+
def __enter__(self):
|
90 |
+
self.start_time = perf_counter()
|
91 |
+
|
92 |
+
def __exit__(self, *exc):
|
93 |
+
self.exit_time = perf_counter()
|
94 |
+
|
95 |
+
time_delta = humanize_time(self.exit_time - self.start_time)
|
96 |
+
if self.logger is None:
|
97 |
+
print(f"{self.name} took {time_delta}.")
|
98 |
+
else:
|
99 |
+
self.logger.info(f"{self.name} took {time_delta}.")
|
100 |
+
|
101 |
+
|
102 |
+
def write_2_log(log_file):
|
103 |
+
# Setup logging
|
104 |
+
log_file_handler = logging.FileHandler(log_file)
|
105 |
+
log_file_handler.setLevel(logging.INFO)
|
106 |
+
log_file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
107 |
+
log_rich_handler = RichHandler()
|
108 |
+
log_rich_handler.setLevel(logging.INFO) #cli_args.log_level
|
109 |
+
log_rich_handler.setFormatter(logging.Formatter("%(message)s"))
|
110 |
+
logging.basicConfig(level=logging.INFO, datefmt="[%X]", handlers=[log_file_handler, log_rich_handler])
|
111 |
+
|
112 |
+
|
113 |
+
def fasta2df(path):
|
114 |
+
with open(path) as fasta_file:
|
115 |
+
identifiers = []
|
116 |
+
seqs = []
|
117 |
+
for header, sequence in SimpleFastaParser(fasta_file):
|
118 |
+
identifiers.append(header)
|
119 |
+
seqs.append(sequence)
|
120 |
+
|
121 |
+
fasta_df = pd.DataFrame(seqs, identifiers, columns=['sequence'])
|
122 |
+
fasta_df['sequence'] = fasta_df.sequence.apply(lambda x: x.replace('U','T'))
|
123 |
+
return fasta_df
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
def fasta2df_subheader(path, id_pos):
|
128 |
+
with open(path) as fasta_file:
|
129 |
+
identifiers = []
|
130 |
+
seqs = []
|
131 |
+
for header, sequence in SimpleFastaParser(fasta_file):
|
132 |
+
identifiers.append(header.split(None)[id_pos])
|
133 |
+
seqs.append(sequence)
|
134 |
+
|
135 |
+
fasta_df = pd.DataFrame(seqs, identifiers, columns=['sequence'])
|
136 |
+
fasta_df['sequence'] = fasta_df.sequence.apply(lambda x: x.replace('U','T'))
|
137 |
+
return fasta_df
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
def build_bowtie_index(bowtie_index_file):
|
142 |
+
#index_example = Path(bowtie_index_file + '.1.ebwt')
|
143 |
+
#if not index_example.is_file():
|
144 |
+
print('-------- index is build --------')
|
145 |
+
os.system(f"bowtie-build {bowtie_index_file + '.fa'} {bowtie_index_file}")
|
146 |
+
#else: print('-------- previously built index is used --------')
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
def make_output_dir(fasta_file):
|
151 |
+
output_dir = default_path + datetime.now().strftime('%Y-%m-%d') + ('__') + fasta_file.replace('.fasta', '').replace('.fa', '') + '/'
|
152 |
+
try:
|
153 |
+
os.makedirs(output_dir)
|
154 |
+
except OSError as e:
|
155 |
+
if e.errno != errno.EEXIST:
|
156 |
+
raise # This was not a "directory exist" error..
|
157 |
+
return output_dir
|
158 |
+
|
159 |
+
|
160 |
+
def reverse_complement(seq):
|
161 |
+
complement = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}
|
162 |
+
return ''.join([complement[base] for base in seq[::-1]])
|
163 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anndata==0.8.0
|
2 |
+
dill==0.3.6
|
3 |
+
hydra-core==1.3.0
|
4 |
+
imbalanced-learn==0.9.1
|
5 |
+
matplotlib==3.5.3
|
6 |
+
numpy==1.22.3
|
7 |
+
omegaconf==2.2.2
|
8 |
+
pandas==1.5.2
|
9 |
+
plotly==5.10.0
|
10 |
+
PyYAML==6.0
|
11 |
+
rich==12.6.0
|
12 |
+
viennarna==2.5.0a5
|
13 |
+
scanpy==1.9.1
|
14 |
+
scikit_learn==1.2.0
|
15 |
+
skorch==0.12.1
|
16 |
+
torch==1.10.1
|
17 |
+
tensorboard==2.11.2
|
18 |
+
Levenshtein==0.21.0
|
scripts/test_inference_api.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transforna import predict_transforna, predict_transforna_all_models
|
2 |
+
|
3 |
+
seqs = [
|
4 |
+
'AACGAAGCTCGACTTTTAAGG',
|
5 |
+
'GTCCACCCCAAAGCGTAGG']
|
6 |
+
|
7 |
+
path_to_models = '/path/to/tcga/models/'
|
8 |
+
sc_preds_id_df = predict_transforna_all_models(seqs,path_to_models = path_to_models) #/models/tcga/
|
9 |
+
#%%
|
10 |
+
#get sc predictions for models trained on id (in distribution)
|
11 |
+
sc_preds_id_df = predict_transforna(seqs, model="seq",trained_on='id',path_to_models = path_to_models)
|
12 |
+
#get sc predictions for models trained on full (all sub classes)
|
13 |
+
sc_preds_df = predict_transforna(seqs, model="seq",path_to_models = path_to_models)
|
14 |
+
#predict using models trained on major class
|
15 |
+
mc_preds_df = predict_transforna(seqs, model="seq",mc_or_sc='major_class',path_to_models = path_to_models)
|
16 |
+
#get logits
|
17 |
+
logits_df = predict_transforna(seqs, model="seq",logits_flag=True,path_to_models = path_to_models)
|
18 |
+
#get embedds
|
19 |
+
embedd_df = predict_transforna(seqs, model="seq",embedds_flag=True,path_to_models = path_to_models)
|
20 |
+
#get the top 4 similar sequences
|
21 |
+
sim_df = predict_transforna(seqs, model="seq",similarity_flag=True,n_sim=4,path_to_models = path_to_models)
|
22 |
+
#get umaps
|
23 |
+
umaps_df = predict_transforna(seqs, model="seq",umaps_flag=True,path_to_models = path_to_models)
|
24 |
+
|
25 |
+
|
26 |
+
all_preds_df = predict_transforna_all_models(seqs,path_to_models=path_to_models)
|
27 |
+
all_preds_df
|
28 |
+
|
29 |
+
# %%
|
scripts/train.sh
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#data_time for hydra output folder
|
2 |
+
get_data_time(){
|
3 |
+
date=$(ls outputs/ | head -n 1)
|
4 |
+
time=$(ls outputs/*/ | head -n 1)
|
5 |
+
date=$date
|
6 |
+
time=$time
|
7 |
+
}
|
8 |
+
|
9 |
+
train_model(){
|
10 |
+
python -m transforna --config-dir="/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/conf"\
|
11 |
+
model_name=$1 trained_on=$2 num_replicates=$4
|
12 |
+
|
13 |
+
get_data_time
|
14 |
+
#rename the folder to model_name
|
15 |
+
mv outputs/$date/$time outputs/$date/$3
|
16 |
+
ls outputs/$date/
|
17 |
+
rm -rf models/tcga/TransfoRNA_${2^^}/$5/$3
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
mv -f outputs/$date/$3 models/tcga/TransfoRNA_${2^^}/$5/
|
22 |
+
rm -rf outputs/
|
23 |
+
|
24 |
+
}
|
25 |
+
#activate transforna environment
|
26 |
+
eval "$(conda shell.bash hook)"
|
27 |
+
conda activate transforna
|
28 |
+
|
29 |
+
#create the models folder if it does not exist
|
30 |
+
if [[ ! -d "models/tcga/TransfoRNA_ID/major_class" ]]; then
|
31 |
+
mkdir -p models/tcga/TransfoRNA_ID/major_class
|
32 |
+
fi
|
33 |
+
if [[ ! -d "models/tcga/TransfoRNA_FULL/sub_class" ]]; then
|
34 |
+
mkdir -p models/tcga/TransfoRNA_FULL/sub_class
|
35 |
+
fi
|
36 |
+
if [[ ! -d "models/tcga/TransfoRNA_ID/sub_class" ]]; then
|
37 |
+
mkdir -p models/tcga/TransfoRNA_ID/sub_class
|
38 |
+
fi
|
39 |
+
if [[ ! -d "models/tcga/TransfoRNA_FULL/major_class" ]]; then
|
40 |
+
mkdir -p models/tcga/TransfoRNA_FULL/major_class
|
41 |
+
fi
|
42 |
+
#remove the outputs folder
|
43 |
+
rm -rf outputs
|
44 |
+
|
45 |
+
|
46 |
+
#define models
|
47 |
+
models=("seq" "seq-seq" "seq-rev" "seq-struct" "baseline")
|
48 |
+
models_capitalized=("Seq" "Seq-Seq" "Seq-Rev" "Seq-Struct" "Baseline")
|
49 |
+
|
50 |
+
|
51 |
+
num_replicates=5
|
52 |
+
|
53 |
+
|
54 |
+
############train major_class_hico
|
55 |
+
|
56 |
+
##replace clf_target:str = 'sub_class_hico' to clf_target:str = 'major_class_hico' in ../conf/train_model_configs/tcga.py
|
57 |
+
sed -i "s/clf_target:str = 'sub_class_hico'/clf_target:str = 'major_class_hico'/g" conf/train_model_configs/tcga.py
|
58 |
+
#print the file content
|
59 |
+
cat conf/train_model_configs/tcga.py
|
60 |
+
#loop and train
|
61 |
+
for i in ${!models[@]}; do
|
62 |
+
echo "Training model ${models_capitalized[$i]} for id on major_class"
|
63 |
+
train_model ${models[$i]} id ${models_capitalized[$i]} $num_replicates "major_class"
|
64 |
+
echo "Training model ${models[$i]} for full on major_class"
|
65 |
+
train_model ${models[$i]} full ${models_capitalized[$i]} 1 "major_class"
|
66 |
+
done
|
67 |
+
|
68 |
+
|
69 |
+
############train sub_class_hico
|
70 |
+
|
71 |
+
#replace clf_target:str = 'major_class_hico' to clf_target:str = 'sub_class_hico' in ../conf/train_model_configs/tcga.py
|
72 |
+
sed -i "s/clf_target:str = 'major_class_hico'/clf_target:str = 'sub_class_hico'/g" conf/train_model_configs/tcga.py
|
73 |
+
|
74 |
+
for i in ${!models[@]}; do
|
75 |
+
echo "Training model ${models_capitalized[$i]} for id on sub_class"
|
76 |
+
train_model ${models[$i]} id ${models_capitalized[$i]} $num_replicates "sub_class"
|
77 |
+
echo "Training model ${models[$i]} for full on sub_class"
|
78 |
+
train_model ${models[$i]} full ${models_capitalized[$i]} 1 "sub_class"
|
79 |
+
done
|
80 |
+
|
setup.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages, setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='TransfoRNA',
|
5 |
+
version='0.0.1',
|
6 |
+
description='TransfoRNA: Navigating the Uncertainties of Small RNA Annotation with an Adaptive Machine Learning Strategy',
|
7 |
+
url='https://github.com/gitHBDX/TransfoRNA',
|
8 |
+
author='YasserTaha,JuliaJehn',
|
9 |
+
author_email='ytaha@hb-dx.com,jjehn@hb-dx.com,tsikosek@hb-dx.com',
|
10 |
+
classifiers=[
|
11 |
+
'Development Status :: 4 - Beta',
|
12 |
+
'Intended Audience :: Biological Researchers',
|
13 |
+
'License :: OSI Approved :: MIT License',
|
14 |
+
'Programming Language :: Python :: 3.9',
|
15 |
+
],
|
16 |
+
packages=find_packages(include=['transforna', 'transforna.*']),
|
17 |
+
install_requires=[
|
18 |
+
"anndata==0.8.0",
|
19 |
+
"dill==0.3.6",
|
20 |
+
"hydra-core==1.3.0",
|
21 |
+
"imbalanced-learn==0.9.1",
|
22 |
+
"matplotlib==3.5.3",
|
23 |
+
"numpy==1.22.3",
|
24 |
+
"omegaconf==2.2.2",
|
25 |
+
"pandas==1.5.2",
|
26 |
+
"plotly==5.10.0",
|
27 |
+
"PyYAML==6.0",
|
28 |
+
"rich==12.6.0",
|
29 |
+
"viennarna>=2.5.0a5",
|
30 |
+
"scanpy==1.9.1",
|
31 |
+
"scikit_learn==1.2.0",
|
32 |
+
"skorch==0.12.1",
|
33 |
+
"torch==1.10.1",
|
34 |
+
"tensorboard==2.16.2",
|
35 |
+
"Levenshtein==0.21.0"
|
36 |
+
],
|
37 |
+
python_requires='>=3.9',
|
38 |
+
#move yaml files to package
|
39 |
+
package_data={'': ['*.yaml']},
|
40 |
+
)
|
transforna/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .src.callbacks import *
|
2 |
+
from .src.inference import *
|
3 |
+
from .src.model import *
|
4 |
+
from .src.novelty_prediction import *
|
5 |
+
from .src.processing import *
|
6 |
+
from .src.train import *
|
7 |
+
from .src.utils import *
|
transforna/__main__.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import hydra
|
7 |
+
from hydra.core.hydra_config import HydraConfig
|
8 |
+
from hydra.utils import instantiate
|
9 |
+
from omegaconf import DictConfig
|
10 |
+
|
11 |
+
from transforna import compute_cv, infer_benchmark, infer_tcga, train
|
12 |
+
|
13 |
+
warnings.filterwarnings("ignore")
|
14 |
+
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
def add_config_to_sys_path():
|
19 |
+
cfg = HydraConfig.get()
|
20 |
+
config_path = [path["path"] for path in cfg.runtime.config_sources if path["schema"] == "file"][0]
|
21 |
+
sys.path.append(config_path)
|
22 |
+
|
23 |
+
#transforna could called from anywhere:
|
24 |
+
#python -m transforna --config-dir = /path/to/configs
|
25 |
+
@hydra.main(config_path='../conf', config_name="main_config")
|
26 |
+
def main(cfg: DictConfig) -> None:
|
27 |
+
add_config_to_sys_path()
|
28 |
+
#get path of hydra outputs folder
|
29 |
+
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
|
30 |
+
|
31 |
+
path = os.getcwd()
|
32 |
+
#init train and model config
|
33 |
+
cfg['train_config'] = instantiate(cfg['train_config']).__dict__
|
34 |
+
cfg['model_config'] = instantiate(cfg['model_config']).__dict__
|
35 |
+
|
36 |
+
#update model config with the name of the model
|
37 |
+
cfg['model_config']["model_input"] = cfg["model_name"]
|
38 |
+
|
39 |
+
#inference or train
|
40 |
+
if cfg["inference"]:
|
41 |
+
logger.info(f"Started inference on {cfg['task']}")
|
42 |
+
if cfg['task'] == 'tcga':
|
43 |
+
return infer_tcga(cfg,path=path)
|
44 |
+
else:
|
45 |
+
return infer_benchmark(cfg,path=path)
|
46 |
+
else:
|
47 |
+
if cfg["cross_val"]:
|
48 |
+
compute_cv(cfg,path,output_dir=output_dir)
|
49 |
+
|
50 |
+
else:
|
51 |
+
train(cfg,path=path,output_dir=output_dir)
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
main()
|
transforna/bin/figure_scripts/figure_4_table_3.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
#%%
|
4 |
+
#read all files ending with dist_df in bin/lc_files/
|
5 |
+
import pandas as pd
|
6 |
+
import glob
|
7 |
+
from plotly import graph_objects as go
|
8 |
+
from transforna import load,predict_transforna
|
9 |
+
all_df = pd.DataFrame()
|
10 |
+
files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/*lev_dist_df.csv')
|
11 |
+
for file in files:
|
12 |
+
df = pd.read_csv(file)
|
13 |
+
all_df = pd.concat([all_df,df])
|
14 |
+
all_df = all_df.drop(columns=['Unnamed: 0'])
|
15 |
+
all_df.loc[all_df.split.isnull(),'split'] = 'NA'
|
16 |
+
|
17 |
+
#%%
|
18 |
+
#draw a box plot for the Ensemble model for each of the splits using seaborn
|
19 |
+
ensemble_df = all_df[all_df.Model == 'Ensemble'].reset_index(drop=True)
|
20 |
+
#remove other_affixes
|
21 |
+
ensemble_df = ensemble_df[ensemble_df.split != 'other_affixes'].reset_index(drop=True)
|
22 |
+
#rename
|
23 |
+
ensemble_df['split'] = ensemble_df['split'].replace({'5\'A-affixes':'Putative 5’-adapter prefixes','Fused':'Recombined'})
|
24 |
+
ensemble_df['split'] = ensemble_df['split'].replace({'Relaxed-miRNA':'Isomirs'})
|
25 |
+
#%%
|
26 |
+
#plot the boxplot using seaborn
|
27 |
+
import seaborn as sns
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
sns.set_theme(style="whitegrid")
|
30 |
+
sns.set(rc={'figure.figsize':(15,10)})
|
31 |
+
sns.set(font_scale=1.5)
|
32 |
+
order = ['LC-familiar','LC-novel','Random','Putative 5’-adapter prefixes','Recombined','NA','LOCO','Isomirs']
|
33 |
+
ax = sns.boxplot(x="split", y="NLD", data=ensemble_df, palette="Set3",order=order,showfliers = True)
|
34 |
+
|
35 |
+
#add Novelty Threshold line
|
36 |
+
ax.axhline(y=ensemble_df['Novelty Threshold'].mean(), color='g', linestyle='--',xmin=0,xmax=1)
|
37 |
+
#annotate mean of Novelty Threshold
|
38 |
+
ax.annotate('NLD threshold', xy=(1.5, ensemble_df['Novelty Threshold'].mean()), xytext=(1.5, ensemble_df['Novelty Threshold'].mean()-0.07), arrowprops=dict(facecolor='black', shrink=0.05))
|
39 |
+
#rename
|
40 |
+
ax.set_xticklabels(['LC-Familiar','LC-Novel','Random','5’-adapter artefacts','Recombined','NA','LOCO','IsomiRs'])
|
41 |
+
#add title
|
42 |
+
ax.set_facecolor('None')
|
43 |
+
plt.title('Levenshtein Distance Distribution per Split on LC')
|
44 |
+
ax.set(xlabel='Split', ylabel='Normalized Levenshtein Distance')
|
45 |
+
#save legend
|
46 |
+
plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.,facecolor=None,framealpha=0.0)
|
47 |
+
plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_no_out_boxplot.svg',dpi=400)
|
48 |
+
#tilt x axis labels
|
49 |
+
plt.xticks(rotation=-22.5)
|
50 |
+
#save svg
|
51 |
+
plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_seaboarn_boxplot.svg',dpi=1000)
|
52 |
+
##save png
|
53 |
+
plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_seaboarn_boxplot.png',dpi=1000)
|
54 |
+
#%%
|
55 |
+
bars = [r for r in ax.get_children()]
|
56 |
+
colors = []
|
57 |
+
for c in bars[:-1]:
|
58 |
+
try: colors.append(c.get_facecolor())
|
59 |
+
except: pass
|
60 |
+
isomir_color = colors[len(order)-1]
|
61 |
+
isomir_color = [255*x for x in isomir_color]
|
62 |
+
#covert to rgb('r','g','b','a')
|
63 |
+
isomir_color = 'rgb(%s,%s,%s,%s)'%(isomir_color[0],isomir_color[1],isomir_color[2],isomir_color[3])
|
64 |
+
|
65 |
+
#%%
|
66 |
+
relaxed_mirna_df = all_df[all_df.split == 'Relaxed-miRNA']
|
67 |
+
models = relaxed_mirna_df.Model.unique()
|
68 |
+
percentage_dict = {}
|
69 |
+
for model in models:
|
70 |
+
model_df = relaxed_mirna_df[relaxed_mirna_df.Model == model]
|
71 |
+
#compute the % of sequences with NLD < Novelty Threshold for each model
|
72 |
+
percentage_dict[model] = len(model_df[model_df['NLD'] > model_df['Novelty Threshold']])/len(model_df)
|
73 |
+
percentage_dict[model]*=100
|
74 |
+
|
75 |
+
fig = go.Figure()
|
76 |
+
for model in ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev','Ensemble']:
|
77 |
+
fig.add_trace(go.Bar(x=[model],y=[percentage_dict[model]],name=model,marker_color=isomir_color))
|
78 |
+
#add percentage on top of each bar
|
79 |
+
fig.add_annotation(x=model,y=percentage_dict[model]+2,text='%s%%'%(round(percentage_dict[model],2)),showarrow=False)
|
80 |
+
#increase size of annotation
|
81 |
+
fig.update_annotations(dict(font_size=13))
|
82 |
+
#add title in the center
|
83 |
+
fig.update_layout(title='Percentage of Isomirs considered novel per model')
|
84 |
+
fig.update_layout(xaxis_tickangle=+22.5)
|
85 |
+
fig.update_layout(showlegend=False)
|
86 |
+
#make transparent background
|
87 |
+
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
|
88 |
+
#y axis label
|
89 |
+
fig.update_yaxes(title_text='Percentage of Novel Sequences')
|
90 |
+
#save svg
|
91 |
+
fig.show()
|
92 |
+
#save svg
|
93 |
+
#fig.write_image('relaxed_mirna_novel_perc_lc_barplot.svg')
|
94 |
+
#%%
|
95 |
+
#here we explore the false familiar of the ood lc set
|
96 |
+
ood_df = pd.read_csv('/nfs/home/yat_ldap/VS_Projects/TransfoRNA/bin/lc_files/LC-novel_lev_dist_df.csv')
|
97 |
+
mapping_dict_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/subclass_to_annotation.json'
|
98 |
+
mapping_dict = load(mapping_dict_path)
|
99 |
+
|
100 |
+
LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/LC__ngs__DI_HB_GEL-23.1.2.h5ad'
|
101 |
+
ad = load(LC_path)
|
102 |
+
#%%
|
103 |
+
model = 'Ensemble'
|
104 |
+
ood_seqs = ood_df[(ood_df.Model == model).values * (ood_df['Is Familiar?'] == True).values].Sequence.tolist()
|
105 |
+
ood_predicted_labels = ood_df[(ood_df.Model == model).values * (ood_df['Is Familiar?'] == True).values].Labels.tolist()
|
106 |
+
ood_actual_labels = ad.var.loc[ood_seqs]['subclass_name'].values.tolist()
|
107 |
+
from transforna import correct_labels
|
108 |
+
ood_predicted_labels = correct_labels(ood_predicted_labels,ood_actual_labels,mapping_dict)
|
109 |
+
|
110 |
+
#get indices where ood_predicted_labels == ood_actual_labels
|
111 |
+
correct_indices = [i for i, x in enumerate(ood_predicted_labels) if x != ood_actual_labels[i]]
|
112 |
+
#remove the indices from ood_seqs, ood_predicted_labels, ood_actual_labels
|
113 |
+
ood_seqs = [ood_seqs[i] for i in correct_indices]
|
114 |
+
ood_predicted_labels = [ood_predicted_labels[i] for i in correct_indices]
|
115 |
+
ood_actual_labels = [ood_actual_labels[i] for i in correct_indices]
|
116 |
+
#get the major class of the actual labels
|
117 |
+
ood_actual_major_class = [mapping_dict[label] if label in mapping_dict else 'None' for label in ood_actual_labels]
|
118 |
+
ood_predicted_major_class = [mapping_dict[label] if label in mapping_dict else 'None' for label in ood_predicted_labels ]
|
119 |
+
#get frequencies of each major class
|
120 |
+
from collections import Counter
|
121 |
+
ood_actual_major_class_freq = Counter(ood_actual_major_class)
|
122 |
+
ood_predicted_major_class_freq = Counter(ood_predicted_major_class)
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
# %%
|
127 |
+
import plotly.express as px
|
128 |
+
major_classes = list(ood_actual_major_class_freq.keys())
|
129 |
+
|
130 |
+
ood_seqs_len = [len(seq) for seq in ood_seqs]
|
131 |
+
ood_seqs_len_freq = Counter(ood_seqs_len)
|
132 |
+
fig = px.bar(x=list(ood_seqs_len_freq.keys()),y=list(ood_seqs_len_freq.values()))
|
133 |
+
fig.show()
|
134 |
+
|
135 |
+
#%%
|
136 |
+
import plotly.graph_objects as go
|
137 |
+
fig = go.Figure()
|
138 |
+
for major_class in major_classes:
|
139 |
+
len_dist = [len(ood_seqs[i]) for i, x in enumerate(ood_actual_major_class) if x == major_class]
|
140 |
+
len_dist_freq = Counter(len_dist)
|
141 |
+
fig.add_trace(go.Bar(x=list(len_dist_freq.keys()),y=list(len_dist_freq.values()),name=major_class))
|
142 |
+
#stack
|
143 |
+
fig.update_layout(barmode='stack')
|
144 |
+
#make transparent background
|
145 |
+
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
|
146 |
+
#set y axis label to Count and x axis label to Length
|
147 |
+
fig.update_layout(yaxis_title='Count',xaxis_title='Length')
|
148 |
+
#set title
|
149 |
+
fig.update_layout(title_text="Length distribution of false familiar sequences per major class")
|
150 |
+
#save as svg
|
151 |
+
fig.write_image('false_familiar_length_distribution_per_major_class_stacked.svg')
|
152 |
+
fig.show()
|
153 |
+
|
154 |
+
# %%
|
155 |
+
#for each model, for each split, print Is Familiar? == True and print the number of sequences
|
156 |
+
for model in all_df.Model.unique():
|
157 |
+
print('\n\n')
|
158 |
+
model_df = all_df[all_df.Model == model]
|
159 |
+
num_hicos = 0
|
160 |
+
total_samples = 0
|
161 |
+
for split in ['LC-familiar','LC-novel','LOCO','NA','Relaxed-miRNA']:
|
162 |
+
|
163 |
+
split_df = model_df[model_df.split == split]
|
164 |
+
#print('Model: %s, Split: %s, Familiar: %s, Number of Sequences: %s'%(model,split,len(split_df[split_df['Is Familiar?'] == True]),len(split_df)))
|
165 |
+
#print model, split %
|
166 |
+
print('%s %s %s'%(model,split,len(split_df[split_df['Is Familiar?'] == True])/len(split_df)*100))
|
167 |
+
if split != 'LC-novel':
|
168 |
+
num_hicos+=len(split_df[split_df['Is Familiar?'] == True])
|
169 |
+
total_samples+=len(split_df)
|
170 |
+
#print % of hicos
|
171 |
+
print('%s %s %s'%(model,'HICO',num_hicos/total_samples*100))
|
172 |
+
print(total_samples)
|
173 |
+
# %%
|
transforna/bin/figure_scripts/figure_5_S10_table_4.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#in this file, the progression of the number of hicos per major class is computed per model
|
2 |
+
#this is done before ID, after FULL.
|
3 |
+
#%%
|
4 |
+
from transforna import load
|
5 |
+
from transforna import predict_transforna,predict_transforna_all_models
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def compute_overlap_models_ensemble(full_df:pd.DataFrame,mapping_dict:dict):
|
11 |
+
full_copy_df = full_df.copy()
|
12 |
+
full_copy_df['MC_Labels'] = full_copy_df['Net-Label'].map(mapping_dict)
|
13 |
+
#filter is familiar
|
14 |
+
full_copy_df = full_copy_df[full_copy_df['Is Familiar?']].set_index('Sequence')
|
15 |
+
#count the predicted miRNAs per each Model
|
16 |
+
full_copy_df.groupby('Model').MC_Labels.value_counts()
|
17 |
+
|
18 |
+
#for eaach of the models and for each of the mc classes, get the overlap between the models predicting a certain mc and the ensemble having the same prediction
|
19 |
+
models = ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev']
|
20 |
+
mcs = full_copy_df.MC_Labels.value_counts().index.tolist()
|
21 |
+
mc_stats = {}
|
22 |
+
novel_resid = {}
|
23 |
+
mcs_predicted_by_only_one_model = {}
|
24 |
+
#add all mcs as keys to mc_stats and add all models as keys in every mc
|
25 |
+
for mc in mcs:
|
26 |
+
mc_stats[mc] = {}
|
27 |
+
novel_resid[mc] = {}
|
28 |
+
mcs_predicted_by_only_one_model[mc] = {}
|
29 |
+
for model in models:
|
30 |
+
mc_stats[mc][model] = 0
|
31 |
+
novel_resid[mc][model] = 0
|
32 |
+
mcs_predicted_by_only_one_model[mc][model] = 0
|
33 |
+
|
34 |
+
for mc in mcs:
|
35 |
+
ensemble_xrna = full_copy_df[full_copy_df.Model == 'Ensemble'].iloc[full_copy_df[full_copy_df.Model == 'Ensemble'].MC_Labels.str.contains(mc).values].index.tolist()
|
36 |
+
for model in models:
|
37 |
+
model_xrna = full_copy_df[full_copy_df.Model == model].iloc[full_copy_df[full_copy_df.Model == model].MC_Labels.str.contains(mc).values].index.tolist()
|
38 |
+
common_xrna = set(ensemble_xrna).intersection(set(model_xrna))
|
39 |
+
try:
|
40 |
+
mc_stats[mc][model] = len(common_xrna)/len(ensemble_xrna)
|
41 |
+
except ZeroDivisionError:
|
42 |
+
mc_stats[mc][model] = 0
|
43 |
+
#check how many sequences exist in ensemble but not in model
|
44 |
+
try:
|
45 |
+
novel_resid[mc][model] = len(set(ensemble_xrna).difference(set(model_xrna)))/len(ensemble_xrna)
|
46 |
+
except ZeroDivisionError:
|
47 |
+
novel_resid[mc][model] = 0
|
48 |
+
#check how many sequences exist in model and in ensemble but not in other models
|
49 |
+
other_models_xrna = []
|
50 |
+
for other_model in models:
|
51 |
+
if other_model != model:
|
52 |
+
other_models_xrna.extend(full_copy_df[full_copy_df.Model == other_model].iloc[full_copy_df[full_copy_df.Model == other_model].MC_Labels.str.contains(mc).values].index.tolist())
|
53 |
+
#check how many of model_xrna are not in other_models_xrna and are in ensemble_xrna
|
54 |
+
try:
|
55 |
+
mcs_predicted_by_only_one_model[mc][model] = len(set(model_xrna).difference(set(other_models_xrna)).intersection(set(ensemble_xrna)))/len(ensemble_xrna)
|
56 |
+
except ZeroDivisionError:
|
57 |
+
mcs_predicted_by_only_one_model[mc][model] = 0
|
58 |
+
|
59 |
+
return models,mc_stats,novel_resid,mcs_predicted_by_only_one_model
|
60 |
+
|
61 |
+
|
62 |
+
def plot_bar_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model):
|
63 |
+
#plot the result as bar plot per mc
|
64 |
+
import plotly.graph_objects as go
|
65 |
+
import numpy as np
|
66 |
+
import plotly.express as px
|
67 |
+
#square plot with mc classes on the x axis and the number of hicos on the y axis before ID, after ID, after FULL
|
68 |
+
#add cascaded bar plot for novel resid. one per mc per model
|
69 |
+
positions = np.arange(len(models))
|
70 |
+
fig = go.Figure()
|
71 |
+
for model in models:
|
72 |
+
fig.add_trace(go.Bar(
|
73 |
+
x=list(mc_stats.keys()),
|
74 |
+
y=[mc_stats[mc][model] for mc in mc_stats.keys()],
|
75 |
+
name=model,
|
76 |
+
marker_color=px.colors.qualitative.Plotly[models.index(model)]
|
77 |
+
))
|
78 |
+
|
79 |
+
fig.add_trace(go.Bar(
|
80 |
+
x=list(mc_stats.keys()),
|
81 |
+
y=[mcs_predicted_by_only_one_model[mc][model] for mc in mc_stats.keys()],
|
82 |
+
#base = [mc_stats[mc][model] for mc in mc_stats.keys()],
|
83 |
+
name = 'novel',
|
84 |
+
marker_color='lightgrey'
|
85 |
+
))
|
86 |
+
fig.update_layout(title='Overlap between Ensemble and other models per MC class')
|
87 |
+
|
88 |
+
return fig
|
89 |
+
|
90 |
+
def plot_heatmap_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model,what_to_plot='overlap'):
|
91 |
+
'''
|
92 |
+
This function computes a heatmap of the overlap between the ensemble and the other models per mc class
|
93 |
+
input:
|
94 |
+
models: list of models
|
95 |
+
mc_stats: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent overlap between each model and the ensemble
|
96 |
+
novel_resid: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent the % of sequences that are predicted by the ensemble as familiar but with specific model as novel
|
97 |
+
mcs_predicted_by_only_one_model: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent the % of sequences that are predicted as familiar by only one model
|
98 |
+
what_to_plot: string. 'overlap' for overlap between ensemble and other models, 'novel' for novel resid, 'only_one_model' for mcs predicted as novel by only one model
|
99 |
+
|
100 |
+
'''
|
101 |
+
|
102 |
+
if what_to_plot == 'overlap':
|
103 |
+
plot_dict = mc_stats
|
104 |
+
elif what_to_plot == 'novel':
|
105 |
+
plot_dict = novel_resid
|
106 |
+
elif what_to_plot == 'only_one_model':
|
107 |
+
plot_dict = mcs_predicted_by_only_one_model
|
108 |
+
|
109 |
+
import plotly.figure_factory as ff
|
110 |
+
fig = ff.create_annotated_heatmap(
|
111 |
+
z=[[plot_dict[mc][model] for mc in plot_dict.keys()] for model in models],
|
112 |
+
x=list(plot_dict.keys()),
|
113 |
+
y=models,
|
114 |
+
annotation_text=[[str(round(plot_dict[mc][model],2)) for mc in plot_dict.keys()] for model in models],
|
115 |
+
font_colors=['black'],
|
116 |
+
colorscale='Blues'
|
117 |
+
)
|
118 |
+
#set x axis order
|
119 |
+
order_x_axis = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','lncRNA','piRNA','YRNA','vtRNA']
|
120 |
+
fig.update_xaxes(type='category',categoryorder='array',categoryarray=order_x_axis)
|
121 |
+
|
122 |
+
|
123 |
+
fig.update_xaxes(side='bottom')
|
124 |
+
if what_to_plot == 'overlap':
|
125 |
+
fig.update_layout(title='Overlap between Ensemble and other models per MC class')
|
126 |
+
elif what_to_plot == 'novel':
|
127 |
+
fig.update_layout(title='Novel resid between Ensemble and other models per MC class')
|
128 |
+
elif what_to_plot == 'only_one_model':
|
129 |
+
fig.update_layout(title='MCs predicted by only one model')
|
130 |
+
return fig
|
131 |
+
#%%
|
132 |
+
#read TCGA
|
133 |
+
dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
|
134 |
+
models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
|
135 |
+
tcga_df = load(dataset_path_train)
|
136 |
+
tcga_df.set_index('sequence',inplace=True)
|
137 |
+
loco_hico_na_stats_before = {}
|
138 |
+
loco_hico_na_stats_before['HICO'] = sum(tcga_df['hico'])/tcga_df.shape[0]
|
139 |
+
before_hico_seqs = tcga_df['subclass_name'][tcga_df['hico'] == True].index.values
|
140 |
+
loco_hico_na_stats_before['LOCO'] = (sum(tcga_df.subclass_name != 'no_annotation') - sum(tcga_df['hico']))/tcga_df.shape[0]
|
141 |
+
before_loco_seqs = tcga_df[tcga_df.hico!=True][tcga_df.subclass_name != 'no_annotation'].index.values
|
142 |
+
loco_hico_na_stats_before['NA'] = sum(tcga_df.subclass_name == 'no_annotation')/tcga_df.shape[0]
|
143 |
+
before_na_seqs = tcga_df[tcga_df.subclass_name == 'no_annotation'].index.values
|
144 |
+
#load mapping dict
|
145 |
+
mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'
|
146 |
+
mapping_dict = load(mapping_dict_path)
|
147 |
+
hico_seqs = tcga_df['subclass_name'][tcga_df['hico'] == True].index.values
|
148 |
+
hicos_mc_before_id_stats = tcga_df.loc[hico_seqs].subclass_name.map(mapping_dict).value_counts()
|
149 |
+
#remove mcs with ; in them
|
150 |
+
#hicos_mc_before_id_stats = hicos_mc_before_id_stats[~hicos_mc_before_id_stats.index.str.contains(';')]
|
151 |
+
seqs_non_hico_id = tcga_df['subclass_name'][tcga_df['hico'] == False].index.values
|
152 |
+
id_df = predict_transforna(sequences=seqs_non_hico_id,model='Seq-Rev',trained_on='id',path_to_models=models_path)
|
153 |
+
id_df = id_df[id_df['Is Familiar?']].set_index('Sequence')
|
154 |
+
#print the percentage of sequences with no_annotation and with
|
155 |
+
print('Percentage of sequences with no annotation: %s'%(id_df[id_df['Net-Label'] == 'no_annotation'].shape[0]/id_df.shape[0]))
|
156 |
+
print('Percentage of sequences with annotation: %s'%(id_df[id_df['Net-Label'] != 'no_annotation'].shape[0]/id_df.shape[0]))
|
157 |
+
|
158 |
+
#%%
|
159 |
+
hicos_mc_after_id_stats = id_df['Net-Label'].map(mapping_dict).value_counts()
|
160 |
+
#remove mcs with ; in them
|
161 |
+
#hicos_mc_after_id_stats = hicos_mc_after_id_stats[~hicos_mc_after_id_stats.index.str.contains(';')]
|
162 |
+
#add missing major classes with zeros
|
163 |
+
for mc in hicos_mc_before_id_stats.index:
|
164 |
+
if mc not in hicos_mc_after_id_stats.index:
|
165 |
+
hicos_mc_after_id_stats[mc] = 0
|
166 |
+
hicos_mc_after_id_stats = hicos_mc_after_id_stats+hicos_mc_before_id_stats
|
167 |
+
|
168 |
+
#%%
|
169 |
+
seqs_non_hico_full = list(set(seqs_non_hico_id).difference(set(id_df.index.values)))
|
170 |
+
full_df = predict_transforna_all_models(sequences=seqs_non_hico_full,trained_on='full',path_to_models=models_path)
|
171 |
+
#UNCOMMENT TO COMPUTE BEFORE AND AFTER PER MC: table_4
|
172 |
+
#ensemble_df = full_df[full_df['Model']=='Ensemble']
|
173 |
+
#ensemble_df['Major Class'] = ensemble_df['Net-Label'].map(mapping_dict)
|
174 |
+
#new_hico_mcs= ensemble_df['Major Class'].value_counts()
|
175 |
+
#ann_hico_mcs = tcga_df[tcga_df['hico'] == True]['small_RNA_class_annotation'].value_counts()
|
176 |
+
|
177 |
+
#%%%
|
178 |
+
inspect_model = True
|
179 |
+
if inspect_model:
|
180 |
+
#from transforna import compute_overlap_models_ensemble,plot_heatmap_overlap_models_ensemble
|
181 |
+
models, mc_stats, novel_resid, mcs_predicted_by_only_one_model = compute_overlap_models_ensemble(full_df,mapping_dict)
|
182 |
+
fig = plot_heatmap_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model,what_to_plot='overlap')
|
183 |
+
fig.show()
|
184 |
+
|
185 |
+
#%%
|
186 |
+
df = full_df[full_df.Model == 'Ensemble']
|
187 |
+
df = df[df['Is Familiar?']].set_index('Sequence')
|
188 |
+
print('Percentage of sequences with no annotation: %s'%(df[df['Is Familiar?'] == False].shape[0]/df.shape[0]))
|
189 |
+
print('Percentage of sequences with annotation: %s'%(df[df['Is Familiar?'] == True].shape[0]/df.shape[0]))
|
190 |
+
hicos_mc_after_full_stats = df['Net-Label'].map(mapping_dict).value_counts()
|
191 |
+
#remove mcs with ; in them
|
192 |
+
#hicos_mc_after_full_stats = hicos_mc_after_full_stats[~hicos_mc_after_full_stats.index.str.contains(';')]
|
193 |
+
#add missing major classes with zeros
|
194 |
+
for mc in hicos_mc_after_id_stats.index:
|
195 |
+
if mc not in hicos_mc_after_full_stats.index:
|
196 |
+
hicos_mc_after_full_stats[mc] = 0
|
197 |
+
hicos_mc_after_full_stats = hicos_mc_after_full_stats + hicos_mc_after_id_stats
|
198 |
+
|
199 |
+
# %%
|
200 |
+
#reorder the index of the series
|
201 |
+
hicos_mc_before_id_stats = hicos_mc_before_id_stats.reindex(hicos_mc_after_full_stats.index)
|
202 |
+
hicos_mc_after_id_stats = hicos_mc_after_id_stats.reindex(hicos_mc_after_full_stats.index)
|
203 |
+
#plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot
|
204 |
+
#%%
|
205 |
+
#%%
|
206 |
+
training_mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','YRNA','lncRNA']
|
207 |
+
hicos_mc_before_id_stats_train = hicos_mc_before_id_stats[training_mcs]
|
208 |
+
hicos_mc_after_id_stats_train = hicos_mc_after_id_stats[training_mcs]
|
209 |
+
hicos_mc_after_full_stats_train = hicos_mc_after_full_stats[training_mcs]
|
210 |
+
#plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot
|
211 |
+
import plotly.graph_objects as go
|
212 |
+
import numpy as np
|
213 |
+
import plotly.io as pio
|
214 |
+
import plotly.express as px
|
215 |
+
|
216 |
+
#make a square plot with mc classes on the x axis and the number of hicos on the y axis before ID, after ID, after FULL
|
217 |
+
fig = go.Figure()
|
218 |
+
fig.add_trace(go.Bar(
|
219 |
+
x=hicos_mc_before_id_stats_train.index,
|
220 |
+
y=hicos_mc_before_id_stats_train.values,
|
221 |
+
name='Before ID',
|
222 |
+
marker_color='rgb(31, 119, 180)',
|
223 |
+
opacity = 0.5
|
224 |
+
))
|
225 |
+
fig.add_trace(go.Bar(
|
226 |
+
x=hicos_mc_after_id_stats_train.index,
|
227 |
+
y=hicos_mc_after_id_stats_train.values,
|
228 |
+
name='After ID',
|
229 |
+
marker_color='rgb(31, 119, 180)',
|
230 |
+
opacity=0.75
|
231 |
+
))
|
232 |
+
fig.add_trace(go.Bar(
|
233 |
+
x=hicos_mc_after_full_stats_train.index,
|
234 |
+
y=hicos_mc_after_full_stats_train.values,
|
235 |
+
name='After FULL',
|
236 |
+
marker_color='rgb(31, 119, 180)',
|
237 |
+
opacity=1
|
238 |
+
))
|
239 |
+
#make log scale
|
240 |
+
fig.update_layout(
|
241 |
+
title='Progression of the Number of HICOs per Major Class',
|
242 |
+
xaxis_tickfont_size=14,
|
243 |
+
yaxis=dict(
|
244 |
+
title='Number of HICOs',
|
245 |
+
titlefont_size=16,
|
246 |
+
tickfont_size=14,
|
247 |
+
),
|
248 |
+
xaxis=dict(
|
249 |
+
title='Major Class',
|
250 |
+
titlefont_size=16,
|
251 |
+
tickfont_size=14,
|
252 |
+
),
|
253 |
+
legend=dict(
|
254 |
+
x=0.8,
|
255 |
+
y=1.0,
|
256 |
+
bgcolor='rgba(255, 255, 255, 0)',
|
257 |
+
bordercolor='rgba(255, 255, 255, 0)'
|
258 |
+
),
|
259 |
+
barmode='group',
|
260 |
+
bargap=0.15,
|
261 |
+
bargroupgap=0.1
|
262 |
+
)
|
263 |
+
#make transparent background
|
264 |
+
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
|
265 |
+
#log scalew
|
266 |
+
fig.update_yaxes(type="log")
|
267 |
+
|
268 |
+
fig.update_layout(legend=dict(
|
269 |
+
yanchor="top",
|
270 |
+
y=0.99,
|
271 |
+
xanchor="left",
|
272 |
+
x=0.01
|
273 |
+
))
|
274 |
+
#tilt the x axis labels
|
275 |
+
fig.update_layout(xaxis_tickangle=22.5)
|
276 |
+
#set the range of the y axis
|
277 |
+
fig.update_yaxes(range=[0, 4.5])
|
278 |
+
fig.update_layout(legend=dict(
|
279 |
+
orientation="h",
|
280 |
+
yanchor="bottom",
|
281 |
+
y=1.02,
|
282 |
+
xanchor="right",
|
283 |
+
x=1
|
284 |
+
))
|
285 |
+
fig.write_image("progression_hicos_per_mc_train.svg")
|
286 |
+
fig.show()
|
287 |
+
#%%
|
288 |
+
eval_mcs = ['miRNA','miscRNA','piRNA','vtRNA']
|
289 |
+
hicos_mc_before_id_stats_eval = hicos_mc_before_id_stats[eval_mcs]
|
290 |
+
hicos_mc_after_full_stats_eval = hicos_mc_after_full_stats[eval_mcs]
|
291 |
+
|
292 |
+
hicos_mc_after_full_stats_eval.index = hicos_mc_after_full_stats_eval.index + '*'
|
293 |
+
hicos_mc_before_id_stats_eval.index = hicos_mc_before_id_stats_eval.index + '*'
|
294 |
+
#%%
|
295 |
+
#plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot
|
296 |
+
import plotly.graph_objects as go
|
297 |
+
import numpy as np
|
298 |
+
import plotly.io as pio
|
299 |
+
import plotly.express as px
|
300 |
+
|
301 |
+
fig2 = go.Figure()
|
302 |
+
fig2.add_trace(go.Bar(
|
303 |
+
x=hicos_mc_before_id_stats_eval.index,
|
304 |
+
y=hicos_mc_before_id_stats_eval.values,
|
305 |
+
name='Before ID',
|
306 |
+
marker_color='rgb(31, 119, 180)',
|
307 |
+
opacity = 0.5
|
308 |
+
))
|
309 |
+
fig2.add_trace(go.Bar(
|
310 |
+
x=hicos_mc_after_full_stats_eval.index,
|
311 |
+
y=hicos_mc_after_full_stats_eval.values,
|
312 |
+
name='After FULL',
|
313 |
+
marker_color='rgb(31, 119, 180)',
|
314 |
+
opacity=1
|
315 |
+
))
|
316 |
+
#make log scale
|
317 |
+
fig2.update_layout(
|
318 |
+
title='Progression of the Number of HICOs per Major Class',
|
319 |
+
xaxis_tickfont_size=14,
|
320 |
+
yaxis=dict(
|
321 |
+
title='Number of HICOs',
|
322 |
+
titlefont_size=16,
|
323 |
+
tickfont_size=14,
|
324 |
+
),
|
325 |
+
xaxis=dict(
|
326 |
+
title='Major Class',
|
327 |
+
titlefont_size=16,
|
328 |
+
tickfont_size=14,
|
329 |
+
),
|
330 |
+
legend=dict(
|
331 |
+
x=0.8,
|
332 |
+
y=1.0,
|
333 |
+
bgcolor='rgba(255, 255, 255, 0)',
|
334 |
+
bordercolor='rgba(255, 255, 255, 0)'
|
335 |
+
),
|
336 |
+
barmode='group',
|
337 |
+
bargap=0.15,
|
338 |
+
bargroupgap=0.1
|
339 |
+
)
|
340 |
+
#make transparent background
|
341 |
+
fig2.update_layout(plot_bgcolor='rgba(0,0,0,0)')
|
342 |
+
#log scalew
|
343 |
+
fig2.update_yaxes(type="log")
|
344 |
+
|
345 |
+
fig2.update_layout(legend=dict(
|
346 |
+
yanchor="top",
|
347 |
+
y=0.99,
|
348 |
+
xanchor="left",
|
349 |
+
x=0.01
|
350 |
+
))
|
351 |
+
#tilt the x axis labels
|
352 |
+
fig2.update_layout(xaxis_tickangle=22.5)
|
353 |
+
#set the range of the y axis
|
354 |
+
fig2.update_yaxes(range=[0, 4.5])
|
355 |
+
#adjust bargap
|
356 |
+
fig2.update_layout(bargap=0.3)
|
357 |
+
fig2.update_layout(legend=dict(
|
358 |
+
orientation="h",
|
359 |
+
yanchor="bottom",
|
360 |
+
y=1.02,
|
361 |
+
xanchor="right",
|
362 |
+
x=1
|
363 |
+
))
|
364 |
+
#fig2.write_image("progression_hicos_per_mc_eval.svg")
|
365 |
+
fig2.show()
|
366 |
+
# %%
|
367 |
+
#append df and df_after_id
|
368 |
+
df_all_hico = df.append(id_df)
|
369 |
+
loco_hico_na_stats_after = {}
|
370 |
+
loco_hico_na_stats_after['HICO from NA'] = sum(df_all_hico.index.isin(before_na_seqs))/tcga_df.shape[0]
|
371 |
+
loco_pred_df = df_all_hico[df_all_hico.index.isin(before_loco_seqs)]
|
372 |
+
loco_anns_pd = tcga_df.loc[loco_pred_df.index].subclass_name.str.split(';',expand=True)
|
373 |
+
loco_anns_pd = loco_anns_pd.apply(lambda x: x.str.lower())
|
374 |
+
#duplicate labels in loco_pred_df * times as the num of columns in loco_anns_pd
|
375 |
+
loco_pred_labels_df = pd.DataFrame(np.repeat(loco_pred_df['Net-Label'].values,loco_anns_pd.shape[1]).reshape(loco_pred_df.shape[0],loco_anns_pd.shape[1])).set_index(loco_pred_df.index)
|
376 |
+
loco_pred_labels_df = loco_pred_labels_df.apply(lambda x: x.str.lower())
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
#%%
|
381 |
+
trna_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('_trna')).any(axis=1)
|
382 |
+
trna_loco_pred_df = loco_pred_labels_df[trna_mask_df]
|
383 |
+
#get trna_loco_anns_pd
|
384 |
+
trna_loco_anns_pd = loco_anns_pd[trna_mask_df]
|
385 |
+
#for trna_loco_pred_df, remove what prepends the __ and what appends the last -
|
386 |
+
trna_loco_pred_df = trna_loco_pred_df.apply(lambda x: x.str.split('__').str[1])
|
387 |
+
trna_loco_pred_df = trna_loco_pred_df.apply(lambda x: x.str.split('-').str[:-1].str.join('-'))
|
388 |
+
#compute overlap between trna_loco_pred_df and trna_loco_anns_pd
|
389 |
+
#for every value in trna_loco_pred_df, check if is part of the corresponding position in trna_loco_anns_pd
|
390 |
+
num_hico_trna_from_loco = 0
|
391 |
+
for idx,row in trna_loco_pred_df.iterrows():
|
392 |
+
trna_label = row[0]
|
393 |
+
num_hico_trna_from_loco += trna_loco_anns_pd.loc[idx].apply(lambda x: x!=None and trna_label in x).any()
|
394 |
+
|
395 |
+
|
396 |
+
#%%
|
397 |
+
#check if 'mir' or 'let' is in any of the values per row. the columns are numbered from 0 to len(loco_anns_pd.columns)
|
398 |
+
mir_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('mir')).any(axis=1)
|
399 |
+
let_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('let')).any(axis=1)
|
400 |
+
mir_or_let_mask_df = mir_mask_df | let_mask_df
|
401 |
+
mir_or_let_loco_pred_df = loco_pred_labels_df[mir_or_let_mask_df]
|
402 |
+
mir_or_let_loco_anns_pd = loco_anns_pd[mir_or_let_mask_df]
|
403 |
+
#for each value in mir_or_let_loco_pred_df, if the value contains two '-', remove the last one and what comes after it
|
404 |
+
mir_or_let_loco_anns_pd = mir_or_let_loco_anns_pd.applymap(lambda x: '-'.join(x.split('-')[:-1]) if x!=None and x.count('-') == 2 else x)
|
405 |
+
mir_or_let_loco_pred_df = mir_or_let_loco_pred_df.applymap(lambda x: '-'.join(x.split('-')[:-1]) if x!=None and x.count('-') == 2 else x)
|
406 |
+
#compute overlap between mir_or_let_loco_pred_df and mir_or_let_loco_anns_pd
|
407 |
+
num_hico_mir_from_loco = sum((mir_or_let_loco_anns_pd == mir_or_let_loco_pred_df).any(axis=1))
|
408 |
+
#%%
|
409 |
+
|
410 |
+
|
411 |
+
#get rest_loco_anns_pd
|
412 |
+
rest_loco_pred_df = loco_pred_labels_df[~mir_or_let_mask_df & ~trna_mask_df]
|
413 |
+
rest_loco_anns_pd = loco_anns_pd[~mir_or_let_mask_df & ~trna_mask_df]
|
414 |
+
|
415 |
+
num_hico_bins_from_loco = 0
|
416 |
+
for idx,row in rest_loco_pred_df.iterrows():
|
417 |
+
rest_rna_label = row[0].split('-')[0]
|
418 |
+
try:
|
419 |
+
bin_no = int(row[0].split('-')[1])
|
420 |
+
except:
|
421 |
+
continue
|
422 |
+
|
423 |
+
num_hico_bins_from_loco += rest_loco_anns_pd.loc[idx].apply(lambda x: x!=None and rest_rna_label == x.split('-')[0] and abs(int(x.split('-')[1])- bin_no)<=1).any()
|
424 |
+
|
425 |
+
loco_hico_na_stats_after['HICO from LOCO'] = (num_hico_trna_from_loco + num_hico_mir_from_loco + num_hico_bins_from_loco)/tcga_df.shape[0]
|
426 |
+
loco_hico_na_stats_after['LOCO from NA'] = loco_hico_na_stats_before['NA'] - loco_hico_na_stats_after['HICO from NA']
|
427 |
+
loco_hico_na_stats_after['LOCO from LOCO'] = loco_hico_na_stats_before['LOCO'] - loco_hico_na_stats_after['HICO from LOCO']
|
428 |
+
loco_hico_na_stats_after['HICO'] = loco_hico_na_stats_before['HICO']
|
429 |
+
|
430 |
+
# %%
|
431 |
+
|
432 |
+
import plotly.graph_objects as go
|
433 |
+
import plotly.io as pio
|
434 |
+
import plotly.express as px
|
435 |
+
|
436 |
+
color_mapping = {}
|
437 |
+
for key in loco_hico_na_stats_before.keys():
|
438 |
+
if key.startswith('HICO'):
|
439 |
+
color_mapping[key] = "rgb(51,160,44)"
|
440 |
+
elif key.startswith('LOCO'):
|
441 |
+
color_mapping[key] = "rgb(178,223,138)"
|
442 |
+
else:
|
443 |
+
color_mapping[key] = "rgb(251,154,153)"
|
444 |
+
colors = list(color_mapping.values())
|
445 |
+
fig = go.Figure(data=[go.Pie(labels=list(loco_hico_na_stats_before.keys()), values=list(loco_hico_na_stats_before.values()),hole=.0,marker=dict(colors=colors),sort=False)])
|
446 |
+
fig.update_layout(title='Percentage of HICOs, LOCOs and NAs before ID')
|
447 |
+
fig.show()
|
448 |
+
#save figure as svg
|
449 |
+
#fig.write_image("pie_chart_before_id.svg")
|
450 |
+
|
451 |
+
# %%
|
452 |
+
|
453 |
+
color_mapping = {}
|
454 |
+
for key in loco_hico_na_stats_after.keys():
|
455 |
+
if key.startswith('HICO'):
|
456 |
+
color_mapping[key] = "rgb(51,160,44)"
|
457 |
+
elif key.startswith('LOCO'):
|
458 |
+
color_mapping[key] = "rgb(178,223,138)"
|
459 |
+
|
460 |
+
loco_hico_na_stats_after = {k: loco_hico_na_stats_after[k] for k in sorted(loco_hico_na_stats_after, key=lambda k: k.startswith('HICO'), reverse=True)}
|
461 |
+
|
462 |
+
fig = go.Figure(data=[go.Pie(labels=list(loco_hico_na_stats_after.keys()), values=list(loco_hico_na_stats_after.values()),hole=.0,marker=dict(colors=list(color_mapping.values())),sort=False)])
|
463 |
+
fig.update_layout(title='Percentage of HICOs, LOCOs and NAs after ID')
|
464 |
+
fig.show()
|
465 |
+
#save figure as svg
|
466 |
+
#fig.write_image("pie_chart_after_id.svg")
|
transforna/bin/figure_scripts/figure_6.ipynb
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/nfs/home/yat_ldap/conda/envs/hbdx/envs/transforna/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from transforna import load,predict_transforna_all_models,predict_transforna,fold_sequences\n",
|
19 |
+
"models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'\n",
|
20 |
+
"lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'\n",
|
21 |
+
"tcga_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'\n",
|
22 |
+
"\n",
|
23 |
+
"tcga_df = load(tcga_path)\n",
|
24 |
+
"lc_df = load(lc_path)\n",
|
25 |
+
"\n",
|
26 |
+
"lc_df = lc_df[lc_df.sequence.str.len() <= 30]\n",
|
27 |
+
"\n",
|
28 |
+
"all_seqs = lc_df.sequence.tolist()+tcga_df.sequence.tolist()\n",
|
29 |
+
"\n",
|
30 |
+
"mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'\n",
|
31 |
+
"mapping_dict = load(mapping_dict_path)\n",
|
32 |
+
" "
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": null,
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"predictions = predict_transforna_all_models(all_seqs,trained_on='full',path_to_models=models_path)\n",
|
42 |
+
"predictions.to_csv('predictions_lc_tcga.csv',index=False)"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 2,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"#read predictions\n",
|
52 |
+
"predictions = load('predictions_lc_tcga.csv')"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
+
"metadata": {},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"umaps = {}\n",
|
62 |
+
"models = predictions['Model'].unique()\n",
|
63 |
+
"for model in models:\n",
|
64 |
+
" if model == 'Ensemble':\n",
|
65 |
+
" continue\n",
|
66 |
+
" #get predictions\n",
|
67 |
+
" model_predictions = predictions[predictions['Model']==model]\n",
|
68 |
+
" #get is familiar rows\n",
|
69 |
+
" familiar_df = model_predictions[model_predictions['Is Familiar?']==True]\n",
|
70 |
+
" #get umap\n",
|
71 |
+
" umap_df = predict_transforna(model_predictions['Sequence'].tolist(),model=model,trained_on='full',path_to_models=models_path,umap_flag=True)\n",
|
72 |
+
" umaps[model] = umap_df"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
+
"\n",
|
82 |
+
"import plotly.express as px\n",
|
83 |
+
"import numpy as np\n",
|
84 |
+
"mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n",
|
85 |
+
"#filter out the classes that contain ;\n",
|
86 |
+
"mcs = [mc for mc in mcs if ';' not in mc]\n",
|
87 |
+
"colors = px.colors.qualitative.Plotly\n",
|
88 |
+
"color_mapping = dict(zip(mcs,colors))\n",
|
89 |
+
"for model,umap_df in umaps.items():\n",
|
90 |
+
" umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n",
|
91 |
+
" umap_df_copy = umap_df.copy()\n",
|
92 |
+
" #remove rows with Major Class containing ;\n",
|
93 |
+
" umap_df = umap_df[~umap_df['Major Class'].str.contains(';')]\n",
|
94 |
+
" fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n",
|
95 |
+
" =['Sequence'],title=model,\\\n",
|
96 |
+
" width = 800, height=800,color_discrete_map=color_mapping)\n",
|
97 |
+
" fig.update_traces(marker=dict(size=1))\n",
|
98 |
+
" #white background\n",
|
99 |
+
" fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
100 |
+
" #only show UMAP1 from 4.3 to 11\n",
|
101 |
+
" fig.update_xaxes(range=[4.3,11])\n",
|
102 |
+
" #and UMAP2 from -2.3 to 6.8\n",
|
103 |
+
" fig.update_yaxes(range=[-2.3,6.8])\n",
|
104 |
+
" #fig.show()\n",
|
105 |
+
" fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.png')\n",
|
106 |
+
" fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.svg')\n"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"execution_count": null,
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [],
|
114 |
+
"source": [
|
115 |
+
"import plotly.express as px\n",
|
116 |
+
"import numpy as np\n",
|
117 |
+
"mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n",
|
118 |
+
"#filter out the classes that contain ;\n",
|
119 |
+
"mcs = [mc for mc in mcs if ';' not in mc]\n",
|
120 |
+
"colors = px.colors.qualitative.Plotly + px.colors.qualitative.Light24\n",
|
121 |
+
"color_mapping = dict(zip(mcs,colors))\n",
|
122 |
+
"for model,umap_df in umaps.items():\n",
|
123 |
+
" umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n",
|
124 |
+
" umap_df_copy = umap_df.copy()\n",
|
125 |
+
" #remove rows with Major Class containing ;\n",
|
126 |
+
" umap_df = umap_df[~umap_df['Major Class'].str.contains(';')]\n",
|
127 |
+
" fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n",
|
128 |
+
" =['Sequence'],title=model,\\\n",
|
129 |
+
" width = 800, height=800,color_discrete_map=color_mapping)\n",
|
130 |
+
" fig.update_traces(marker=dict(size=1))\n",
|
131 |
+
" #white background\n",
|
132 |
+
" fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
133 |
+
" #fig.show()\n",
|
134 |
+
" fig.write_image(f'lc_figures/lc_tcga_umap_{model}.png')\n",
|
135 |
+
" fig.write_image(f'lc_figures/lc_tcga_umap_{model}.svg')\n"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": null,
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"#plot umap using px.scatter for each model\n",
|
145 |
+
"import plotly.express as px\n",
|
146 |
+
"import numpy as np\n",
|
147 |
+
"mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n",
|
148 |
+
"#filter out the classes that contain ;\n",
|
149 |
+
"mcs = [mc for mc in mcs if ';' not in mc]\n",
|
150 |
+
"colors = px.colors.qualitative.Plotly\n",
|
151 |
+
"color_mapping = dict(zip(mcs,colors))\n",
|
152 |
+
"umap_df = umaps['Seq']\n",
|
153 |
+
"umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n",
|
154 |
+
"umap_df_copy = umap_df.copy()\n",
|
155 |
+
"#display points contained within the circle at center (7.9,2.5) and radius 4.3\n",
|
156 |
+
"umap_df_copy['distance'] = np.sqrt((umap_df_copy['UMAP1']-7.9)**2+(umap_df_copy['UMAP2']-2.5)**2)\n",
|
157 |
+
"umap_df_copy = umap_df_copy[umap_df_copy['distance']<=4.3]\n",
|
158 |
+
"#remove rows with Major Class containing ;\n",
|
159 |
+
"umap_df_copy = umap_df_copy[~umap_df_copy['Major Class'].str.contains(';')]\n",
|
160 |
+
"fig = px.scatter(umap_df_copy,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n",
|
161 |
+
" =['Sequence'],title=model,\\\n",
|
162 |
+
" width = 800, height=800,color_discrete_map=color_mapping)\n",
|
163 |
+
"fig.update_traces(marker=dict(size=1))\n",
|
164 |
+
"#white background\n",
|
165 |
+
"fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
166 |
+
"fig.show()\n",
|
167 |
+
"#fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.png')\n",
|
168 |
+
"#fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.svg')\n"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"#plot\n",
|
178 |
+
"sec_struct = fold_sequences(model_predictions['Sequence'].tolist())['structure_37']\n",
|
179 |
+
"#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n",
|
180 |
+
"sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 40,
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"umap_df = umaps['Seq-Struct']\n",
|
190 |
+
"fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color=sec_struct_ratio,hover_data=['Sequence'],title=model,\\\n",
|
191 |
+
" width = 800, height=800,color_continuous_scale='Viridis')\n",
|
192 |
+
"fig.update_traces(marker=dict(size=1))\n",
|
193 |
+
"fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
194 |
+
"#save\n",
|
195 |
+
"fig.write_image(f'lc_figures/lc_tcga_umap_{model}_dot_bracket.png')\n",
|
196 |
+
"fig.write_image(f'lc_figures/lc_tcga_umap_{model}_dot_bracket.svg')"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "code",
|
201 |
+
"execution_count": null,
|
202 |
+
"metadata": {},
|
203 |
+
"outputs": [],
|
204 |
+
"source": []
|
205 |
+
}
|
206 |
+
],
|
207 |
+
"metadata": {
|
208 |
+
"kernelspec": {
|
209 |
+
"display_name": "transforna",
|
210 |
+
"language": "python",
|
211 |
+
"name": "python3"
|
212 |
+
},
|
213 |
+
"language_info": {
|
214 |
+
"codemirror_mode": {
|
215 |
+
"name": "ipython",
|
216 |
+
"version": 3
|
217 |
+
},
|
218 |
+
"file_extension": ".py",
|
219 |
+
"mimetype": "text/x-python",
|
220 |
+
"name": "python",
|
221 |
+
"nbconvert_exporter": "python",
|
222 |
+
"pygments_lexer": "ipython3",
|
223 |
+
"version": "3.9.18"
|
224 |
+
}
|
225 |
+
},
|
226 |
+
"nbformat": 4,
|
227 |
+
"nbformat_minor": 2
|
228 |
+
}
|
transforna/bin/figure_scripts/figure_S4.ipynb
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"\n",
|
10 |
+
"import pandas as pd\n",
|
11 |
+
"scores = {'major_class':{},'sub_class':{}}\n",
|
12 |
+
"models = ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev']\n",
|
13 |
+
"models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n",
|
14 |
+
"for model1 in models:\n",
|
15 |
+
" summary_pd = pd.read_csv(models_path+'/major_class/'+model1+'/summary_pd.tsv',sep='\\t')\n",
|
16 |
+
" scores['major_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100)+')'\n",
|
17 |
+
" summary_pd = pd.read_csv(models_path+'/sub_class/'+model1+'/summary_pd.tsv',sep='\\t')\n",
|
18 |
+
" scores['sub_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100) +')'"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [
|
26 |
+
{
|
27 |
+
"data": {
|
28 |
+
"text/plain": [
|
29 |
+
"{'Baseline': '52.83789870060305+/- (1.0961119898709506)',\n",
|
30 |
+
" 'Seq': '97.70018230805728+/- (0.3819207447704567)',\n",
|
31 |
+
" 'Seq-Seq': '95.65091330992355+/- (0.4963151975035616)',\n",
|
32 |
+
" 'Seq-Struct': '97.71071590680333+/- (0.6173598637101496)',\n",
|
33 |
+
" 'Seq-Rev': '97.51224133899979+/- (0.3418133671042992)'}"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"execution_count": 2,
|
37 |
+
"metadata": {},
|
38 |
+
"output_type": "execute_result"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"scores['sub_class']"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 3,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"\n",
|
52 |
+
"import json\n",
|
53 |
+
"import pandas as pd\n",
|
54 |
+
"with open('/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json') as f:\n",
|
55 |
+
" mapping_dict = json.load(f)\n",
|
56 |
+
"\n",
|
57 |
+
"b_acc_sc_to_mc = {}\n",
|
58 |
+
"for model1 in models:\n",
|
59 |
+
" b_acc = []\n",
|
60 |
+
" for idx in range(5):\n",
|
61 |
+
" confusion_matrix = pd.read_csv(models_path+'/sub_class/'+model1+f'/embedds/confusion_matrix_{idx}.csv',sep=',',index_col=0)\n",
|
62 |
+
" confusion_matrix.index = confusion_matrix.index.map(mapping_dict)\n",
|
63 |
+
" confusion_matrix.columns = confusion_matrix.columns.map(mapping_dict)\n",
|
64 |
+
" confusion_matrix = confusion_matrix.groupby(confusion_matrix.index).sum().groupby(confusion_matrix.columns,axis=1).sum()\n",
|
65 |
+
" b_acc.append(confusion_matrix.values.diagonal().sum()/confusion_matrix.values.sum())\n",
|
66 |
+
" b_acc_sc_to_mc[model1] = str(pd.Series(b_acc).mean()*100)+'+/-'+' ('+str(pd.Series(b_acc).std()*100)+')'\n"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 4,
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [
|
74 |
+
{
|
75 |
+
"data": {
|
76 |
+
"text/plain": [
|
77 |
+
"{'Baseline': '89.6182558114013+/- (0.6372156071358975)',\n",
|
78 |
+
" 'Seq': '99.66714304286457+/- (0.1404591049684126)',\n",
|
79 |
+
" 'Seq-Seq': '99.40702944026852+/- (0.18268320317601783)',\n",
|
80 |
+
" 'Seq-Struct': '99.77114728744993+/- (0.06976258667467564)',\n",
|
81 |
+
" 'Seq-Rev': '99.70878801385821+/- (0.11954774341354062)'}"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
"execution_count": 4,
|
85 |
+
"metadata": {},
|
86 |
+
"output_type": "execute_result"
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"source": [
|
90 |
+
"b_acc_sc_to_mc"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": 5,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"\n",
|
100 |
+
"import plotly.express as px\n",
|
101 |
+
"no_annotation_predictions = {}\n",
|
102 |
+
"for model1 in models:\n",
|
103 |
+
" #multiindex\n",
|
104 |
+
" no_annotation_predictions[model1] = pd.read_csv(models_path+'/sub_class/'+model1+'/embedds/no_annotation_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
|
105 |
+
" no_annotation_predictions[model1].set_index([('RNA Sequences','0')] ,inplace=True)\n",
|
106 |
+
" no_annotation_predictions[model1].index.name = 'RNA Sequences'\n",
|
107 |
+
" no_annotation_predictions[model1] = no_annotation_predictions[model1]['Logits'].idxmax(axis=1)\n"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": null,
|
113 |
+
"metadata": {},
|
114 |
+
"outputs": [],
|
115 |
+
"source": [
|
116 |
+
"from transforna.src.utils.tcga_post_analysis_utils import correct_labels\n",
|
117 |
+
"import pandas as pd\n",
|
118 |
+
"correlation = pd.DataFrame(index=models,columns=models)\n",
|
119 |
+
"for model1 in models:\n",
|
120 |
+
" for model2 in models:\n",
|
121 |
+
" model1_predictions = correct_labels(no_annotation_predictions[model1],no_annotation_predictions[model2],mapping_dict)\n",
|
122 |
+
" is_equal = model1_predictions == no_annotation_predictions[model2].values\n",
|
123 |
+
" correlation.loc[model1,model2] = is_equal.sum()/len(is_equal)\n",
|
124 |
+
"font_size = 20\n",
|
125 |
+
"fig = px.imshow(correlation, color_continuous_scale='Blues')\n",
|
126 |
+
"#annotate\n",
|
127 |
+
"for i in range(len(models)):\n",
|
128 |
+
" for j in range(len(models)):\n",
|
129 |
+
" if i != j:\n",
|
130 |
+
" font = dict(color='black', size=font_size)\n",
|
131 |
+
" else:\n",
|
132 |
+
" font = dict(color='white', size=font_size) \n",
|
133 |
+
" \n",
|
134 |
+
" fig.add_annotation(\n",
|
135 |
+
" x=j, y=i,\n",
|
136 |
+
" text=str(round(correlation.iloc[i,j],2)),\n",
|
137 |
+
" showarrow=False,\n",
|
138 |
+
" font=font\n",
|
139 |
+
" )\n",
|
140 |
+
"\n",
|
141 |
+
"#set figure size: width and height\n",
|
142 |
+
"fig.update_layout(width=800, height=800)\n",
|
143 |
+
"\n",
|
144 |
+
"fig.update_layout(title='Correlation between models for each sub_class model')\n",
|
145 |
+
"#set x and y axis to Models\n",
|
146 |
+
"fig.update_xaxes(title_text='Models', tickfont=dict(size=font_size))\n",
|
147 |
+
"fig.update_yaxes(title_text='Models', tickfont=dict(size=font_size))\n",
|
148 |
+
"fig.show()\n",
|
149 |
+
"#save\n",
|
150 |
+
"fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.png')\n",
|
151 |
+
"fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.svg')"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": 7,
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [],
|
159 |
+
"source": [
|
160 |
+
"#create umap for every model from embedds folder\n",
|
161 |
+
"models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n",
|
162 |
+
"\n",
|
163 |
+
"#read\n",
|
164 |
+
"sc_embedds = {}\n",
|
165 |
+
"mc_embedds = {}\n",
|
166 |
+
"sc_to_mc_labels = {}\n",
|
167 |
+
"sc_labels = {}\n",
|
168 |
+
"mc_labels = {}\n",
|
169 |
+
"for model in models:\n",
|
170 |
+
" df = pd.read_csv(models_path+'/sub_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
|
171 |
+
" sc_embedds[model] = df['RNA Embedds'].values\n",
|
172 |
+
" sc_labels[model] = df['Labels']['0']\n",
|
173 |
+
" sc_to_mc_labels[model] = sc_labels[model].map(mapping_dict).values\n",
|
174 |
+
"\n",
|
175 |
+
" #major class\n",
|
176 |
+
" df = pd.read_csv(models_path+'/major_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
|
177 |
+
" mc_embedds[model] = df['RNA Embedds'].values\n",
|
178 |
+
" mc_labels[model] = df['Labels']['0']"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 8,
|
184 |
+
"metadata": {},
|
185 |
+
"outputs": [],
|
186 |
+
"source": [
|
187 |
+
"import umap\n",
|
188 |
+
"#compute umap coordinates\n",
|
189 |
+
"sc_umap_coords = {}\n",
|
190 |
+
"mc_umap_coords = {}\n",
|
191 |
+
"for model in models:\n",
|
192 |
+
" sc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(sc_embedds[model])\n",
|
193 |
+
" mc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(mc_embedds[model])"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": null,
|
199 |
+
"metadata": {},
|
200 |
+
"outputs": [],
|
201 |
+
"source": [
|
202 |
+
"#plot umap\n",
|
203 |
+
"import plotly.express as px\n",
|
204 |
+
"import numpy as np\n",
|
205 |
+
"\n",
|
206 |
+
"mcs = np.unique(sc_to_mc_labels[models[0]])\n",
|
207 |
+
"colors = px.colors.qualitative.Plotly\n",
|
208 |
+
"color_mapping = dict(zip(mcs,colors))\n",
|
209 |
+
"for model in models:\n",
|
210 |
+
" fig = px.scatter(x=sc_umap_coords[model][:,0],y=sc_umap_coords[model][:,1],color=sc_to_mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n",
|
211 |
+
"\n",
|
212 |
+
" hover_data={'Major Class':sc_labels[model],'Sub Class':sc_to_mc_labels[model]},color_discrete_map=color_mapping)\n",
|
213 |
+
"\n",
|
214 |
+
" fig.update_traces(marker=dict(size=1))\n",
|
215 |
+
" #white background\n",
|
216 |
+
" fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
217 |
+
"\n",
|
218 |
+
" fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.svg')\n",
|
219 |
+
" fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.png')\n",
|
220 |
+
" fig.show()\n",
|
221 |
+
"\n",
|
222 |
+
" #plot umap for major class\n",
|
223 |
+
" fig = px.scatter(x=mc_umap_coords[model][:,0],y=mc_umap_coords[model][:,1],color=mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n",
|
224 |
+
"\n",
|
225 |
+
" hover_data={'Major Class':mc_labels[model]},color_discrete_map=color_mapping)\n",
|
226 |
+
" fig.update_traces(marker=dict(size=1))\n",
|
227 |
+
" #white background\n",
|
228 |
+
" fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
229 |
+
"\n",
|
230 |
+
" fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.svg')\n",
|
231 |
+
" fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.png')\n",
|
232 |
+
" fig.show()"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"cell_type": "code",
|
237 |
+
"execution_count": null,
|
238 |
+
"metadata": {},
|
239 |
+
"outputs": [],
|
240 |
+
"source": [
|
241 |
+
"from transforna import fold_sequences\n",
|
242 |
+
"df = pd.read_csv(models_path+'/major_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
|
243 |
+
"sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n",
|
244 |
+
"#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n",
|
245 |
+
"sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n",
|
246 |
+
"fig = px.scatter(x=mc_umap_coords['Seq-Struct'][:,0],y=mc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n",
|
247 |
+
" hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n",
|
248 |
+
"\n",
|
249 |
+
"fig.update_traces(marker=dict(size=3))\n",
|
250 |
+
"#white background\n",
|
251 |
+
"fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
252 |
+
"fig.show()\n",
|
253 |
+
"fig.write_image(models_path+'/major_class/Seq-Struct/figures/mc_umap_sec_struct.svg')\n"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"execution_count": null,
|
259 |
+
"metadata": {},
|
260 |
+
"outputs": [],
|
261 |
+
"source": [
|
262 |
+
"\n",
|
263 |
+
"from transforna import fold_sequences\n",
|
264 |
+
"df = pd.read_csv(models_path+'/sub_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
|
265 |
+
"sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n",
|
266 |
+
"#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n",
|
267 |
+
"sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n",
|
268 |
+
"fig = px.scatter(x=sc_umap_coords['Seq-Struct'][:,0],y=sc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n",
|
269 |
+
" hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n",
|
270 |
+
"\n",
|
271 |
+
"fig.update_traces(marker=dict(size=3))\n",
|
272 |
+
"#white background\n",
|
273 |
+
"fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
|
274 |
+
"fig.show()\n",
|
275 |
+
"fig.write_image(models_path+'/sub_class/Seq-Struct/figures/sc_umap_sec_struct.svg')"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "code",
|
280 |
+
"execution_count": 11,
|
281 |
+
"metadata": {},
|
282 |
+
"outputs": [],
|
283 |
+
"source": [
|
284 |
+
"from transforna import Results_Handler,get_closest_ngbr_per_split\n",
|
285 |
+
"\n",
|
286 |
+
"splits = ['train','valid','test','ood','artificial','no_annotation']\n",
|
287 |
+
"splits_to_plot = ['test','ood','random','recombined','artificial_affix']\n",
|
288 |
+
"renaming_dict= {'test':'ID (test)','ood':'Rare sub-classes','random':'Random','artificial_affix':'Putative 5\\'-adapter prefixes','recombined':'Recombined'}\n",
|
289 |
+
"\n",
|
290 |
+
"lev_dist_df = pd.DataFrame()\n",
|
291 |
+
"for model in models:\n",
|
292 |
+
" results = Results_Handler(models_path+f'/sub_class/{model}/embedds',splits=splits,read_dataset=True)\n",
|
293 |
+
" results.append_loco_variants()\n",
|
294 |
+
" results.get_knn_model()\n",
|
295 |
+
" \n",
|
296 |
+
" #compute levenstein distance per split\n",
|
297 |
+
" for split in splits_to_plot:\n",
|
298 |
+
" split_seqs,split_labels,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,split)\n",
|
299 |
+
" #create df from split and levenstein distance\n",
|
300 |
+
" lev_dist_split_df = pd.DataFrame({'split':split,'lev_dist':lev_dist,'seqs':split_seqs,'labels':split_labels,'top_n_seqs':top_n_seqs,'top_n_labels':top_n_labels})\n",
|
301 |
+
" #rename \n",
|
302 |
+
" lev_dist_split_df['split'] = lev_dist_split_df['split'].map(renaming_dict)\n",
|
303 |
+
" lev_dist_split_df['model'] = model\n",
|
304 |
+
" #append \n",
|
305 |
+
" lev_dist_df = pd.concat([lev_dist_df,lev_dist_split_df],axis=0)\n",
|
306 |
+
"\n"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": null,
|
312 |
+
"metadata": {},
|
313 |
+
"outputs": [],
|
314 |
+
"source": [
|
315 |
+
"#plot the distribution of lev_dist for each split for each model\n",
|
316 |
+
"model_thresholds = {'Baseline':0.267,'Seq':0.246,'Seq-Seq':0.272,'Seq-Struct': 0.242,'Seq-Rev':0.237}\n",
|
317 |
+
"model_aucs = {'Baseline':0.76,'Seq':0.97,'Seq-Seq':0.96,'Seq-Struct': 0.97,'Seq-Rev':0.97}\n",
|
318 |
+
"import seaborn as sns\n",
|
319 |
+
"import matplotlib.pyplot as plt\n",
|
320 |
+
"sns.set_theme(style=\"whitegrid\")\n",
|
321 |
+
"sns.set(rc={'figure.figsize':(15,10)})\n",
|
322 |
+
"sns.set(font_scale=1.5)\n",
|
323 |
+
"ax = sns.boxplot(x=\"model\", y=\"lev_dist\", hue=\"split\", data=lev_dist_df, palette=\"Set3\",order=models,showfliers = True)\n",
|
324 |
+
"#add title\n",
|
325 |
+
"ax.set_facecolor('None')\n",
|
326 |
+
"plt.title('Levenshtein Distance Distribution per Model on ID')\n",
|
327 |
+
"ax.set(xlabel='Model', ylabel='Normalized Levenshtein Distance')\n",
|
328 |
+
"#legend background should transparent\n",
|
329 |
+
"ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.,facecolor=None,framealpha=0.0)\n",
|
330 |
+
"# add horizontal lines for thresholds for each model while making sure the line is within the boxplot\n",
|
331 |
+
"min_val = 0 \n",
|
332 |
+
"for model in models:\n",
|
333 |
+
" thresh = model_thresholds[model]\n",
|
334 |
+
" plt.axhline(y=thresh, color='g', linestyle='--',xmin=min_val,xmax=min_val+0.2)\n",
|
335 |
+
" min_val+=0.2\n",
|
336 |
+
"\n"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": null,
|
342 |
+
"metadata": {},
|
343 |
+
"outputs": [],
|
344 |
+
"source": []
|
345 |
+
}
|
346 |
+
],
|
347 |
+
"metadata": {
|
348 |
+
"kernelspec": {
|
349 |
+
"display_name": "transforna",
|
350 |
+
"language": "python",
|
351 |
+
"name": "python3"
|
352 |
+
},
|
353 |
+
"language_info": {
|
354 |
+
"codemirror_mode": {
|
355 |
+
"name": "ipython",
|
356 |
+
"version": 3
|
357 |
+
},
|
358 |
+
"file_extension": ".py",
|
359 |
+
"mimetype": "text/x-python",
|
360 |
+
"name": "python",
|
361 |
+
"nbconvert_exporter": "python",
|
362 |
+
"pygments_lexer": "ipython3",
|
363 |
+
"version": "3.9.18"
|
364 |
+
}
|
365 |
+
},
|
366 |
+
"nbformat": 4,
|
367 |
+
"nbformat_minor": 2
|
368 |
+
}
|
transforna/bin/figure_scripts/figure_S5.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
import transforna
|
6 |
+
from transforna import IDModelAugmenter, load
|
7 |
+
|
8 |
+
#%%
|
9 |
+
model_name = 'Seq'
|
10 |
+
config_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID/sub_class/{model_name}/meta/hp_settings.yaml'
|
11 |
+
config = load(config_path)
|
12 |
+
model_augmenter = IDModelAugmenter(df=None,config=config)
|
13 |
+
df = model_augmenter.predict_transforna_na()
|
14 |
+
tcga_df = load('/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv')
|
15 |
+
|
16 |
+
tcga_df.set_index('sequence',inplace=True)
|
17 |
+
tcga_df['Labels'] = tcga_df['subclass_name'][tcga_df['hico'] == True]
|
18 |
+
tcga_df['Labels'] = tcga_df['Labels'].astype('category')
|
19 |
+
#%%
|
20 |
+
tcga_df.loc[df.Sequence.values,'Labels'] = df['Net-Label'].values
|
21 |
+
|
22 |
+
loco_labels_df = tcga_df['subclass_name'].str.split(';',expand=True).loc[df['Sequence']]
|
23 |
+
#filter the rows having no_annotation in the first row of loco_labels_df
|
24 |
+
loco_labels_df = loco_labels_df.iloc[~(loco_labels_df[0] == 'no_annotation').values]
|
25 |
+
#%%
|
26 |
+
#get the Is Familiar? column from df based on index of loco_labels_df
|
27 |
+
novelty_prediction_loco_df = df[df['Sequence'].isin(loco_labels_df.index)].set_index('Sequence')['Is Familiar?']
|
28 |
+
#%%
|
29 |
+
id_predictions_df = tcga_df.loc[loco_labels_df.index]['Labels']
|
30 |
+
#copy the columns of id_predictions_df nuber of times equal to the number of columns in loco_labels_df
|
31 |
+
id_predictions_df = pd.concat([id_predictions_df]*loco_labels_df.shape[1],axis=1)
|
32 |
+
id_predictions_df.columns = np.arange(loco_labels_df.shape[1])
|
33 |
+
equ_mask = loco_labels_df == id_predictions_df
|
34 |
+
#check how many rows in eq_mask has atleast one True
|
35 |
+
num_true = equ_mask.any(axis=1).sum()
|
36 |
+
print('percentage of all loco RNAs: ',num_true/equ_mask.shape[0])
|
37 |
+
|
38 |
+
|
39 |
+
#split loco_labels_df into two dataframes. familiar and novel
|
40 |
+
fam_loco_labels_df = loco_labels_df[novelty_prediction_loco_df]
|
41 |
+
novel_loco_labels__df = loco_labels_df[~novelty_prediction_loco_df]
|
42 |
+
#seperate id_predictions_df into two dataframes. novel and familiar
|
43 |
+
id_predictions_fam_df = id_predictions_df[novelty_prediction_loco_df]
|
44 |
+
id_predictions_novel_df = id_predictions_df[~novelty_prediction_loco_df]
|
45 |
+
#%%
|
46 |
+
num_true_fam = (fam_loco_labels_df == id_predictions_fam_df).any(axis=1).sum()
|
47 |
+
num_true_novel = (novel_loco_labels__df == id_predictions_novel_df).any(axis=1).sum()
|
48 |
+
|
49 |
+
print('percentage of similar predictions in familiar: ',num_true_fam/fam_loco_labels_df.shape[0])
|
50 |
+
print('percentage of similar predictions not in novel: ',num_true_novel/novel_loco_labels__df.shape[0])
|
51 |
+
print('')
|
52 |
+
# %%
|
53 |
+
#remove the rows in fam_loco_labels_df and id_predictions_fam_df that have atleast one True in equ_mask
|
54 |
+
fam_loco_labels_no_overlap_df = fam_loco_labels_df[~equ_mask.any(axis=1)]
|
55 |
+
id_predictions_fam_no_overlap_df = id_predictions_fam_df[~equ_mask.any(axis=1)]
|
56 |
+
#collapse the dataframe of fam_loco_labels_df with a ';' seperator
|
57 |
+
collapsed_loco_labels_df = fam_loco_labels_no_overlap_df.apply(lambda x: ';'.join(x.dropna().astype(str)),axis=1)
|
58 |
+
#combined collapsed_loco_labels_df with id_predictions_fam_df[0]
|
59 |
+
predicted_fam_but_ann_novel_df = pd.concat([collapsed_loco_labels_df,id_predictions_fam_no_overlap_df[0]],axis=1)
|
60 |
+
#rename columns
|
61 |
+
predicted_fam_but_ann_novel_df.columns = ['KBA_labels','predicted_label']
|
62 |
+
# %%
|
63 |
+
#get major class for each column in KBA_labels and predicted_label
|
64 |
+
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json'
|
65 |
+
sc_to_mc_mapper_dict = load(mapping_dict_path)
|
66 |
+
|
67 |
+
predicted_fam_but_ann_novel_df['KBA_labels_mc'] = predicted_fam_but_ann_novel_df['KBA_labels'].str.split(';').apply(lambda x: ';'.join([sc_to_mc_mapper_dict[i] if i in sc_to_mc_mapper_dict.keys() else i for i in x]))
|
68 |
+
predicted_fam_but_ann_novel_df['predicted_label_mc'] = predicted_fam_but_ann_novel_df['predicted_label'].apply(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict.keys() else x)
|
69 |
+
# %%
|
70 |
+
#for the each of the sequence in predicted_fam_but_ann_novel_df, compute the sim seq along with the lv distance
|
71 |
+
from transforna import predict_transforna
|
72 |
+
|
73 |
+
sim_df = predict_transforna(model=model_name,sequences=predicted_fam_but_ann_novel_df.index.tolist(),similarity_flag=True,n_sim=1,trained_on='id',path_to_models='/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/')
|
74 |
+
sim_df = sim_df.set_index('Sequence')
|
75 |
+
|
76 |
+
#append the sim_df to predicted_fam_but_ann_novel_df except for the Labels column
|
77 |
+
predicted_fam_but_ann_novel_df = pd.concat([predicted_fam_but_ann_novel_df,sim_df.drop('Labels',axis=1)],axis=1)
|
78 |
+
# %%
|
79 |
+
#plot the mc proportions of predicted_label_mc
|
80 |
+
predicted_fam_but_ann_novel_df['predicted_label_mc'].value_counts().plot(kind='bar')
|
81 |
+
#get order of labels on x axis
|
82 |
+
x_labels = predicted_fam_but_ann_novel_df['predicted_label_mc'].value_counts().index.tolist()
|
83 |
+
# %%
|
84 |
+
#plot the LV distance per predicted_label_mc and order the x axis based on the order of x_labels
|
85 |
+
fig = predicted_fam_but_ann_novel_df.boxplot(column='NLD',by='predicted_label_mc',figsize=(20,10),rot=90,showfliers=False)
|
86 |
+
#reorder x axis in fig by x_labels
|
87 |
+
fig.set_xticklabels(x_labels)
|
88 |
+
#increase font of axis labels and ticks
|
89 |
+
fig.set_xlabel('Predicted Label',fontsize=20)
|
90 |
+
fig.set_ylabel('Levenstein Distance',fontsize=20)
|
91 |
+
fig.tick_params(axis='both', which='major', labelsize=20)
|
92 |
+
#display pandas full rows
|
93 |
+
pd.set_option('display.max_rows', None)
|
94 |
+
# %%
|
transforna/bin/figure_scripts/figure_S8.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
from transforna import load
|
3 |
+
from transforna import Results_Handler
|
4 |
+
import pandas as pd
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
import plotly.io as pio
|
7 |
+
path = '/media/ftp_share/hbdx/analysis/tcga/TransfoRNA_I_ID_V4/sub_class/Seq/embedds/'
|
8 |
+
results:Results_Handler = Results_Handler(path=path,splits=['train','valid','test','ood'],read_ad=True)
|
9 |
+
|
10 |
+
mapping_dict = load('/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/subclass_to_annotation.json')
|
11 |
+
mapping_dict['artificial_affix'] = 'artificial_affix'
|
12 |
+
train_df = results.splits_df_dict['train_df']
|
13 |
+
valid_df = results.splits_df_dict['valid_df']
|
14 |
+
test_df = results.splits_df_dict['test_df']
|
15 |
+
ood_df = results.splits_df_dict['ood_df']
|
16 |
+
#remove RNA Sequences from the dataframe if not in results.ad.var.index
|
17 |
+
train_df = train_df[train_df['RNA Sequences'].isin(results.ad.var[results.ad.var['hico'] == True].index)['0'].values]
|
18 |
+
valid_df = valid_df[valid_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values]
|
19 |
+
test_df = test_df[test_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values]
|
20 |
+
ood_df = ood_df[ood_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values]
|
21 |
+
#concatenate train,valid and test
|
22 |
+
train_val_test_df = train_df.append(valid_df).append(test_df)
|
23 |
+
#map Labels to annotation
|
24 |
+
hico_id_labels = train_val_test_df['Labels','0'].map(mapping_dict).values
|
25 |
+
hico_ood_labels = ood_df['Labels','0'].map(mapping_dict).values
|
26 |
+
#read ad
|
27 |
+
ad = results.ad
|
28 |
+
|
29 |
+
hico_loco_df = pd.DataFrame(columns=['mc','hico_id','hico_ood','loco'])
|
30 |
+
for mc in ad.var['small_RNA_class_annotation'][ad.var['hico'] == True].unique():
|
31 |
+
hico_loco_df = hico_loco_df.append({'mc':mc,
|
32 |
+
'hico_id':sum([mc in i for i in hico_id_labels]),
|
33 |
+
'hico_ood':sum([mc in i for i in hico_ood_labels]),
|
34 |
+
'loco':sum([mc in i for i in ad.var['small_RNA_class_annotation'][ad.var['hico'] != True].values])},ignore_index=True)
|
35 |
+
# %%
|
36 |
+
|
37 |
+
|
38 |
+
order = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA']
|
39 |
+
|
40 |
+
fig = go.Figure()
|
41 |
+
fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['hico_id'],name='HICO ID',marker_color='#00CC96'))
|
42 |
+
fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['hico_ood'],name='HICO OOD',marker_color='darkcyan'))
|
43 |
+
fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['loco'],name='LOCO',marker_color='#7f7f7f'))
|
44 |
+
fig.update_layout(barmode='group')
|
45 |
+
fig.update_layout(width=800,height=800)
|
46 |
+
#order the x axis
|
47 |
+
fig.update_layout(xaxis={'categoryorder':'array','categoryarray':order})
|
48 |
+
fig.update_layout(xaxis_title='Major Class',yaxis_title='Number of Sequences')
|
49 |
+
fig.update_layout(title='Number of Sequences per Major Class in ID, OOD and LOCO')
|
50 |
+
fig.update_layout(yaxis_type="log")
|
51 |
+
#save as png
|
52 |
+
pio.write_image(fig,'hico_id_ood_loco_proportion.png')
|
53 |
+
#save as svg
|
54 |
+
pio.write_image(fig,'hico_id_ood_loco_proportion.svg')
|
55 |
+
fig.show()
|
56 |
+
# %%
|
transforna/bin/figure_scripts/figure_S9_S11.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#in this file, the progression of the number of hicos per major class is computed per model
|
2 |
+
#this is done before ID, after FULL.
|
3 |
+
#%%
|
4 |
+
from transforna import load
|
5 |
+
from transforna import predict_transforna,predict_transforna_all_models
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
import numpy as np
|
9 |
+
import plotly.io as pio
|
10 |
+
import plotly.express as px
|
11 |
+
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json'
|
12 |
+
models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
|
13 |
+
|
14 |
+
mapping_dict = load(mapping_dict_path)
|
15 |
+
|
16 |
+
#%%
|
17 |
+
dataset:str = 'LC'
|
18 |
+
hico_loco_na_flag:str = 'hico'
|
19 |
+
assert hico_loco_na_flag in ['hico','loco_na'], 'hico_loco_na_flag must be either hico or loco_na'
|
20 |
+
if dataset == 'TCGA':
|
21 |
+
dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
|
22 |
+
else:
|
23 |
+
dataset_path_train: str = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
|
24 |
+
|
25 |
+
prediction_single_pd = predict_transforna(['AAAAAAACCCCCTTTTTTT'],model='Seq',logits_flag = True,trained_on='id',path_to_models=models_path)
|
26 |
+
sub_classes_used_for_training = prediction_single_pd.columns.tolist()
|
27 |
+
|
28 |
+
var = load(dataset_path_train).set_index('sequence')
|
29 |
+
#remove from var all indexes that are longer than 30
|
30 |
+
var = var[var.index.str.len() <= 30]
|
31 |
+
hico_seqs_all = var.index[var['hico']].tolist()
|
32 |
+
hico_labels_all = var['subclass_name'][var['hico']].values
|
33 |
+
|
34 |
+
hico_seqs_id = var.index[var.hico & var.subclass_name.isin(sub_classes_used_for_training)].tolist()
|
35 |
+
hico_labels_id = var['subclass_name'][var.hico & var.subclass_name.isin(sub_classes_used_for_training)].values
|
36 |
+
|
37 |
+
non_hico_seqs = var['subclass_name'][var['hico'] == False].index.values
|
38 |
+
non_hico_labels = var['subclass_name'][var['hico'] == False].values
|
39 |
+
|
40 |
+
|
41 |
+
#filter hico labels and hico seqs to hico ID
|
42 |
+
if hico_loco_na_flag == 'loco_na':
|
43 |
+
curr_seqs = non_hico_seqs
|
44 |
+
curr_labels = non_hico_labels
|
45 |
+
elif hico_loco_na_flag == 'hico':
|
46 |
+
curr_seqs = hico_seqs_id
|
47 |
+
curr_labels = hico_labels_id
|
48 |
+
|
49 |
+
full_df = predict_transforna_all_models(sequences=curr_seqs,path_to_models=models_path)
|
50 |
+
|
51 |
+
|
52 |
+
#%%
|
53 |
+
mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA']
|
54 |
+
#for each mc, get the sequences of hicos in that mc and compute the number of hicos per model
|
55 |
+
num_hicos_per_mc = {}
|
56 |
+
if hico_loco_na_flag == 'hico':#this is where ground truth exists (hico id)
|
57 |
+
curr_labels_id_mc = [mapping_dict[label] for label in curr_labels]
|
58 |
+
|
59 |
+
elif hico_loco_na_flag == 'loco_na': # this is where ground truth does not exist (LOCO/NA)
|
60 |
+
ensemble_preds = full_df[full_df.Model == 'Ensemble'].set_index('Sequence').loc[curr_seqs].reset_index()
|
61 |
+
curr_labels_id_mc = [mapping_dict[label] for label in ensemble_preds['Net-Label']]
|
62 |
+
|
63 |
+
for mc in mcs:
|
64 |
+
#select sequences from hico_seqs that are in the major class mc
|
65 |
+
mc_seqs = [seq for seq,label in zip(curr_seqs,curr_labels_id_mc) if label == mc]
|
66 |
+
if len(mc_seqs) == 0:
|
67 |
+
num_hicos_per_mc[mc] = {model:0 for model in full_df.Model.unique()}
|
68 |
+
continue
|
69 |
+
#only keep in full_df the sequences that are in mc_seqs
|
70 |
+
mc_full_df = full_df[full_df.Sequence.isin(mc_seqs)]
|
71 |
+
curr_num_hico_per_model = mc_full_df[mc_full_df['Is Familiar?']].groupby(['Model'])['Is Familiar?'].value_counts().droplevel(1)
|
72 |
+
#remove Baseline from index
|
73 |
+
curr_num_hico_per_model = curr_num_hico_per_model.drop('Baseline')
|
74 |
+
curr_num_hico_per_model -= curr_num_hico_per_model.min()
|
75 |
+
num_hicos_per_mc[mc] = curr_num_hico_per_model.to_dict()
|
76 |
+
#%%
|
77 |
+
to_plot_df = pd.DataFrame(num_hicos_per_mc)
|
78 |
+
to_plot_mcs = ['rRNA','tRNA','snoRNA']
|
79 |
+
fig = go.Figure()
|
80 |
+
#x axis should be the mcs
|
81 |
+
for model in num_hicos_per_mc['rRNA'].keys():
|
82 |
+
fig.add_trace(go.Bar(x=mcs, y=[num_hicos_per_mc[mc][model] for mc in mcs], name=model))
|
83 |
+
|
84 |
+
fig.update_layout(barmode='group')
|
85 |
+
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
|
86 |
+
fig.write_image(f"num_hicos_per_model_{dataset}_{hico_loco_na_flag}.svg")
|
87 |
+
fig.update_yaxes(type="log")
|
88 |
+
fig.show()
|
89 |
+
|
90 |
+
#%%
|
91 |
+
|
92 |
+
import pandas as pd
|
93 |
+
import glob
|
94 |
+
from plotly import graph_objects as go
|
95 |
+
from transforna import load,predict_transforna
|
96 |
+
all_df = pd.DataFrame()
|
97 |
+
files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/LC-novel_lev_dist_df.csv')
|
98 |
+
for file in files:
|
99 |
+
df = pd.read_csv(file)
|
100 |
+
all_df = pd.concat([all_df,df])
|
101 |
+
all_df = all_df.drop(columns=['Unnamed: 0'])
|
102 |
+
all_df.loc[all_df.split.isnull(),'split'] = 'NA'
|
103 |
+
ensemble_df = all_df[all_df.Model == 'Ensemble']
|
104 |
+
# %%
|
105 |
+
|
106 |
+
lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
|
107 |
+
lc_df = load(lc_path)
|
108 |
+
lc_df.set_index('sequence',inplace=True)
|
109 |
+
# %%
|
110 |
+
#filter lc_df to only include sequences that are in ensemble_df
|
111 |
+
lc_df = lc_df.loc[ensemble_df.Sequence]
|
112 |
+
actual_major_classes = lc_df['small_RNA_class_annotation']
|
113 |
+
predicted_major_classes = ensemble_df[['Net-Label','Sequence']].set_index('Sequence').loc[lc_df.index]['Net-Label'].map(mapping_dict)
|
114 |
+
# %%
|
115 |
+
#plot correlation matrix between actual and predicted major classes
|
116 |
+
from sklearn.metrics import confusion_matrix
|
117 |
+
import seaborn as sns
|
118 |
+
import matplotlib.pyplot as plt
|
119 |
+
import numpy as np
|
120 |
+
major_classes = list(set(list(predicted_major_classes.unique())+list(actual_major_classes.unique())))
|
121 |
+
conf_matrix = confusion_matrix(actual_major_classes,predicted_major_classes,labels=major_classes)
|
122 |
+
conf_matrix = conf_matrix/np.sum(conf_matrix,axis=1)
|
123 |
+
|
124 |
+
sns.heatmap(conf_matrix,annot=True,cmap='Blues')
|
125 |
+
for i in range(conf_matrix.shape[0]):
|
126 |
+
for j in range(conf_matrix.shape[1]):
|
127 |
+
conf_matrix[i,j] = round(conf_matrix[i,j],1)
|
128 |
+
|
129 |
+
|
130 |
+
plt.xlabel('Predicted Major Class')
|
131 |
+
plt.ylabel('Actual Major Class')
|
132 |
+
plt.xticks(np.arange(len(major_classes)),major_classes,rotation=90)
|
133 |
+
plt.yticks(np.arange(len(major_classes)),major_classes,rotation=0)
|
134 |
+
plt.show()
|
135 |
+
|
136 |
+
# %%
|
transforna/bin/figure_scripts/infer_lc_using_tcga.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import sys
|
3 |
+
from random import randint
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
from anndata import AnnData
|
8 |
+
|
9 |
+
#add parent directory to path
|
10 |
+
sys.path.append('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/')
|
11 |
+
from src import (Results_Handler, correct_labels, load, predict_transforna,
|
12 |
+
predict_transforna_all_models,get_fused_seqs)
|
13 |
+
|
14 |
+
|
15 |
+
def get_mc_sc(infer_df,sequences,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag = False):
|
16 |
+
|
17 |
+
infered_seqs = infer_df.loc[sequences]
|
18 |
+
sc_classes_df = infered_seqs['subclass_name'].str.split(';',expand=True)
|
19 |
+
#filter rows with all nans in sc_classes_df
|
20 |
+
sc_classes_df = sc_classes_df[~sc_classes_df.isnull().all(axis=1)]
|
21 |
+
#cmask for classes used for training
|
22 |
+
if ood_flag:
|
23 |
+
sub_classes_used_for_training_plus_neighbors = []
|
24 |
+
#for every subclass in sub_classes_used_for_training that contains bin, get previous and succeeding bins
|
25 |
+
for sub_class in sub_classes_used_for_training:
|
26 |
+
sub_classes_used_for_training_plus_neighbors.append(sub_class)
|
27 |
+
if 'bin' in sub_class:
|
28 |
+
bin_num = int(sub_class.split('_bin-')[1])
|
29 |
+
if bin_num > 0:
|
30 |
+
sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num-1}')
|
31 |
+
sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num+1}')
|
32 |
+
if 'tR' in sub_class:
|
33 |
+
#seperate the first part(either 3p/5p), also ge tthe second part after __
|
34 |
+
first_part = sub_class.split('-')[0]
|
35 |
+
second_part = sub_class.split('__')[1]
|
36 |
+
#get all classes in sc_to_mc_mapper_dict,values that contain both parts and append them to sub_classes_used_for_training_plus_neighbors
|
37 |
+
sub_classes_used_for_training_plus_neighbors += [sc for sc in sc_to_mc_mapper_dict.keys() if first_part in sc and second_part in sc]
|
38 |
+
sub_classes_used_for_training_plus_neighbors = list(set(sub_classes_used_for_training_plus_neighbors))
|
39 |
+
mask = sc_classes_df.applymap(lambda x: True if (x not in sub_classes_used_for_training_plus_neighbors and 'hypermapper' not in x)\
|
40 |
+
or pd.isnull(x) else False)
|
41 |
+
|
42 |
+
else:
|
43 |
+
mask = sc_classes_df.applymap(lambda x: True if x in sub_classes_used_for_training or pd.isnull(x) else False)
|
44 |
+
|
45 |
+
#check if any sub class in sub_classes_used_for_training is in sc_classes_df
|
46 |
+
if mask.apply(lambda x: all(x.tolist()), axis=1).sum() == 0:
|
47 |
+
#TODO: change to log
|
48 |
+
import logging
|
49 |
+
log_ = logging.getLogger(__name__)
|
50 |
+
log_.error('None of the sub classes used for training are in the sequences')
|
51 |
+
raise Exception('None of the sub classes used for training are in the sequences')
|
52 |
+
|
53 |
+
#filter rows with atleast one False in mask
|
54 |
+
sc_classes_df = sc_classes_df[mask.apply(lambda x: all(x.tolist()), axis=1)]
|
55 |
+
#get mc classes
|
56 |
+
mc_classes_df = sc_classes_df.applymap(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict else 'not_found')
|
57 |
+
#assign major class for not found values if containing 'miRNA', 'tRNA', 'rRNA', 'snRNA', 'snoRNA'
|
58 |
+
#mc_classes_df = mc_classes_df.applymap(lambda x: None if x is None else 'miRNA' if 'miR' in x else 'tRNA' if 'tRNA' in x else 'rRNA' if 'rRNA' in x else 'snRNA' if 'snRNA' in x else 'snoRNA' if 'snoRNA' in x else 'snoRNA' if 'SNO' in x else 'protein_coding' if 'RPL37A' in x else 'lncRNA' if 'SNHG1' in x else 'not_found')
|
59 |
+
#filter all 'not_found' rows
|
60 |
+
mc_classes_df = mc_classes_df[mc_classes_df.apply(lambda x: 'not_found' not in x.tolist() ,axis=1)]
|
61 |
+
#filter values with ; in mc_classes_df
|
62 |
+
mc_classes_df = mc_classes_df[~mc_classes_df[0].str.contains(';')]
|
63 |
+
#filter index
|
64 |
+
sc_classes_df = sc_classes_df.loc[mc_classes_df.index]
|
65 |
+
mc_classes_df = mc_classes_df.loc[sc_classes_df.index]
|
66 |
+
return mc_classes_df,sc_classes_df
|
67 |
+
|
68 |
+
def plot_confusion_false_novel(df,sc_df,mc_df,save_figs:bool=False):
|
69 |
+
#filter index of sc_classes_df to contain indices of outliers df
|
70 |
+
curr_sc_classes_df = sc_df.loc[[i for i in df.index if i in sc_df.index]]
|
71 |
+
curr_mc_classes_df = mc_df.loc[[i for i in df.index if i in mc_df.index]]
|
72 |
+
#convert Labels to mc_Labels
|
73 |
+
df = df.assign(predicted_mc_labels=df.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1))
|
74 |
+
#add mc classes
|
75 |
+
df = df.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist())
|
76 |
+
#add sc classes
|
77 |
+
df = df.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist())
|
78 |
+
#compute accuracy
|
79 |
+
df = df.assign(mc_accuracy=df.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1))
|
80 |
+
df = df.assign(sc_accuracy=df.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1))
|
81 |
+
|
82 |
+
#use plotly to plot confusion matrix based on mc classes
|
83 |
+
mc_confusion_matrix = df.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack()
|
84 |
+
mc_confusion_matrix = mc_confusion_matrix.fillna(0)
|
85 |
+
mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1)
|
86 |
+
mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,2))
|
87 |
+
#for columns not in rows, sum the column values and add them to a new column called 'other'
|
88 |
+
other_col = [0]*mc_confusion_matrix.shape[0]
|
89 |
+
for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]:
|
90 |
+
other_col += mc_confusion_matrix[i]
|
91 |
+
mc_confusion_matrix['other'] = other_col
|
92 |
+
#add an other row with all zeros
|
93 |
+
mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1]
|
94 |
+
#drop all columns not in rows
|
95 |
+
mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1)
|
96 |
+
#plot confusion matri
|
97 |
+
fig = go.Figure(data=go.Heatmap(
|
98 |
+
z=mc_confusion_matrix.values,
|
99 |
+
x=mc_confusion_matrix.columns,
|
100 |
+
y=mc_confusion_matrix.index,
|
101 |
+
hoverongaps = False))
|
102 |
+
#add z values to heatmap
|
103 |
+
for i in range(len(mc_confusion_matrix.index)):
|
104 |
+
for j in range(len(mc_confusion_matrix.columns)):
|
105 |
+
fig.add_annotation(text=str(mc_confusion_matrix.values[i][j]), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i],
|
106 |
+
showarrow=False, font_size=25, font_color='red')
|
107 |
+
#add title
|
108 |
+
fig.update_layout(title_text='Confusion matrix based on mc classes for false novel sequences')
|
109 |
+
#label x axis and y axis
|
110 |
+
fig.update_xaxes(title_text='Predicted mc class')
|
111 |
+
fig.update_yaxes(title_text='Actual mc class')
|
112 |
+
#save
|
113 |
+
if save_figs:
|
114 |
+
fig.write_image('transforna/bin/lc_figures/confusion_matrix_mc_classes_false_novel.png')
|
115 |
+
|
116 |
+
|
117 |
+
def compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers = False,fig_prefix:str = '',save_figs:bool=False):
|
118 |
+
font_size = 25
|
119 |
+
if fig_prefix == 'LC-familiar':
|
120 |
+
font_size = 10
|
121 |
+
#rename Labels to predicted_sc_labels
|
122 |
+
prediction_pd = prediction_pd.rename(columns={'Net-Label':'predicted_sc_labels'})
|
123 |
+
|
124 |
+
for model in prediction_pd['Model'].unique():
|
125 |
+
#get model predictions
|
126 |
+
num_rows = sc_classes_df.shape[0]
|
127 |
+
model_prediction_pd = prediction_pd[prediction_pd['Model'] == model]
|
128 |
+
model_prediction_pd = model_prediction_pd.set_index('Sequence')
|
129 |
+
#filter index of model_prediction_pd to contain indices of sc_classes_df
|
130 |
+
model_prediction_pd = model_prediction_pd.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]]
|
131 |
+
|
132 |
+
try: #try because ensemble models do not have a folder
|
133 |
+
#check how many of the hico seqs exist in the train_df
|
134 |
+
embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/{model}/embedds'
|
135 |
+
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
|
136 |
+
except:
|
137 |
+
embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq-Rev/embedds'
|
138 |
+
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
|
139 |
+
|
140 |
+
train_seqs = set(results.splits_df_dict['train_df']['RNA Sequences']['0'].values.tolist())
|
141 |
+
common_seqs = train_seqs.intersection(set(model_prediction_pd.index.tolist()))
|
142 |
+
print(f'Number of common seqs between train_df and {model} is {len(common_seqs)}')
|
143 |
+
#print(f'removing overlaping sequences between train set and inference')
|
144 |
+
#remove common_seqs from model_prediction_pd
|
145 |
+
#model_prediction_pd = model_prediction_pd.drop(common_seqs)
|
146 |
+
|
147 |
+
|
148 |
+
#compute number of sequences where NLD is higher than Novelty Threshold
|
149 |
+
num_outliers = sum(model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'])
|
150 |
+
false_novel_df = model_prediction_pd[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold']]
|
151 |
+
|
152 |
+
plot_confusion_false_novel(false_novel_df,sc_classes_df,mc_classes_df,save_figs)
|
153 |
+
#draw a pie chart depicting number of outliers per actual_mc_labels
|
154 |
+
fig_outl = mc_classes_df.loc[false_novel_df.index][0].value_counts().plot.pie(autopct='%1.1f%%',figsize=(6, 6))
|
155 |
+
fig_outl.set_title(f'False Novel per MC for {model}: {num_outliers}')
|
156 |
+
if save_figs:
|
157 |
+
fig_outl.get_figure().savefig(f'transforna/bin/lc_figures/false_novel_mc_{model}.png')
|
158 |
+
fig_outl.get_figure().clf()
|
159 |
+
#get number of unique sub classes per major class in false_novel_df
|
160 |
+
false_novel_sc_freq_df = sc_classes_df.loc[false_novel_df.index][0].value_counts().to_frame()
|
161 |
+
#save index as csv
|
162 |
+
#false_novel_sc_freq_df.to_csv(f'false_novel_sc_freq_df_{model}.csv')
|
163 |
+
#add mc to false_novel_sc_freq_df
|
164 |
+
false_novel_sc_freq_df['MC'] = false_novel_sc_freq_df.index.map(lambda x: sc_to_mc_mapper_dict[x])
|
165 |
+
#plot pie chart showing unique sub classes per major class in false_novel_df
|
166 |
+
fig_outl_sc = false_novel_sc_freq_df.groupby('MC')[0].sum().plot.pie(autopct='%1.1f%%',figsize=(6, 6))
|
167 |
+
fig_outl_sc.set_title(f'False novel: No. Unique sub classes per MC {model}: {num_outliers}')
|
168 |
+
if save_figs:
|
169 |
+
fig_outl_sc.get_figure().savefig(f'transforna/bin/lc_figures/{fig_prefix}_false_novel_sc_{model}.png')
|
170 |
+
fig_outl_sc.get_figure().clf()
|
171 |
+
#filter outliers
|
172 |
+
if seperate_outliers:
|
173 |
+
model_prediction_pd = model_prediction_pd[model_prediction_pd['NLD'] <= model_prediction_pd['Novelty Threshold']]
|
174 |
+
else:
|
175 |
+
#set the predictions of outliers to 'Outlier'
|
176 |
+
model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_sc_labels'] = 'Outlier'
|
177 |
+
model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_mc_labels'] = 'Outlier'
|
178 |
+
sc_to_mc_mapper_dict['Outlier'] = 'Outlier'
|
179 |
+
|
180 |
+
#filter index of sc_classes_df to contain indices of model_prediction_pd
|
181 |
+
curr_sc_classes_df = sc_classes_df.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]]
|
182 |
+
curr_mc_classes_df = mc_classes_df.loc[[i for i in model_prediction_pd.index if i in mc_classes_df.index]]
|
183 |
+
#convert Labels to mc_Labels
|
184 |
+
model_prediction_pd = model_prediction_pd.assign(predicted_mc_labels=model_prediction_pd.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1))
|
185 |
+
#add mc classes
|
186 |
+
model_prediction_pd = model_prediction_pd.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist())
|
187 |
+
#add sc classes
|
188 |
+
model_prediction_pd = model_prediction_pd.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist())
|
189 |
+
#correct labels
|
190 |
+
model_prediction_pd['predicted_sc_labels'] = correct_labels(model_prediction_pd['predicted_sc_labels'],model_prediction_pd['actual_sc_labels'],sc_to_mc_mapper_dict)
|
191 |
+
#compute accuracy
|
192 |
+
model_prediction_pd = model_prediction_pd.assign(mc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1))
|
193 |
+
model_prediction_pd = model_prediction_pd.assign(sc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1))
|
194 |
+
|
195 |
+
if not seperate_outliers:
|
196 |
+
cols_to_save = ['actual_mc_labels','predicted_mc_labels','predicted_sc_labels','actual_sc_labels']
|
197 |
+
total_false_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels != model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save]
|
198 |
+
#add a column indicating if NLD is greater than Novelty Threshold
|
199 |
+
total_false_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_false_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_false_mc_predictions_df.index]['Novelty Threshold']
|
200 |
+
#save
|
201 |
+
total_false_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_false_mcs_w_out_{model}.csv')
|
202 |
+
total_true_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels == model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save]
|
203 |
+
#add a column indicating if NLD is greater than Novelty Threshold
|
204 |
+
total_true_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_true_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_true_mc_predictions_df.index]['Novelty Threshold']
|
205 |
+
#save
|
206 |
+
total_true_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_true_mcs_w_out_{model}.csv')
|
207 |
+
|
208 |
+
print('Model: ', model)
|
209 |
+
print('num_outliers: ', num_outliers)
|
210 |
+
#print accuracy including outliers
|
211 |
+
print('mc_accuracy: ', model_prediction_pd['mc_accuracy'].mean())
|
212 |
+
print('sc_accuracy: ', model_prediction_pd['sc_accuracy'].mean())
|
213 |
+
|
214 |
+
#print balanced accuracy
|
215 |
+
print('mc_balanced_accuracy: ', model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean())
|
216 |
+
print('sc_balanced_accuracy: ', model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean())
|
217 |
+
|
218 |
+
#use plotly to plot confusion matrix based on mc classes
|
219 |
+
mc_confusion_matrix = model_prediction_pd.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack()
|
220 |
+
mc_confusion_matrix = mc_confusion_matrix.fillna(0)
|
221 |
+
mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1)
|
222 |
+
mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,4))
|
223 |
+
#for columns not in rows, sum the column values and add them to a new column called 'other'
|
224 |
+
other_col = [0]*mc_confusion_matrix.shape[0]
|
225 |
+
for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]:
|
226 |
+
other_col += mc_confusion_matrix[i]
|
227 |
+
mc_confusion_matrix['other'] = other_col
|
228 |
+
#add an other row with all zeros
|
229 |
+
mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1]
|
230 |
+
#drop all columns not in rows
|
231 |
+
mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1)
|
232 |
+
#plot confusion matrix
|
233 |
+
|
234 |
+
fig = go.Figure(data=go.Heatmap(
|
235 |
+
z=mc_confusion_matrix.values,
|
236 |
+
x=mc_confusion_matrix.columns,
|
237 |
+
y=mc_confusion_matrix.index,
|
238 |
+
colorscale='Blues',
|
239 |
+
hoverongaps = False))
|
240 |
+
#add z values to heatmap
|
241 |
+
for i in range(len(mc_confusion_matrix.index)):
|
242 |
+
for j in range(len(mc_confusion_matrix.columns)):
|
243 |
+
fig.add_annotation(text=str(round(mc_confusion_matrix.values[i][j],2)), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i],
|
244 |
+
showarrow=False, font_size=font_size, font_color='black')
|
245 |
+
|
246 |
+
fig.update_layout(
|
247 |
+
title='Confusion matrix for mc classes - ' + model + ' - ' + 'mc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean(),2)) \
|
248 |
+
+ ' - ' + 'sc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean(),2)) + '<br>' + \
|
249 |
+
'percent false novel: ' + str(round(num_outliers/num_rows,2)),
|
250 |
+
xaxis_nticks=36)
|
251 |
+
#label x axis and y axis
|
252 |
+
fig.update_xaxes(title_text='Predicted mc class')
|
253 |
+
fig.update_yaxes(title_text='Actual mc class')
|
254 |
+
if save_figs:
|
255 |
+
#save plot
|
256 |
+
if seperate_outliers:
|
257 |
+
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.png')
|
258 |
+
#save svg
|
259 |
+
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.svg')
|
260 |
+
else:
|
261 |
+
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.png')
|
262 |
+
#save svg
|
263 |
+
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.svg')
|
264 |
+
print('\n')
|
265 |
+
|
266 |
+
|
267 |
+
if __name__ == '__main__':
|
268 |
+
#####################################################################################################################
|
269 |
+
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'
|
270 |
+
LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
|
271 |
+
path_to_models = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
|
272 |
+
|
273 |
+
trained_on = 'full' #id or full
|
274 |
+
save_figs = True
|
275 |
+
|
276 |
+
infer_aa = infer_relaxed_mirna = infer_hico = infer_ood = infer_other_affixes = infer_random = infer_fused = infer_na = infer_loco = False
|
277 |
+
|
278 |
+
split = 'infer_hico'#sys.argv[1]
|
279 |
+
print(f'Running inference for {split}')
|
280 |
+
if split == 'infer_aa':
|
281 |
+
infer_aa = True
|
282 |
+
elif split == 'infer_relaxed_mirna':
|
283 |
+
infer_relaxed_mirna = True
|
284 |
+
elif split == 'infer_hico':
|
285 |
+
infer_hico = True
|
286 |
+
elif split == 'infer_ood':
|
287 |
+
infer_ood = True
|
288 |
+
elif split == 'infer_other_affixes':
|
289 |
+
infer_other_affixes = True
|
290 |
+
elif split == 'infer_random':
|
291 |
+
infer_random = True
|
292 |
+
elif split == 'infer_fused':
|
293 |
+
infer_fused = True
|
294 |
+
elif split == 'infer_na':
|
295 |
+
infer_na = True
|
296 |
+
elif split == 'infer_loco':
|
297 |
+
infer_loco = True
|
298 |
+
|
299 |
+
#####################################################################################################################
|
300 |
+
#only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico should be true
|
301 |
+
if sum([infer_aa,infer_relaxed_mirna,infer_hico,infer_ood,infer_other_affixes,infer_random,infer_fused,infer_na,infer_loco]) != 1:
|
302 |
+
raise Exception('Only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico or infer_other_affixes or infer_random or infer_fused or infer_na should be true')
|
303 |
+
|
304 |
+
#set fig_prefix
|
305 |
+
if infer_aa:
|
306 |
+
fig_prefix = '5\'A-affixes'
|
307 |
+
elif infer_other_affixes:
|
308 |
+
fig_prefix = 'other_affixes'
|
309 |
+
elif infer_relaxed_mirna:
|
310 |
+
fig_prefix = 'Relaxed-miRNA'
|
311 |
+
elif infer_hico:
|
312 |
+
fig_prefix = 'LC-familiar'
|
313 |
+
elif infer_ood:
|
314 |
+
fig_prefix = 'LC-novel'
|
315 |
+
elif infer_random:
|
316 |
+
fig_prefix = 'Random'
|
317 |
+
elif infer_fused:
|
318 |
+
fig_prefix = 'Fused'
|
319 |
+
elif infer_na:
|
320 |
+
fig_prefix = 'NA'
|
321 |
+
elif infer_loco:
|
322 |
+
fig_prefix = 'LOCO'
|
323 |
+
|
324 |
+
infer_df = load(LC_path)
|
325 |
+
if isinstance(infer_df,AnnData):
|
326 |
+
infer_df = infer_df.var
|
327 |
+
infer_df.set_index('sequence',inplace=True)
|
328 |
+
sc_to_mc_mapper_dict = load(mapping_dict_path)
|
329 |
+
#select hico sequences
|
330 |
+
hico_seqs = infer_df.index[infer_df['hico']].tolist()
|
331 |
+
art_affix_seqs = infer_df[~infer_df['five_prime_adapter_filter']].index.tolist()
|
332 |
+
|
333 |
+
if infer_hico:
|
334 |
+
hico_seqs = hico_seqs
|
335 |
+
|
336 |
+
if infer_aa:
|
337 |
+
hico_seqs = art_affix_seqs
|
338 |
+
|
339 |
+
if infer_other_affixes:
|
340 |
+
hico_seqs = infer_df[~infer_df['hbdx_spikein_affix_filter']].index.tolist()
|
341 |
+
|
342 |
+
if infer_na:
|
343 |
+
hico_seqs = infer_df[infer_df.subclass_name == 'no_annotation'].index.tolist()
|
344 |
+
|
345 |
+
if infer_loco:
|
346 |
+
hico_seqs = infer_df[~infer_df['hico']][infer_df.subclass_name != 'no_annotation'].index.tolist()
|
347 |
+
|
348 |
+
#for mirnas
|
349 |
+
if infer_relaxed_mirna:
|
350 |
+
#subclass name must contain miR, let, Let and not contain ; and that are not hico
|
351 |
+
mirnas_seqs = infer_df[infer_df.subclass_name.str.contains('miR') | infer_df.subclass_name.str.contains('let')][~infer_df.subclass_name.str.contains(';')].index.tolist()
|
352 |
+
#remove the ones that are true in ad.hico column
|
353 |
+
hico_seqs = list(set(mirnas_seqs).difference(set(hico_seqs)))
|
354 |
+
|
355 |
+
#novel mirnas
|
356 |
+
#mirnas_not_in_train_mask = (ad['hico']==True).values * ~(ad['subclass_name'].isin(mirna_train_sc)).values * (ad['small_RNA_class_annotation'].isin(['miRNA']))
|
357 |
+
#hicos = ad[mirnas_not_in_train_mask].index.tolist()
|
358 |
+
|
359 |
+
|
360 |
+
if infer_random:
|
361 |
+
#create random sequences
|
362 |
+
random_seqs = []
|
363 |
+
while len(random_seqs) < 200:
|
364 |
+
random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(18,30)))
|
365 |
+
if random_seq not in random_seqs:
|
366 |
+
random_seqs.append(random_seq)
|
367 |
+
hico_seqs = random_seqs
|
368 |
+
|
369 |
+
if infer_fused:
|
370 |
+
hico_seqs = get_fused_seqs(hico_seqs,num_sequences=200)
|
371 |
+
|
372 |
+
|
373 |
+
#hico_seqs = ad[ad.subclass_name.str.contains('mir')][~ad.subclass_name.str.contains(';')]['subclass_name'].index.tolist()
|
374 |
+
hico_seqs = [seq for seq in hico_seqs if len(seq) <= 30]
|
375 |
+
#set cuda 1
|
376 |
+
import os
|
377 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
|
378 |
+
|
379 |
+
#run prediction
|
380 |
+
prediction_pd = predict_transforna_all_models(hico_seqs,trained_on=trained_on,path_to_models=path_to_models)
|
381 |
+
prediction_pd['split'] = fig_prefix
|
382 |
+
#the if condition here is to make sure to filter seqs with sub classes not used in training
|
383 |
+
if not infer_ood and not infer_relaxed_mirna and not infer_hico:
|
384 |
+
prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv')
|
385 |
+
if infer_aa or infer_other_affixes or infer_random or infer_fused:
|
386 |
+
for model in prediction_pd.Model.unique():
|
387 |
+
num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?'])
|
388 |
+
print(f'Number of non novel sequences for {model} is {num_non_novel}')
|
389 |
+
print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the lower the better')
|
390 |
+
|
391 |
+
else:
|
392 |
+
if infer_na or infer_loco:
|
393 |
+
#print number of Is Familiar per model
|
394 |
+
for model in prediction_pd.Model.unique():
|
395 |
+
num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?'])
|
396 |
+
print(f'Number of non novel sequences for {model} is {num_non_novel}')
|
397 |
+
print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the higher the better')
|
398 |
+
print('\n')
|
399 |
+
else:
|
400 |
+
#only to get classes used for training
|
401 |
+
prediction_single_pd = predict_transforna(hico_seqs[0],model='Seq',logits_flag = True,trained_on=trained_on,path_to_models=path_to_models)
|
402 |
+
sub_classes_used_for_training = prediction_single_pd.columns.tolist()
|
403 |
+
|
404 |
+
|
405 |
+
mc_classes_df,sc_classes_df = get_mc_sc(infer_df,hico_seqs,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag=infer_ood)
|
406 |
+
if infer_ood:
|
407 |
+
for model in prediction_pd.Model.unique():
|
408 |
+
#filter sequences in prediction_pd to only include sequences in sc_classes_df
|
409 |
+
curr_prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)]
|
410 |
+
#filter curr_prediction toonly include model
|
411 |
+
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd.Model == model]
|
412 |
+
num_seqs = curr_prediction_pd.shape[0]
|
413 |
+
#filter Is Familiar
|
414 |
+
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Is Familiar?']]
|
415 |
+
#filter sc_classes_df to only include sequences in curr_prediction_pd
|
416 |
+
curr_sc_classes_df = sc_classes_df[sc_classes_df.index.isin(curr_prediction_pd['Sequence'].values)]
|
417 |
+
#correct labels and remove the correct labels from the curr_prediction_pd
|
418 |
+
curr_prediction_pd['Net-Label'] = correct_labels(curr_prediction_pd['Net-Label'].values,curr_sc_classes_df[0].values,sc_to_mc_mapper_dict)
|
419 |
+
#filter rows in curr_prediction where Labels is equal to sc_classes_df[0]
|
420 |
+
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Net-Label'].values != curr_sc_classes_df[0].values]
|
421 |
+
num_non_novel = len(curr_prediction_pd)
|
422 |
+
print(f'Number of non novel sequences for {model} is {num_non_novel}')
|
423 |
+
print(f'Percent non novel for {model} is {num_non_novel/num_seqs}, the lower the better')
|
424 |
+
print('\n')
|
425 |
+
else:
|
426 |
+
#filter prediction_pd to include only sequences in prediction_pd
|
427 |
+
|
428 |
+
#compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=False,fig_prefix = fig_prefix,save_figs=save_figs)
|
429 |
+
compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=True,fig_prefix = fig_prefix,save_figs=save_figs)
|
430 |
+
|
431 |
+
if infer_ood or infer_relaxed_mirna or infer_hico:
|
432 |
+
prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)]
|
433 |
+
#save lev_dist_df
|
434 |
+
prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv')
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
+
|
transforna/src/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .callbacks import *
|
2 |
+
from .callbacks.tbWriter import writer
|
3 |
+
from .inference import *
|
4 |
+
from .model import *
|
5 |
+
from .novelty_prediction import *
|
6 |
+
from .processing import *
|
7 |
+
from .score import *
|
8 |
+
from .train import *
|
9 |
+
from .utils import *
|
transforna/src/callbacks/LRCallback.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections.abc import Iterable
|
3 |
+
from math import cos, floor, log, pi
|
4 |
+
|
5 |
+
import skorch
|
6 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
7 |
+
|
8 |
+
_LRScheduler
|
9 |
+
|
10 |
+
|
11 |
+
class CyclicCosineDecayLR(skorch.callbacks.Callback):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
optimizer,
|
15 |
+
init_interval,
|
16 |
+
min_lr,
|
17 |
+
len_param_groups,
|
18 |
+
base_lrs,
|
19 |
+
restart_multiplier=None,
|
20 |
+
restart_interval=None,
|
21 |
+
restart_lr=None,
|
22 |
+
last_epoch=-1,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Initialize new CyclicCosineDecayLR object
|
26 |
+
:param optimizer: (Optimizer) - Wrapped optimizer.
|
27 |
+
:param init_interval: (int) - Initial decay cycle interval.
|
28 |
+
:param min_lr: (float or iterable of floats) - Minimal learning rate.
|
29 |
+
:param restart_multiplier: (float) - Multiplication coefficient for increasing cycle intervals,
|
30 |
+
if this parameter is set, restart_interval must be None.
|
31 |
+
:param restart_interval: (int) - Restart interval for fixed cycle intervals,
|
32 |
+
if this parameter is set, restart_multiplier must be None.
|
33 |
+
:param restart_lr: (float or iterable of floats) - Optional, the learning rate at cycle restarts,
|
34 |
+
if not provided, initial learning rate will be used.
|
35 |
+
:param last_epoch: (int) - Last epoch.
|
36 |
+
"""
|
37 |
+
self.len_param_groups = len_param_groups
|
38 |
+
if restart_interval is not None and restart_multiplier is not None:
|
39 |
+
raise ValueError(
|
40 |
+
"You can either set restart_interval or restart_multiplier but not both"
|
41 |
+
)
|
42 |
+
|
43 |
+
if isinstance(min_lr, Iterable) and len(min_lr) != self.len_param_groups:
|
44 |
+
raise ValueError(
|
45 |
+
"Expected len(min_lr) to be equal to len(optimizer.param_groups), "
|
46 |
+
"got {} and {} instead".format(len(min_lr), self.len_param_groups)
|
47 |
+
)
|
48 |
+
|
49 |
+
if isinstance(restart_lr, Iterable) and len(restart_lr) != len(
|
50 |
+
self.len_param_groups
|
51 |
+
):
|
52 |
+
raise ValueError(
|
53 |
+
"Expected len(restart_lr) to be equal to len(optimizer.param_groups), "
|
54 |
+
"got {} and {} instead".format(len(restart_lr), self.len_param_groups)
|
55 |
+
)
|
56 |
+
|
57 |
+
if init_interval <= 0:
|
58 |
+
raise ValueError(
|
59 |
+
"init_interval must be a positive number, got {} instead".format(
|
60 |
+
init_interval
|
61 |
+
)
|
62 |
+
)
|
63 |
+
|
64 |
+
group_num = self.len_param_groups
|
65 |
+
self._init_interval = init_interval
|
66 |
+
self._min_lr = [min_lr] * group_num if isinstance(min_lr, float) else min_lr
|
67 |
+
self._restart_lr = (
|
68 |
+
[restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr
|
69 |
+
)
|
70 |
+
self._restart_interval = restart_interval
|
71 |
+
self._restart_multiplier = restart_multiplier
|
72 |
+
self.last_epoch = last_epoch
|
73 |
+
self.base_lrs = base_lrs
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
def on_batch_end(self, net, training, **kwargs):
|
77 |
+
if self.last_epoch < self._init_interval:
|
78 |
+
return self._calc(self.last_epoch, self._init_interval, self.base_lrs)
|
79 |
+
|
80 |
+
elif self._restart_interval is not None:
|
81 |
+
cycle_epoch = (
|
82 |
+
self.last_epoch - self._init_interval
|
83 |
+
) % self._restart_interval
|
84 |
+
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
|
85 |
+
return self._calc(cycle_epoch, self._restart_interval, lrs)
|
86 |
+
|
87 |
+
elif self._restart_multiplier is not None:
|
88 |
+
n = self._get_n(self.last_epoch)
|
89 |
+
sn_prev = self._partial_sum(n)
|
90 |
+
cycle_epoch = self.last_epoch - sn_prev
|
91 |
+
interval = self._init_interval * self._restart_multiplier ** n
|
92 |
+
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
|
93 |
+
return self._calc(cycle_epoch, interval, lrs)
|
94 |
+
else:
|
95 |
+
return self._min_lr
|
96 |
+
|
97 |
+
def _calc(self, t, T, lrs):
|
98 |
+
return [
|
99 |
+
min_lr + (lr - min_lr) * (1 + cos(pi * t / T)) / 2
|
100 |
+
for lr, min_lr in zip(lrs, self._min_lr)
|
101 |
+
]
|
102 |
+
|
103 |
+
def _get_n(self, epoch):
|
104 |
+
a = self._init_interval
|
105 |
+
r = self._restart_multiplier
|
106 |
+
_t = 1 - (1 - r) * epoch / a
|
107 |
+
return floor(log(_t, r))
|
108 |
+
|
109 |
+
def _partial_sum(self, n):
|
110 |
+
a = self._init_interval
|
111 |
+
r = self._restart_multiplier
|
112 |
+
return a * (1 - r ** n) / (1 - r)
|
113 |
+
|
114 |
+
|
115 |
+
class LearningRateDecayCallback(skorch.callbacks.Callback):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
config,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
self.lr_warmup_end = config.lr_warmup_end
|
122 |
+
self.lr_warmup_start = config.lr_warmup_start
|
123 |
+
self.learning_rate = config.learning_rate
|
124 |
+
self.warmup_batch = config.warmup_epoch * config.batch_per_epoch
|
125 |
+
self.final_batch = config.final_epoch * config.batch_per_epoch
|
126 |
+
|
127 |
+
self.batch_idx = 0
|
128 |
+
|
129 |
+
def on_batch_end(self, net, training, **kwargs):
|
130 |
+
"""
|
131 |
+
|
132 |
+
:param trainer:
|
133 |
+
:type trainer:
|
134 |
+
:param pl_module:
|
135 |
+
:type pl_module:
|
136 |
+
:param batch:
|
137 |
+
:type batch:
|
138 |
+
:param batch_idx:
|
139 |
+
:type batch_idx:
|
140 |
+
:param dataloader_idx:
|
141 |
+
:type dataloader_idx:
|
142 |
+
"""
|
143 |
+
# to avoid updating after validation batch
|
144 |
+
if training:
|
145 |
+
|
146 |
+
if self.batch_idx < self.warmup_batch:
|
147 |
+
# linear warmup, in paper: start from 0.1 to 1 over lr_warmup_end batches
|
148 |
+
lr_mult = float(self.batch_idx) / float(max(1, self.warmup_batch))
|
149 |
+
lr = self.lr_warmup_start + lr_mult * (
|
150 |
+
self.lr_warmup_end - self.lr_warmup_start
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
# Cosine learning rate decay
|
154 |
+
progress = float(self.batch_idx - self.warmup_batch) / float(
|
155 |
+
max(1, self.final_batch - self.warmup_batch)
|
156 |
+
)
|
157 |
+
lr = max(
|
158 |
+
self.learning_rate
|
159 |
+
+ 0.5
|
160 |
+
* (1.0 + math.cos(math.pi * progress))
|
161 |
+
* (self.lr_warmup_end - self.learning_rate),
|
162 |
+
self.learning_rate,
|
163 |
+
)
|
164 |
+
net.lr = lr
|
165 |
+
# for param_group in net.optimizer.param_groups:
|
166 |
+
# param_group["lr"] = lr
|
167 |
+
|
168 |
+
self.batch_idx += 1
|
169 |
+
|
170 |
+
|
171 |
+
class LRAnnealing(skorch.callbacks.Callback):
|
172 |
+
def on_epoch_end(self, net, **kwargs):
|
173 |
+
if not net.history[-1]["valid_loss_best"]:
|
174 |
+
net.lr /= 4.0
|
transforna/src/callbacks/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .criterion import *
|
2 |
+
from .LRCallback import *
|
3 |
+
from .metrics import *
|
4 |
+
from .tbWriter import *
|
transforna/src/callbacks/criterion.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class LossFunction(nn.Module):
|
12 |
+
def __init__(self,main_config):
|
13 |
+
super(LossFunction, self).__init__()
|
14 |
+
self.train_config = main_config["train_config"]
|
15 |
+
self.model_config = main_config["model_config"]
|
16 |
+
self.batch_per_epoch = self.train_config.batch_per_epoch
|
17 |
+
self.warm_up_annealing = (
|
18 |
+
self.train_config.warmup_epoch * self.batch_per_epoch
|
19 |
+
)
|
20 |
+
self.num_embed_hidden = self.model_config.num_embed_hidden
|
21 |
+
self.batch_idx = 0
|
22 |
+
self.loss_anealing_term = 0
|
23 |
+
|
24 |
+
|
25 |
+
class_weights = self.model_config.class_weights
|
26 |
+
#TODO: use device as in main_config
|
27 |
+
class_weights = torch.FloatTensor([float(x) for x in class_weights])
|
28 |
+
|
29 |
+
if self.model_config.num_classes > 2:
|
30 |
+
self.clf_loss_fn = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=self.train_config.label_smoothing_clf,reduction='none')
|
31 |
+
else:
|
32 |
+
self.clf_loss_fn = self.focal_loss
|
33 |
+
|
34 |
+
|
35 |
+
# @staticmethod
|
36 |
+
def cosine_similarity_matrix(
|
37 |
+
self, gene_embedd: torch.Tensor, second_input_embedd: torch.Tensor, annealing=True
|
38 |
+
) -> torch.Tensor:
|
39 |
+
# if annealing is true, then this function is being called from Net.predict and
|
40 |
+
# doesnt pass the instantiated object LossFunction, therefore no access to self.
|
41 |
+
# in Predict we also just need the max of predictions.
|
42 |
+
# for some reason, skorch only passes the LossFunction initialized object, only
|
43 |
+
# from get_loss fn.
|
44 |
+
|
45 |
+
assert gene_embedd.size(0) == second_input_embedd.size(0)
|
46 |
+
|
47 |
+
cosine_similarity = torch.matmul(gene_embedd, second_input_embedd.T)
|
48 |
+
|
49 |
+
if annealing:
|
50 |
+
if self.batch_idx < self.warm_up_annealing:
|
51 |
+
self.loss_anealing_term = 1 + (
|
52 |
+
self.batch_idx / self.warm_up_annealing
|
53 |
+
) * torch.sqrt(torch.tensor(self.num_embed_hidden))
|
54 |
+
|
55 |
+
cosine_similarity *= self.loss_anealing_term
|
56 |
+
|
57 |
+
return cosine_similarity
|
58 |
+
def get_similar_labels(self,y:torch.Tensor):
|
59 |
+
'''
|
60 |
+
This function recieves y, the labels tensor
|
61 |
+
It creates a list of lists containing at every index a list(min_len = 2) of the indices of the labels that are similar
|
62 |
+
'''
|
63 |
+
# create a test array
|
64 |
+
labels_y = y[:,0].cpu().detach().numpy()
|
65 |
+
|
66 |
+
# creates an array of indices, sorted by unique element
|
67 |
+
idx_sort = np.argsort(labels_y)
|
68 |
+
|
69 |
+
# sorts records array so all unique elements are together
|
70 |
+
sorted_records_array = labels_y[idx_sort]
|
71 |
+
|
72 |
+
# returns the unique values, the index of the first occurrence of a value, and the count for each element
|
73 |
+
vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)
|
74 |
+
|
75 |
+
# splits the indices into separate arrays
|
76 |
+
res = np.split(idx_sort, idx_start[1:])
|
77 |
+
#filter them with respect to their size, keeping only items occurring more than once
|
78 |
+
vals = vals[count > 1]
|
79 |
+
res = filter(lambda x: x.size > 1, res)
|
80 |
+
|
81 |
+
indices_similar_labels = []
|
82 |
+
similar_labels = []
|
83 |
+
for r in res:
|
84 |
+
indices_similar_labels.append(list(r))
|
85 |
+
similar_labels.append(list(labels_y[r]))
|
86 |
+
|
87 |
+
return indices_similar_labels,similar_labels
|
88 |
+
|
89 |
+
def get_triplet_samples(self,indices_similar_labels,similar_labels):
|
90 |
+
'''
|
91 |
+
This function creates three lists, positives, anchors and negatives
|
92 |
+
Each index in the three lists correpond to a single triplet
|
93 |
+
'''
|
94 |
+
positives,anchors,negatives = [],[],[]
|
95 |
+
for idx_similar_labels in indices_similar_labels:
|
96 |
+
random_indices = random.sample(range(len(idx_similar_labels)), 2)
|
97 |
+
positives.append(idx_similar_labels[random_indices[0]])
|
98 |
+
anchors.append(idx_similar_labels[random_indices[1]])
|
99 |
+
|
100 |
+
negatives = copy.deepcopy(positives)
|
101 |
+
random.shuffle(negatives)
|
102 |
+
while (np.array(positives) == np.array(negatives)).any():
|
103 |
+
random.shuffle(negatives)
|
104 |
+
|
105 |
+
return positives,anchors,negatives
|
106 |
+
def get_triplet_loss(self,y,gene_embedd,second_input_embedd):
|
107 |
+
'''
|
108 |
+
This function computes triplet loss by creating triplet samples of positives, negatives and anchors
|
109 |
+
The objective is to decrease the distance of the embeddings between the anchors and the positives
|
110 |
+
while increasing the distance between the anchor and the negatives.
|
111 |
+
This is done seperately for both the embeddings, gene embedds 0 and ss embedds 1
|
112 |
+
'''
|
113 |
+
#get similar labels
|
114 |
+
indices_similar_labels,similar_labels = self.get_similar_labels(y)
|
115 |
+
#insuring that there's at least two sets of labels in a given list (indices_similar_labels)
|
116 |
+
if len(indices_similar_labels)>1:
|
117 |
+
#get triplet samples
|
118 |
+
positives,anchors,negatives = self.get_triplet_samples(indices_similar_labels,similar_labels)
|
119 |
+
#get triplet loss for gene embedds
|
120 |
+
gene_embedd_triplet_loss = self.triplet_loss(gene_embedd[positives,:],
|
121 |
+
gene_embedd[anchors,:],
|
122 |
+
gene_embedd[negatives,:])
|
123 |
+
#get triplet loss for ss embedds
|
124 |
+
second_input_embedd_triplet_loss = self.triplet_loss(second_input_embedd[positives,:],
|
125 |
+
second_input_embedd[anchors,:],
|
126 |
+
second_input_embedd[negatives,:])
|
127 |
+
return gene_embedd_triplet_loss + second_input_embedd_triplet_loss
|
128 |
+
else:
|
129 |
+
return 0
|
130 |
+
|
131 |
+
def focal_loss(self,predicted_labels,y):
|
132 |
+
y = y.unsqueeze(dim=1)
|
133 |
+
y_new = torch.zeros(y.shape[0], 2).type(torch.cuda.FloatTensor)
|
134 |
+
y_new[range(y.shape[0]), y[:,0]]=1
|
135 |
+
BCE_loss = F.binary_cross_entropy_with_logits(predicted_labels.float(), y_new.float(), reduction='none')
|
136 |
+
pt = torch.exp(-BCE_loss) # prevents nans when probability 0
|
137 |
+
F_loss = (1-pt)**2 * BCE_loss
|
138 |
+
loss = 10*F_loss.mean()
|
139 |
+
return loss
|
140 |
+
|
141 |
+
def contrastive_loss(self,cosine_similarity,batch_size):
|
142 |
+
j = -torch.sum(torch.diagonal(cosine_similarity))
|
143 |
+
|
144 |
+
cosine_similarity.diagonal().copy_(torch.zeros(cosine_similarity.size(0)))
|
145 |
+
|
146 |
+
j = (1 - self.train_config.label_smoothing_sim) * j + (
|
147 |
+
self.train_config.label_smoothing_sim / (cosine_similarity.size(0) * (cosine_similarity.size(0) - 1))
|
148 |
+
) * torch.sum(cosine_similarity)
|
149 |
+
|
150 |
+
j += torch.sum(torch.logsumexp(cosine_similarity, dim=0))
|
151 |
+
|
152 |
+
if j < 0:
|
153 |
+
j = j-j
|
154 |
+
return j/batch_size
|
155 |
+
|
156 |
+
def forward(self, embedds: List[torch.Tensor], y=None) -> torch.Tensor:
|
157 |
+
self.batch_idx += 1
|
158 |
+
gene_embedd, second_input_embedd, predicted_labels,curr_epoch = embedds
|
159 |
+
|
160 |
+
|
161 |
+
loss = self.clf_loss_fn(predicted_labels,y.squeeze())
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
return loss
|
transforna/src/callbacks/metrics.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import skorch
|
5 |
+
import torch
|
6 |
+
from sklearn.metrics import confusion_matrix, make_scorer
|
7 |
+
from skorch.callbacks import BatchScoring
|
8 |
+
from skorch.callbacks.scoring import ScoringBase, _cache_net_forward_iter
|
9 |
+
from skorch.callbacks.training import Checkpoint
|
10 |
+
|
11 |
+
from .LRCallback import LearningRateDecayCallback
|
12 |
+
|
13 |
+
writer = None
|
14 |
+
|
15 |
+
def accuracy_score(y_true, y_pred: torch.tensor,task:str=None,mirna_flag:bool = False):
|
16 |
+
#sample
|
17 |
+
|
18 |
+
# premirna
|
19 |
+
if task == "premirna":
|
20 |
+
y_pred = y_pred[:,:-1]
|
21 |
+
miRNA_idx = np.where(y_true.squeeze()==mirna_flag)
|
22 |
+
correct = torch.max(y_pred,1).indices.cpu().numpy()[miRNA_idx] == mirna_flag
|
23 |
+
return sum(correct)
|
24 |
+
|
25 |
+
# sncrna
|
26 |
+
if task == "sncrna":
|
27 |
+
y_pred = y_pred[:,:-1]
|
28 |
+
# correct is of [samples], where each entry is true if it was found in top k
|
29 |
+
correct = torch.max(y_pred,1).indices.cpu().numpy() == y_true.squeeze()
|
30 |
+
|
31 |
+
return sum(correct) / y_pred.shape[0]
|
32 |
+
|
33 |
+
|
34 |
+
def accuracy_score_tcga(y_true, y_pred):
|
35 |
+
|
36 |
+
if torch.is_tensor(y_pred):
|
37 |
+
y_pred = y_pred.clone().detach().cpu().numpy()
|
38 |
+
if torch.is_tensor(y_true):
|
39 |
+
y_true = y_true.clone().detach().cpu().numpy()
|
40 |
+
|
41 |
+
#y pred contains logits | samples weights
|
42 |
+
sample_weight = y_pred[:,-1]
|
43 |
+
y_pred = np.argmax(y_pred[:,:-1],axis=1)
|
44 |
+
|
45 |
+
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
|
46 |
+
with np.errstate(divide='ignore', invalid='ignore'):
|
47 |
+
per_class = np.diag(C) / C.sum(axis=1)
|
48 |
+
if np.any(np.isnan(per_class)):
|
49 |
+
per_class = per_class[~np.isnan(per_class)]
|
50 |
+
score = np.mean(per_class)
|
51 |
+
return score
|
52 |
+
|
53 |
+
def score_callbacks(cfg):
|
54 |
+
|
55 |
+
acc_scorer = make_scorer(accuracy_score,task=cfg["task"])
|
56 |
+
if cfg['task'] == 'tcga':
|
57 |
+
acc_scorer = make_scorer(accuracy_score_tcga)
|
58 |
+
|
59 |
+
|
60 |
+
if cfg["task"] == "premirna":
|
61 |
+
acc_scorer_mirna = make_scorer(accuracy_score,task=cfg["task"],mirna_flag = True)
|
62 |
+
|
63 |
+
val_score_callback_mirna = BatchScoringPremirna( mirna_flag=True,
|
64 |
+
scoring = acc_scorer_mirna, lower_is_better=False, name="val_acc_mirna")
|
65 |
+
|
66 |
+
train_score_callback_mirna = BatchScoringPremirna(mirna_flag=True,
|
67 |
+
scoring = acc_scorer_mirna, on_train=True, lower_is_better=False, name="train_acc_mirna")
|
68 |
+
|
69 |
+
val_score_callback = BatchScoringPremirna(mirna_flag=False,
|
70 |
+
scoring = acc_scorer, lower_is_better=False, name="val_acc")
|
71 |
+
|
72 |
+
train_score_callback = BatchScoringPremirna(mirna_flag=False,
|
73 |
+
scoring = acc_scorer, on_train=True, lower_is_better=False, name="train_acc")
|
74 |
+
|
75 |
+
|
76 |
+
scoring_callbacks = [
|
77 |
+
train_score_callback,
|
78 |
+
train_score_callback_mirna
|
79 |
+
]
|
80 |
+
if cfg["train_split"]:
|
81 |
+
scoring_callbacks.extend([val_score_callback_mirna,val_score_callback])
|
82 |
+
|
83 |
+
if cfg["task"] in ["sncrna", "tcga"]:
|
84 |
+
|
85 |
+
val_score_callback = BatchScoring(acc_scorer, lower_is_better=False, name="val_acc")
|
86 |
+
train_score_callback = BatchScoring(
|
87 |
+
acc_scorer, on_train=True, lower_is_better=False, name="train_acc"
|
88 |
+
)
|
89 |
+
scoring_callbacks = [train_score_callback]
|
90 |
+
|
91 |
+
#tcga dataset has a predifined valid split, so train_split is false, but still valid metric is required
|
92 |
+
#TODO: remove predifined valid from tcga from prepare_data_tcga
|
93 |
+
if cfg["train_split"] or cfg['task'] == 'tcga':
|
94 |
+
scoring_callbacks.append(val_score_callback)
|
95 |
+
|
96 |
+
return scoring_callbacks
|
97 |
+
|
98 |
+
def get_callbacks(path,cfg):
|
99 |
+
|
100 |
+
callback_list = [("lrcallback", LearningRateDecayCallback)]
|
101 |
+
if cfg['tensorboard'] == True:
|
102 |
+
from .tbWriter import writer
|
103 |
+
callback_list.append(MetricsVizualization)
|
104 |
+
|
105 |
+
if (cfg["train_split"] or cfg['task'] == 'tcga') and cfg['inference'] == False:
|
106 |
+
monitor = "val_acc_best"
|
107 |
+
if cfg['trained_on'] == 'full':
|
108 |
+
monitor = 'train_acc_best'
|
109 |
+
ckpt_path = path+"/ckpt/"
|
110 |
+
try:
|
111 |
+
os.mkdir(ckpt_path)
|
112 |
+
except:
|
113 |
+
pass
|
114 |
+
model_name = f'model_params_{cfg["task"]}.pt'
|
115 |
+
callback_list.append(Checkpoint(monitor=monitor, dirname=ckpt_path,f_params=model_name))
|
116 |
+
|
117 |
+
scoring_callbacks = score_callbacks(cfg)
|
118 |
+
#TODO: For some reason scoring callbaks have to be inserted before checpoint and metrics viz callbacks
|
119 |
+
#otherwise NeuralNet notify function throws an exception
|
120 |
+
callback_list[1:1] = scoring_callbacks
|
121 |
+
|
122 |
+
return callback_list
|
123 |
+
|
124 |
+
|
125 |
+
class MetricsVizualization(skorch.callbacks.Callback):
|
126 |
+
def __init__(self, batch_idx=0) -> None:
|
127 |
+
super().__init__()
|
128 |
+
self.batch_idx = batch_idx
|
129 |
+
|
130 |
+
# TODO: Change to display metrics at epoch ends
|
131 |
+
def on_batch_end(self, net, training, **kwargs):
|
132 |
+
# validation batch
|
133 |
+
if not training:
|
134 |
+
# log val accuracy. accessing net.history:[ epoch ,batches, last batch,column in batch]
|
135 |
+
writer.add_scalar(
|
136 |
+
"Accuracy/val_acc",
|
137 |
+
net.history[-1, "batches", -1, "val_acc"],
|
138 |
+
self.batch_idx,
|
139 |
+
)
|
140 |
+
# log val loss
|
141 |
+
writer.add_scalar(
|
142 |
+
"Loss/val_loss",
|
143 |
+
net.history[-1, "batches", -1, "valid_loss"],
|
144 |
+
self.batch_idx,
|
145 |
+
)
|
146 |
+
# update batch idx after validation on batch is computed
|
147 |
+
# train batch
|
148 |
+
else:
|
149 |
+
# log lr
|
150 |
+
writer.add_scalar("Metrics/lr", net.lr, self.batch_idx)
|
151 |
+
# log train accuracy
|
152 |
+
writer.add_scalar(
|
153 |
+
"Accuracy/train_acc",
|
154 |
+
net.history[-1, "batches", -1, "train_acc"],
|
155 |
+
self.batch_idx,
|
156 |
+
)
|
157 |
+
# log train loss
|
158 |
+
writer.add_scalar(
|
159 |
+
"Loss/train_loss",
|
160 |
+
net.history[-1, "batches", -1, "train_loss"],
|
161 |
+
self.batch_idx,
|
162 |
+
)
|
163 |
+
self.batch_idx += 1
|
164 |
+
|
165 |
+
class BatchScoringPremirna(ScoringBase):
|
166 |
+
def __init__(self,mirna_flag:bool = False,*args,**kwargs):
|
167 |
+
super().__init__(*args,**kwargs)
|
168 |
+
#self.total_num_samples = total_num_samples
|
169 |
+
self.total_num_samples = 0
|
170 |
+
self.mirna_flag = mirna_flag
|
171 |
+
self.first_batch_flag = True
|
172 |
+
def on_batch_end(self, net, X, y, training, **kwargs):
|
173 |
+
if training != self.on_train:
|
174 |
+
return
|
175 |
+
|
176 |
+
y_preds = [kwargs['y_pred']]
|
177 |
+
#only for the first batch: get no. of samples belonging to same class samples
|
178 |
+
if self.first_batch_flag:
|
179 |
+
self.total_num_samples += sum(kwargs["batch"][1] == self.mirna_flag).detach().cpu().numpy()[0]
|
180 |
+
|
181 |
+
with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net:
|
182 |
+
# In case of y=None we will not have gathered any samples.
|
183 |
+
# We expect the scoring function to deal with y=None.
|
184 |
+
y = None if y is None else self.target_extractor(y)
|
185 |
+
try:
|
186 |
+
score = self._scoring(cached_net, X, y)
|
187 |
+
cached_net.history.record_batch(self.name_, score)
|
188 |
+
except KeyError:
|
189 |
+
pass
|
190 |
+
def get_avg_score(self, history):
|
191 |
+
if self.on_train:
|
192 |
+
bs_key = 'train_batch_size'
|
193 |
+
else:
|
194 |
+
bs_key = 'valid_batch_size'
|
195 |
+
|
196 |
+
weights, scores = list(zip(
|
197 |
+
*history[-1, 'batches', :, [bs_key, self.name_]]))
|
198 |
+
#score_avg = np.average(scores, weights=weights)
|
199 |
+
score_avg = sum(scores)/self.total_num_samples
|
200 |
+
return score_avg
|
201 |
+
|
202 |
+
# pylint: disable=unused-argument
|
203 |
+
def on_epoch_end(self, net, **kwargs):
|
204 |
+
self.first_batch_flag = False
|
205 |
+
history = net.history
|
206 |
+
try: # don't raise if there is no valid data
|
207 |
+
history[-1, 'batches', :, self.name_]
|
208 |
+
except KeyError:
|
209 |
+
return
|
210 |
+
|
211 |
+
score_avg = self.get_avg_score(history)
|
212 |
+
is_best = self._is_best_score(score_avg)
|
213 |
+
if is_best:
|
214 |
+
self.best_score_ = score_avg
|
215 |
+
|
216 |
+
history.record(self.name_, score_avg)
|
217 |
+
if is_best is not None:
|
218 |
+
history.record(self.name_ + '_best', bool(is_best))
|
transforna/src/callbacks/tbWriter.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from torch.utils.tensorboard import SummaryWriter
|
5 |
+
|
6 |
+
writer = SummaryWriter(str(Path(__file__).parent.parent.parent.absolute())+"/runs/")
|
transforna/src/inference/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .inference_api import *
|
2 |
+
from .inference_benchmark import *
|
3 |
+
from .inference_tcga import *
|
transforna/src/inference/inference_api.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
import warnings
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
from contextlib import redirect_stdout
|
6 |
+
from datetime import datetime
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
from hydra.utils import instantiate
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from sklearn.preprocessing import StandardScaler
|
15 |
+
from umap import UMAP
|
16 |
+
|
17 |
+
from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split
|
18 |
+
from ..processing.seq_tokenizer import SeqTokenizer
|
19 |
+
from ..utils.file import load
|
20 |
+
from ..utils.tcga_post_analysis_utils import Results_Handler
|
21 |
+
from ..utils.utils import (get_model, infer_from_pd,
|
22 |
+
prepare_inference_results_tcga,
|
23 |
+
update_config_with_inference_params)
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
warnings.filterwarnings("ignore")
|
28 |
+
|
29 |
+
def aggregate_ensemble_model(lev_dist_df:pd.DataFrame):
|
30 |
+
'''
|
31 |
+
This function aggregates the predictions of the ensemble model by choosing the model with the lowest and the highest NLD per query sequence.
|
32 |
+
If the lowest NLD is lower than Novelty Threshold, then the model with the lowest NLD is chosen as the ensemble prediction.
|
33 |
+
Otherwise, the model with the highest NLD is chosen as the ensemble prediction.
|
34 |
+
'''
|
35 |
+
#for every sequence, if at least one model scores an NLD < Novelty Threshold, then get the one with the least NLD as the ensemble prediction
|
36 |
+
#otherwise, get the highest NLD.
|
37 |
+
#get the minimum NLD per query sequence
|
38 |
+
#remove the baseline model
|
39 |
+
baseline_df = lev_dist_df[lev_dist_df['Model'] == 'Baseline'].reset_index(drop=True)
|
40 |
+
lev_dist_df = lev_dist_df[lev_dist_df['Model'] != 'Baseline'].reset_index(drop=True)
|
41 |
+
min_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmin().values]
|
42 |
+
#get the maximum NLD per query sequence
|
43 |
+
max_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmax().values]
|
44 |
+
#choose between each row in min_lev_dist_df and max_lev_dist_df based on the value of Novelty Threshold
|
45 |
+
novel_mask_df = min_lev_dist_df['NLD'] > min_lev_dist_df['Novelty Threshold']
|
46 |
+
#get the rows where NLD is lower than Novelty Threshold
|
47 |
+
min_lev_dist_df = min_lev_dist_df[~novel_mask_df.values]
|
48 |
+
#get the rows where NLD is higher than Novelty Threshold
|
49 |
+
max_lev_dist_df = max_lev_dist_df[novel_mask_df.values]
|
50 |
+
#merge min_lev_dist_df and max_lev_dist_df
|
51 |
+
ensemble_lev_dist_df = pd.concat([min_lev_dist_df,max_lev_dist_df])
|
52 |
+
#add ensemble model
|
53 |
+
ensemble_lev_dist_df['Model'] = 'Ensemble'
|
54 |
+
#add ensemble_lev_dist_df to lev_dist_df
|
55 |
+
lev_dist_df = pd.concat([lev_dist_df,ensemble_lev_dist_df,baseline_df])
|
56 |
+
return lev_dist_df.reset_index(drop=True)
|
57 |
+
|
58 |
+
|
59 |
+
def read_inference_model_config(model:str,mc_or_sc,trained_on:str,path_to_models:str):
|
60 |
+
transforna_folder = "TransfoRNA_ID"
|
61 |
+
if trained_on == "full":
|
62 |
+
transforna_folder = "TransfoRNA_FULL"
|
63 |
+
|
64 |
+
model_path = f"{path_to_models}/{transforna_folder}/{mc_or_sc}/{model}/meta/hp_settings.yaml"
|
65 |
+
cfg = OmegaConf.load(model_path)
|
66 |
+
return cfg
|
67 |
+
|
68 |
+
def predict_transforna(sequences: List[str], model: str = "Seq-Rev", mc_or_sc:str='sub_class',\
|
69 |
+
logits_flag:bool = False,attention_flag:bool = False,\
|
70 |
+
similarity_flag:bool=False,n_sim:int=3,embedds_flag:bool = False, \
|
71 |
+
umap_flag:bool = False,trained_on:str='full',path_to_models:str='') -> pd.DataFrame:
|
72 |
+
'''
|
73 |
+
This function predicts the major class or sub class of a list of sequences using the TransfoRNA model.
|
74 |
+
Additionaly, it can return logits, attention scores, similarity scores, gene embeddings or umap embeddings.
|
75 |
+
|
76 |
+
Input:
|
77 |
+
sequences: list of sequences to predict
|
78 |
+
model: model to use for prediction
|
79 |
+
mc_or_sc: models trained on major class or sub class
|
80 |
+
logits_flag: whether to return logits
|
81 |
+
attention_flag: whether to return attention scores (obtained from the self-attention layer)
|
82 |
+
similarity_flag: whether to return explanatory/similar sequences in the training set
|
83 |
+
n_sim: number of similar sequences to return
|
84 |
+
embedds_flag: whether to return embeddings of the sequences
|
85 |
+
umap_flag: whether to return umap embeddings
|
86 |
+
trained_on: whether to use the model trained on the full dataset or the ID dataset
|
87 |
+
Output:
|
88 |
+
pd.DataFrame with the predictions
|
89 |
+
'''
|
90 |
+
#assers that only one flag is True
|
91 |
+
assert sum([logits_flag,attention_flag,similarity_flag,embedds_flag,umap_flag]) <= 1, 'One option at most can be True'
|
92 |
+
# capitalize the first letter of the model and the first letter after the -
|
93 |
+
model = "-".join([word.capitalize() for word in model.split("-")])
|
94 |
+
cfg = read_inference_model_config(model,mc_or_sc,trained_on,path_to_models)
|
95 |
+
cfg = update_config_with_inference_params(cfg,mc_or_sc,trained_on,path_to_models)
|
96 |
+
root_dir = Path(__file__).parents[1].absolute()
|
97 |
+
|
98 |
+
with redirect_stdout(None):
|
99 |
+
cfg, net = get_model(cfg, root_dir)
|
100 |
+
#original_infer_pd might include seqs that are longer than input model. if so, infer_pd contains the trimmed sequences
|
101 |
+
infer_pd = pd.Series(sequences, name="Sequences").to_frame()
|
102 |
+
predicted_labels, logits, gene_embedds_df,attn_scores_pd,all_data, max_len, net,_ = infer_from_pd(cfg, net, infer_pd, SeqTokenizer,attention_flag)
|
103 |
+
|
104 |
+
if model == 'Seq':
|
105 |
+
gene_embedds_df = gene_embedds_df.iloc[:,:int(gene_embedds_df.shape[1]/2)]
|
106 |
+
if logits_flag:
|
107 |
+
cfg['log_logits'] = True
|
108 |
+
prepare_inference_results_tcga(cfg, predicted_labels, logits, all_data, max_len)
|
109 |
+
infer_pd = all_data["infere_rna_seq"]
|
110 |
+
|
111 |
+
if logits_flag:
|
112 |
+
logits_df = infer_pd.rename_axis("Sequence").reset_index()
|
113 |
+
logits_cols = [col for col in infer_pd.columns if "Logits" in col]
|
114 |
+
logits_df = infer_pd[logits_cols]
|
115 |
+
logits_df.columns = pd.MultiIndex.from_tuples(logits_df.columns, names=["Logits", "Sub Class"])
|
116 |
+
logits_df.columns = logits_df.columns.droplevel(0)
|
117 |
+
return logits_df
|
118 |
+
|
119 |
+
elif attention_flag:
|
120 |
+
return attn_scores_pd
|
121 |
+
|
122 |
+
elif embedds_flag:
|
123 |
+
return gene_embedds_df
|
124 |
+
|
125 |
+
else: #return table with predictions, entropy, threshold, is familiar
|
126 |
+
#add aa predictions to infer_pd
|
127 |
+
embedds_path = '/'.join(cfg['inference_settings']["model_path"].split('/')[:-2])+'/embedds'
|
128 |
+
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
|
129 |
+
results.get_knn_model()
|
130 |
+
lv_threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"]
|
131 |
+
logger.info(f'computing levenstein distance for the inference set')
|
132 |
+
#prepare infer split
|
133 |
+
gene_embedds_df.columns = results.embedds_cols[:len(gene_embedds_df.columns)]
|
134 |
+
#add index of gene_embedds_df to be a column with name results.seq_col
|
135 |
+
gene_embedds_df[results.seq_col] = gene_embedds_df.index
|
136 |
+
#set gene_embedds_df as the new infer split
|
137 |
+
results.splits_df_dict['infer_df'] = gene_embedds_df
|
138 |
+
|
139 |
+
|
140 |
+
_,_,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,'infer',num_neighbors=n_sim)
|
141 |
+
|
142 |
+
if similarity_flag:
|
143 |
+
#create df
|
144 |
+
sim_df = pd.DataFrame()
|
145 |
+
#populate query sequences and duplicate them n times
|
146 |
+
sequences = gene_embedds_df.index.tolist()
|
147 |
+
#duplicate each sequence n_sim times
|
148 |
+
sequences_duplicated = [seq for seq in sequences for _ in range(n_sim)]
|
149 |
+
sim_df['Sequence'] = sequences_duplicated
|
150 |
+
#assign top_5_seqs list to df column
|
151 |
+
sim_df[f'Explanatory Sequence'] = top_n_seqs
|
152 |
+
sim_df['NLD'] = lev_dist
|
153 |
+
sim_df['Explanatory Label'] = top_n_labels
|
154 |
+
sim_df['Novelty Threshold'] = lv_threshold
|
155 |
+
#for every query sequence, order the NLD in a increasing order
|
156 |
+
sim_df = sim_df.sort_values(by=['Sequence','NLD'],ascending=[False,True])
|
157 |
+
return sim_df
|
158 |
+
|
159 |
+
logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}')
|
160 |
+
#for every n_sim elements in the list, get the smallest levenstein distance
|
161 |
+
lv_dist_closest = [min(lev_dist[i:i+n_sim]) for i in range(0,len(lev_dist),n_sim)]
|
162 |
+
top_n_labels_closest = [top_n_labels[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)]
|
163 |
+
top_n_seqs_closest = [top_n_seqs[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)]
|
164 |
+
infer_pd['Is Familiar?'] = [True if lv<lv_threshold else False for lv in lv_dist_closest]
|
165 |
+
|
166 |
+
if umap_flag:
|
167 |
+
#compute umap
|
168 |
+
logger.info(f'computing umap for the inference set')
|
169 |
+
gene_embedds_df = gene_embedds_df.drop(results.seq_col,axis=1)
|
170 |
+
umap = UMAP(n_components=2,random_state=42)
|
171 |
+
scaled_embedds = StandardScaler().fit_transform(gene_embedds_df.values)
|
172 |
+
gene_embedds_df = pd.DataFrame(umap.fit_transform(scaled_embedds),columns=['UMAP1','UMAP2'])
|
173 |
+
gene_embedds_df['Net-Label'] = infer_pd['Net-Label'].values
|
174 |
+
gene_embedds_df['Is Familiar?'] = infer_pd['Is Familiar?'].values
|
175 |
+
gene_embedds_df['Explanatory Label'] = top_n_labels_closest
|
176 |
+
gene_embedds_df['Explanatory Sequence'] = top_n_seqs_closest
|
177 |
+
gene_embedds_df['Sequence'] = infer_pd.index
|
178 |
+
return gene_embedds_df
|
179 |
+
|
180 |
+
#override threshold
|
181 |
+
infer_pd['Novelty Threshold'] = lv_threshold
|
182 |
+
infer_pd['NLD'] = lv_dist_closest
|
183 |
+
infer_pd['Explanatory Label'] = top_n_labels_closest
|
184 |
+
infer_pd['Explanatory Sequence'] = top_n_seqs_closest
|
185 |
+
infer_pd = infer_pd.round({"NLD": 2, "Novelty Threshold": 2})
|
186 |
+
logger.info(f'num of new hico based on levenstein distance is {np.sum(infer_pd["Is Familiar?"])}')
|
187 |
+
return infer_pd.rename_axis("Sequence").reset_index()
|
188 |
+
|
189 |
+
def predict_transforna_all_models(sequences: List[str], mc_or_sc:str = 'sub_class',logits_flag: bool = False, attention_flag: bool = False,\
|
190 |
+
similarity_flag: bool = False, n_sim:int = 3,
|
191 |
+
embedds_flag:bool=False, umap_flag:bool = False, trained_on:str="full",path_to_models:str='') -> pd.DataFrame:
|
192 |
+
"""
|
193 |
+
Predicts the labels of the sequences using all the models available in the transforna package.
|
194 |
+
If non of the flags are true, it constructs and aggrgates the output of the ensemble model.
|
195 |
+
|
196 |
+
Input:
|
197 |
+
sequences: list of sequences to predict
|
198 |
+
mc_or_sc: models trained on major class or sub class
|
199 |
+
logits_flag: whether to return logits
|
200 |
+
attention_flag: whether to return attention scores (obtained from the self-attention layer)
|
201 |
+
similarity_flag: whether to return explanatory/similar sequences in the training set
|
202 |
+
n_sim: number of similar sequences to return
|
203 |
+
embedds_flag: whether to return embeddings of the sequences
|
204 |
+
umap_flag: whether to return umap embeddings
|
205 |
+
trained_on: whether to use the model trained on the full dataset or the ID dataset
|
206 |
+
Output:
|
207 |
+
df: dataframe with the predictions
|
208 |
+
"""
|
209 |
+
now = datetime.now()
|
210 |
+
before_time = now.strftime("%H:%M:%S")
|
211 |
+
models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"]
|
212 |
+
if similarity_flag or embedds_flag: #remove baseline, takes long time
|
213 |
+
models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"]
|
214 |
+
if attention_flag: #remove single based transformer models
|
215 |
+
models = ["Seq", "Seq-Struct", "Seq-Rev"]
|
216 |
+
df = None
|
217 |
+
for model in models:
|
218 |
+
logger.info(model)
|
219 |
+
df_ = predict_transforna(sequences, model, mc_or_sc,logits_flag,attention_flag,similarity_flag,n_sim,embedds_flag,umap_flag,trained_on=trained_on,path_to_models = path_to_models)
|
220 |
+
df_["Model"] = model
|
221 |
+
df = pd.concat([df, df_], axis=0)
|
222 |
+
#aggregate ensemble model if not of the flags are true
|
223 |
+
if not logits_flag and not attention_flag and not similarity_flag and not embedds_flag and not umap_flag:
|
224 |
+
df = aggregate_ensemble_model(df)
|
225 |
+
|
226 |
+
now = datetime.now()
|
227 |
+
after_time = now.strftime("%H:%M:%S")
|
228 |
+
delta_time = datetime.strptime(after_time, "%H:%M:%S") - datetime.strptime(before_time, "%H:%M:%S")
|
229 |
+
logger.info(f"Time taken: {delta_time}")
|
230 |
+
|
231 |
+
return df
|
232 |
+
|
233 |
+
|
234 |
+
if __name__ == "__main__":
|
235 |
+
parser = ArgumentParser()
|
236 |
+
parser.add_argument("sequences", nargs="+")
|
237 |
+
parser.add_argument("--logits_flag", nargs="?", const = True,default=False)
|
238 |
+
parser.add_argument("--attention_flag", nargs="?", const = True,default=False)
|
239 |
+
parser.add_argument("--similarity_flag", nargs="?", const = True,default=False)
|
240 |
+
parser.add_argument("--n_sim", nargs="?", const = 3,default=3)
|
241 |
+
parser.add_argument("--embedds_flag", nargs="?", const = True,default=False)
|
242 |
+
parser.add_argument("--trained_on", nargs="?", const = True,default="full")
|
243 |
+
predict_transforna_all_models(**vars(parser.parse_args()))
|
transforna/src/inference/inference_benchmark.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
from ..callbacks.metrics import accuracy_score
|
6 |
+
from ..processing.seq_tokenizer import SeqTokenizer
|
7 |
+
from ..score.score import infer_from_model, infer_testset
|
8 |
+
from ..utils.file import load, save
|
9 |
+
from ..utils.utils import *
|
10 |
+
|
11 |
+
|
12 |
+
def infer_benchmark(cfg:Dict= None,path:str = None):
|
13 |
+
if cfg['tensorboard']:
|
14 |
+
from ..callbacks.tbWriter import writer
|
15 |
+
|
16 |
+
model = cfg["model_name"]+'_'+cfg['task']
|
17 |
+
|
18 |
+
#set seed
|
19 |
+
set_seed_and_device(cfg["seed"],cfg["device_number"])
|
20 |
+
#get data
|
21 |
+
ad = load(cfg["train_config"].dataset_path_train)
|
22 |
+
|
23 |
+
#instantiate dataset class
|
24 |
+
dataset_class = SeqTokenizer(ad.var,cfg)
|
25 |
+
test_data = load(cfg["train_config"].dataset_path_test)
|
26 |
+
#prepare data for training and inference
|
27 |
+
all_data = prepare_data_benchmark(dataset_class,test_data,cfg)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
#sync skorch config with params in train and model config
|
32 |
+
sync_skorch_with_config(cfg["model"]["skorch_model"],cfg)
|
33 |
+
|
34 |
+
# instantiate skorch model
|
35 |
+
net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path)
|
36 |
+
net.initialize()
|
37 |
+
net.load_params(f_params=f'{cfg["inference_settings"]["model_path"]}')
|
38 |
+
|
39 |
+
#perform inference on task specific testset
|
40 |
+
if cfg["inference_settings"]["infere_original_testset"]:
|
41 |
+
infer_testset(net,cfg,all_data,accuracy_score)
|
42 |
+
|
43 |
+
#inference on custom data
|
44 |
+
predicted_labels,logits,_,_ = infer_from_model(net,all_data["infere_data"])
|
45 |
+
prepare_inference_results_benchmarck(net,cfg,predicted_labels,logits,all_data)
|
46 |
+
save(path=Path(__file__).parent.parent.absolute() / f'inference_results_{model}',data=all_data["infere_rna_seq"])
|
47 |
+
if cfg['tensorboard']:
|
48 |
+
writer.close()
|
transforna/src/inference/inference_tcga.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
from anndata import AnnData
|
5 |
+
|
6 |
+
from ..processing.seq_tokenizer import SeqTokenizer
|
7 |
+
from ..utils.file import load
|
8 |
+
from ..utils.utils import *
|
9 |
+
|
10 |
+
|
11 |
+
def infer_tcga(cfg:Dict= None,path:str = None):
|
12 |
+
if cfg['tensorboard']:
|
13 |
+
from ..callbacks.tbWriter import writer
|
14 |
+
cfg,net = get_model(cfg,path)
|
15 |
+
inference_path = cfg['inference_settings']['sequences_path']
|
16 |
+
original_infer_df = load(inference_path, index_col=0)
|
17 |
+
if isinstance(original_infer_df,AnnData):
|
18 |
+
original_infer_df = original_infer_df.var
|
19 |
+
predicted_labels,logits,_,_,all_data,max_len,net,infer_df = infer_from_pd(cfg,net,original_infer_df,SeqTokenizer)
|
20 |
+
|
21 |
+
#create inference_output if it does not exist
|
22 |
+
if not os.path.exists(f"inference_output"):
|
23 |
+
os.makedirs(f"inference_output")
|
24 |
+
if cfg['log_embedds']:
|
25 |
+
embedds_pd = log_embedds(cfg,net,all_data['infere_rna_seq'])
|
26 |
+
embedds_pd.to_csv(f"inference_output/{cfg['model_name']}_embedds.csv")
|
27 |
+
|
28 |
+
prepare_inference_results_tcga(cfg,predicted_labels,logits,all_data,max_len)
|
29 |
+
|
30 |
+
#if sequences were trimmed, add mapping of trimmed sequences to original sequences
|
31 |
+
if original_infer_df.shape[0] != infer_df.shape[0]:
|
32 |
+
all_data["infere_rna_seq"] = add_original_seqs_to_predictions(infer_df,all_data['infere_rna_seq'])
|
33 |
+
#save
|
34 |
+
all_data["infere_rna_seq"].to_csv(f"inference_output/{cfg['model_name']}_inference_results.csv")
|
35 |
+
|
36 |
+
if cfg['tensorboard']:
|
37 |
+
writer.close()
|
38 |
+
return predicted_labels
|
transforna/src/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .model_components import *
|
2 |
+
from .skorchWrapper import *
|
transforna/src/model/model_components.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from typing import Dict, Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from omegaconf import DictConfig
|
12 |
+
from torch.nn.modules.normalization import LayerNorm
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
def circulant_mask(n: int, window: int) -> torch.Tensor:
|
17 |
+
"""Calculate the relative attention mask, calculated once when model instatiated, as a subset of this matrix
|
18 |
+
will be used for a input length less than max.
|
19 |
+
i,j represent relative token positions in this matrix and in the attention scores matrix,
|
20 |
+
this mask enables attention scores to be set to 0 if further than the specified window length
|
21 |
+
|
22 |
+
:param n: a fixed parameter set to be larger than largest max sequence length across batches
|
23 |
+
:param window: [window length],
|
24 |
+
:return relative attention mask
|
25 |
+
"""
|
26 |
+
circulant_t = torch.zeros(n, n)
|
27 |
+
# [0, 1, 2, ..., window, -1, -2, ..., window]
|
28 |
+
offsets = [0] + [i for i in range(window + 1)] + [-i for i in range(window + 1)]
|
29 |
+
if window >= n:
|
30 |
+
return torch.ones(n, n)
|
31 |
+
for offset in offsets:
|
32 |
+
# size of the 1-tensor depends on the length of the diagonal
|
33 |
+
circulant_t.diagonal(offset=offset).copy_(torch.ones(n - abs(offset)))
|
34 |
+
return circulant_t
|
35 |
+
|
36 |
+
|
37 |
+
class SelfAttention(nn.Module):
|
38 |
+
|
39 |
+
"""normal query, key, value based self attention but with relative attention functionality
|
40 |
+
and a learnable bias encoding relative token position which is added to the attention scores before the softmax"""
|
41 |
+
|
42 |
+
def __init__(self, config: DictConfig, relative_attention: int):
|
43 |
+
"""init self attention weight of each key, query, value and output projection layer.
|
44 |
+
|
45 |
+
:param config: model config
|
46 |
+
:type config: ConveRTModelConfig
|
47 |
+
"""
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.config = config
|
51 |
+
self.query = nn.Linear(config.num_embed_hidden, config.num_attention_project)
|
52 |
+
self.key = nn.Linear(config.num_embed_hidden, config.num_attention_project)
|
53 |
+
self.value = nn.Linear(config.num_embed_hidden, config.num_attention_project)
|
54 |
+
|
55 |
+
self.softmax = nn.Softmax(dim=-1)
|
56 |
+
self.output_projection = nn.Linear(
|
57 |
+
config.num_attention_project, config.num_embed_hidden
|
58 |
+
)
|
59 |
+
self.bias = torch.nn.Parameter(torch.randn(config.n), requires_grad=True)
|
60 |
+
stdv = 1.0 / math.sqrt(self.bias.data.size(0))
|
61 |
+
self.bias.data.uniform_(-stdv, stdv)
|
62 |
+
self.relative_attention = relative_attention
|
63 |
+
self.n = self.config.n
|
64 |
+
self.half_n = self.n // 2
|
65 |
+
self.register_buffer(
|
66 |
+
"relative_mask",
|
67 |
+
circulant_mask(config.tokens_len, self.relative_attention),
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self, attn_input: torch.Tensor, attention_mask: torch.Tensor
|
72 |
+
) -> torch.Tensor:
|
73 |
+
"""calculate self-attention of query, key and weighted to value at the end.
|
74 |
+
self-attention input is projected by linear layer at the first time.
|
75 |
+
applying attention mask for ignore pad index attention weight. Relative attention mask
|
76 |
+
applied and a learnable bias added to the attention scores.
|
77 |
+
return value after apply output projection layer to value * attention
|
78 |
+
|
79 |
+
:param attn_input: [description]
|
80 |
+
:type attn_input: [type]
|
81 |
+
:param attention_mask: [description], defaults to None
|
82 |
+
:type attention_mask: [type], optional
|
83 |
+
:return: [description]
|
84 |
+
:rtype: [type]
|
85 |
+
"""
|
86 |
+
self.T = attn_input.size()[1]
|
87 |
+
# input is B x max seq len x n_emb
|
88 |
+
_query = self.query.forward(attn_input)
|
89 |
+
_key = self.key.forward(attn_input)
|
90 |
+
_value = self.value.forward(attn_input)
|
91 |
+
|
92 |
+
# scaled dot product
|
93 |
+
attention_scores = torch.matmul(_query, _key.transpose(1, 2))
|
94 |
+
attention_scores = attention_scores / math.sqrt(
|
95 |
+
self.config.num_attention_project
|
96 |
+
)
|
97 |
+
|
98 |
+
# Relative attention
|
99 |
+
|
100 |
+
# extended_attention_mask = attention_mask.to(attention_scores.device) # fp16 compatibility
|
101 |
+
extended_attention_mask = (1.0 - attention_mask.unsqueeze(-1)) * -10000.0
|
102 |
+
attention_scores = attention_scores + extended_attention_mask
|
103 |
+
|
104 |
+
# fix circulant_matrix to matrix of size 60 x60 (max token truncation_length,
|
105 |
+
# register as buffer, so not keep creating masks of different sizes.
|
106 |
+
|
107 |
+
attention_scores = attention_scores.masked_fill(
|
108 |
+
self.relative_mask.unsqueeze(0)[:, : self.T, : self.T] == 0, float("-inf")
|
109 |
+
)
|
110 |
+
|
111 |
+
# Learnable bias vector is used of max size,for each i, different subsets of it are added to the scores, where the permutations
|
112 |
+
# depend on the relative position (i-j). this way cleverly allows no loops. bias vector is 2*max truncation length+1
|
113 |
+
# so has a learnable parameter for each eg. (i-j) /in {-60,...60} .
|
114 |
+
|
115 |
+
ii, jj = torch.meshgrid(torch.arange(self.T), torch.arange(self.T))
|
116 |
+
B_matrix = self.bias[self.n // 2 - ii + jj]
|
117 |
+
|
118 |
+
attention_scores = attention_scores + B_matrix.unsqueeze(0)
|
119 |
+
|
120 |
+
attention_scores = self.softmax(attention_scores)
|
121 |
+
output = torch.matmul(attention_scores, _value)
|
122 |
+
|
123 |
+
output = self.output_projection(output)
|
124 |
+
|
125 |
+
return [output,attention_scores] # B x T x num embed hidden
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
class FeedForward1(nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self, input_hidden: int, intermediate_hidden: int, dropout_rate: float = 0.0
|
132 |
+
):
|
133 |
+
# 512 2048
|
134 |
+
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.linear_1 = nn.Linear(input_hidden, intermediate_hidden)
|
138 |
+
self.dropout = nn.Dropout(dropout_rate)
|
139 |
+
self.linear_2 = nn.Linear(intermediate_hidden, input_hidden)
|
140 |
+
|
141 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
142 |
+
|
143 |
+
x = F.gelu(self.linear_1(x))
|
144 |
+
return self.linear_2(self.dropout(x))
|
145 |
+
|
146 |
+
|
147 |
+
class SharedInnerBlock(nn.Module):
|
148 |
+
def __init__(self, config: DictConfig, relative_attn: int):
|
149 |
+
super().__init__()
|
150 |
+
|
151 |
+
self.config = config
|
152 |
+
self.self_attention = SelfAttention(config, relative_attn)
|
153 |
+
self.norm1 = LayerNorm(config.num_embed_hidden) # 512
|
154 |
+
self.dropout = nn.Dropout(config.dropout)
|
155 |
+
self.ff1 = FeedForward1(
|
156 |
+
config.num_embed_hidden, config.feed_forward1_hidden, config.dropout
|
157 |
+
)
|
158 |
+
self.norm2 = LayerNorm(config.num_embed_hidden)
|
159 |
+
|
160 |
+
def forward(self, x: torch.Tensor, attention_mask: int) -> torch.Tensor:
|
161 |
+
|
162 |
+
new_values_x,attn_scores = self.self_attention(x, attention_mask=attention_mask)
|
163 |
+
x = x+new_values_x
|
164 |
+
x = self.norm1(x)
|
165 |
+
x = x + self.ff1(x)
|
166 |
+
return self.norm2(x),attn_scores
|
167 |
+
|
168 |
+
|
169 |
+
# pretty basic, just single head. but done many times, stack to have another dimension (4 with batches).# so get stacks of B x H of attention scores T x T..
|
170 |
+
# then matrix multiply these extra stacks with the v
|
171 |
+
# (B xnh)x T xT . (Bx nh xTx hs) gives (B Nh) T x hs stacks. now hs is set to be final dimension/ number of heads, so reorder the stacks (concatenating them)
|
172 |
+
# can have optional extra projection layer, but doing that later
|
173 |
+
|
174 |
+
|
175 |
+
class MultiheadAttention(nn.Module):
|
176 |
+
def __init__(self, config: DictConfig):
|
177 |
+
super().__init__()
|
178 |
+
self.num_attention_heads = config.num_attention_heads
|
179 |
+
self.num_attn_proj = config.num_embed_hidden * config.num_attention_heads
|
180 |
+
self.attention_head_size = int(self.num_attn_proj / self.num_attention_heads)
|
181 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
182 |
+
|
183 |
+
self.query = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
|
184 |
+
self.key = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
|
185 |
+
self.value = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
|
186 |
+
|
187 |
+
self.dropout = nn.Dropout(config.dropout)
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
191 |
+
) -> torch.Tensor:
|
192 |
+
B, T, _ = hidden_states.size()
|
193 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
194 |
+
k = (
|
195 |
+
self.key(hidden_states)
|
196 |
+
.view(B, T, self.num_attention_heads, self.attention_head_size)
|
197 |
+
.transpose(1, 2)
|
198 |
+
) # (B, nh, T, hs)
|
199 |
+
q = (
|
200 |
+
self.query(hidden_states)
|
201 |
+
.view(B, T, self.num_attention_heads, self.attention_head_size)
|
202 |
+
.transpose(1, 2)
|
203 |
+
) # (B, nh, T, hs)
|
204 |
+
v = (
|
205 |
+
self.value(hidden_states)
|
206 |
+
.view(B, T, self.num_attention_heads, self.attention_head_size)
|
207 |
+
.transpose(1, 2)
|
208 |
+
) # (B, nh, T, hs)
|
209 |
+
|
210 |
+
attention_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
211 |
+
|
212 |
+
if attention_mask is not None:
|
213 |
+
attention_mask = attention_mask[:, None, None, :]
|
214 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
215 |
+
|
216 |
+
attention_scores = attention_scores + attention_mask
|
217 |
+
|
218 |
+
attention_scores = F.softmax(attention_scores, dim=-1)
|
219 |
+
|
220 |
+
attention_scores = self.dropout(attention_scores)
|
221 |
+
|
222 |
+
y = attention_scores @ v
|
223 |
+
|
224 |
+
y = y.transpose(1, 2).contiguous().view(B, T, self.num_attn_proj)
|
225 |
+
|
226 |
+
return y
|
227 |
+
|
228 |
+
|
229 |
+
class PositionalEncoding(nn.Module):
|
230 |
+
def __init__(self, model_config: DictConfig,):
|
231 |
+
super(PositionalEncoding, self).__init__()
|
232 |
+
self.dropout = nn.Dropout(p=model_config.dropout)
|
233 |
+
self.num_embed_hidden = model_config.num_embed_hidden
|
234 |
+
pe = torch.zeros(model_config.tokens_len, self.num_embed_hidden)
|
235 |
+
position = torch.arange(
|
236 |
+
0, model_config.tokens_len, dtype=torch.float
|
237 |
+
).unsqueeze(1)
|
238 |
+
div_term = torch.exp(
|
239 |
+
torch.arange(0, self.num_embed_hidden, 2).float()
|
240 |
+
* (-math.log(10000.0) / self.num_embed_hidden)
|
241 |
+
)
|
242 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
243 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
244 |
+
pe = pe.unsqueeze(0)
|
245 |
+
self.register_buffer("pe", pe)
|
246 |
+
|
247 |
+
def forward(self, x):
|
248 |
+
x = x + self.pe[: x.size(0), :]
|
249 |
+
return self.dropout(x)
|
250 |
+
|
251 |
+
|
252 |
+
class RNAFFrwd(
|
253 |
+
nn.Module
|
254 |
+
): # params are not shared for context and reply. so need two sets of weights
|
255 |
+
"""Fully-Connected 3-layer Linear Model"""
|
256 |
+
|
257 |
+
def __init__(self, model_config: DictConfig):
|
258 |
+
"""
|
259 |
+
:param input_hidden: first-hidden layer input embed-dim
|
260 |
+
:type input_hidden: int
|
261 |
+
:param intermediate_hidden: layer-(hidden)-layer middle point weight
|
262 |
+
:type intermediate_hidden: int
|
263 |
+
:param dropout_rate: dropout rate, defaults to None
|
264 |
+
:type dropout_rate: float, optional
|
265 |
+
"""
|
266 |
+
# paper specifies,skip connections,layer normalization, and orthogonal initialization
|
267 |
+
|
268 |
+
super().__init__()
|
269 |
+
# 3,679,744 x2 params
|
270 |
+
self.rna_ffwd_input_dim = (
|
271 |
+
model_config.num_embed_hidden * model_config.num_attention_heads
|
272 |
+
)
|
273 |
+
self.linear_1 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim)
|
274 |
+
self.linear_2 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim)
|
275 |
+
|
276 |
+
self.norm1 = LayerNorm(self.rna_ffwd_input_dim)
|
277 |
+
self.norm2 = LayerNorm(self.rna_ffwd_input_dim)
|
278 |
+
self.final = nn.Linear(self.rna_ffwd_input_dim, model_config.num_embed_hidden)
|
279 |
+
self.orthogonal_initialization() # torch implementation works perfectly out the box,
|
280 |
+
|
281 |
+
def orthogonal_initialization(self):
|
282 |
+
for l in [
|
283 |
+
self.linear_1,
|
284 |
+
self.linear_2,
|
285 |
+
]:
|
286 |
+
torch.nn.init.orthogonal_(l.weight)
|
287 |
+
|
288 |
+
def forward(self, x: torch.Tensor, attn_msk: torch.Tensor) -> torch.Tensor:
|
289 |
+
sentence_lengths = attn_msk.sum(1)
|
290 |
+
|
291 |
+
# adding square root reduction projection separately as not a shared.
|
292 |
+
# part of the diagram torch.Size([Batch, scent_len, embedd_dim])
|
293 |
+
|
294 |
+
# x has dims B x T x 2*d_emb
|
295 |
+
norms = 1 / torch.sqrt(sentence_lengths.double()).float() # 64
|
296 |
+
# TODO: Aggregation is done on all words including the masked ones
|
297 |
+
x = norms.unsqueeze(1) * torch.sum(x, dim=1) # 64 x1024
|
298 |
+
|
299 |
+
x = x + F.gelu(self.linear_1(self.norm1(x)))
|
300 |
+
x = x + F.gelu(self.linear_2(self.norm2(x)))
|
301 |
+
|
302 |
+
return F.normalize(self.final(x), dim=1, p=2) # 64 512
|
303 |
+
|
304 |
+
|
305 |
+
class RNATransformer(nn.Module):
|
306 |
+
def __init__(self, model_config: DictConfig):
|
307 |
+
super().__init__()
|
308 |
+
self.num_embedd_hidden = model_config.num_embed_hidden
|
309 |
+
self.encoder = nn.Embedding(
|
310 |
+
model_config.vocab_size, model_config.num_embed_hidden
|
311 |
+
)
|
312 |
+
self.model_input = model_config.model_input
|
313 |
+
if 'baseline' not in self.model_input:
|
314 |
+
# positional encoder
|
315 |
+
self.pos_encoder = PositionalEncoding(model_config)
|
316 |
+
|
317 |
+
self.transformer_layers = nn.ModuleList(
|
318 |
+
[
|
319 |
+
SharedInnerBlock(model_config, int(window/model_config.window))
|
320 |
+
for window in model_config.relative_attns[
|
321 |
+
: model_config.num_encoder_layers
|
322 |
+
]
|
323 |
+
]
|
324 |
+
)
|
325 |
+
self.MHA = MultiheadAttention(model_config)
|
326 |
+
# self.concatenate = FeedForward2(model_config)
|
327 |
+
|
328 |
+
self.rna_ffrwd = RNAFFrwd(model_config)
|
329 |
+
self.pad_id = 0
|
330 |
+
|
331 |
+
def forward(self, x:torch.Tensor) -> torch.Tensor:
|
332 |
+
if x.is_cuda:
|
333 |
+
long_tensor = torch.cuda.LongTensor
|
334 |
+
else:
|
335 |
+
long_tensor = torch.LongTensor
|
336 |
+
|
337 |
+
embedds = self.encoder(x)
|
338 |
+
if 'baseline' not in self.model_input:
|
339 |
+
output = self.pos_encoder(embedds)
|
340 |
+
attention_mask = (x != self.pad_id).int()
|
341 |
+
|
342 |
+
for l in self.transformer_layers:
|
343 |
+
output,attn_scores = l(output, attention_mask)
|
344 |
+
output = self.MHA(output)
|
345 |
+
output = self.rna_ffrwd(output, attention_mask)
|
346 |
+
return output,attn_scores
|
347 |
+
else:
|
348 |
+
embedds = torch.flatten(embedds,start_dim=1)
|
349 |
+
return embedds,None
|
350 |
+
|
351 |
+
class GeneEmbeddModel(nn.Module):
|
352 |
+
def __init__(
|
353 |
+
self, main_config: DictConfig,
|
354 |
+
):
|
355 |
+
super().__init__()
|
356 |
+
self.train_config = main_config["train_config"]
|
357 |
+
self.model_config = main_config["model_config"]
|
358 |
+
self.device = self.train_config.device
|
359 |
+
self.model_input = self.model_config["model_input"]
|
360 |
+
self.false_input_perc = self.model_config["false_input_perc"]
|
361 |
+
#adjust n (used to add rel bias on attn scores)
|
362 |
+
self.model_config.n = self.model_config.tokens_len*2+1
|
363 |
+
self.transformer_layers = RNATransformer(self.model_config)
|
364 |
+
#save tokens_len of sequences to be used to split ids between transformers
|
365 |
+
self.tokens_len = self.model_config.tokens_len
|
366 |
+
#reassign tokens_len and vocab_size to init a new transformer
|
367 |
+
#more clean solution -> RNATransformer and its children should
|
368 |
+
# have a flag input indicating which transformer
|
369 |
+
self.model_config.tokens_len = self.model_config.second_input_token_len
|
370 |
+
self.model_config.n = self.model_config.tokens_len*2+1
|
371 |
+
self.seq_vocab_size = self.model_config.vocab_size
|
372 |
+
#this differs between both models not the token_len/ss_token_len
|
373 |
+
self.model_config.vocab_size = self.model_config.second_input_vocab_size
|
374 |
+
|
375 |
+
self.second_input_model = RNATransformer(self.model_config)
|
376 |
+
|
377 |
+
#num_transformers refers to using either one model or two in parallel
|
378 |
+
self.num_transformers = 2
|
379 |
+
if self.model_input == 'seq':
|
380 |
+
self.num_transformers = 1
|
381 |
+
# could be moved to model
|
382 |
+
self.weight_decay = self.train_config.l2_weight_decay
|
383 |
+
if 'baseline' in self.model_input:
|
384 |
+
self.num_transformers = 1
|
385 |
+
num_nodes = self.model_config.num_embed_hidden*self.tokens_len
|
386 |
+
self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes)
|
387 |
+
else:
|
388 |
+
#setting classification layer
|
389 |
+
num_nodes = self.num_transformers*self.model_config.num_embed_hidden
|
390 |
+
if self.num_transformers == 1:
|
391 |
+
self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes)
|
392 |
+
else:
|
393 |
+
self.final_clf_1 = nn.Linear(num_nodes,num_nodes)
|
394 |
+
self.final_clf_2 = nn.Linear(num_nodes,self.model_config.num_classes)
|
395 |
+
self.relu = nn.ReLU()
|
396 |
+
self.BN = nn.BatchNorm1d(num_nodes)
|
397 |
+
self.dropout = nn.Dropout(0.6)
|
398 |
+
|
399 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
400 |
+
|
401 |
+
|
402 |
+
def distort_input(self,x):
|
403 |
+
for sample_idx in range(x.shape[0]):
|
404 |
+
seq_length = x[sample_idx,-1]
|
405 |
+
num_tokens_flipped = int(self.false_input_perc*seq_length)
|
406 |
+
max_start_flip_idx = seq_length - num_tokens_flipped
|
407 |
+
|
408 |
+
random_feat_idx = random.randint(0,max_start_flip_idx-1)
|
409 |
+
x[sample_idx,random_feat_idx:random_feat_idx+num_tokens_flipped] = \
|
410 |
+
torch.tensor(np.random.choice(range(1,self.seq_vocab_size-1),size=num_tokens_flipped,replace=True))
|
411 |
+
|
412 |
+
x[sample_idx,random_feat_idx+self.tokens_len:random_feat_idx+self.tokens_len+num_tokens_flipped] = \
|
413 |
+
torch.tensor(np.random.choice(range(1,self.model_config.second_input_vocab_size-1),size=num_tokens_flipped,replace=True))
|
414 |
+
return x
|
415 |
+
|
416 |
+
def forward(self, x,train=False):
|
417 |
+
if self.device == 'cuda':
|
418 |
+
long_tensor = torch.cuda.LongTensor
|
419 |
+
float_tensor = torch.cuda.FloatTensor
|
420 |
+
else:
|
421 |
+
long_tensor = torch.LongTensor
|
422 |
+
float_tensor = torch.FloatTensor
|
423 |
+
if train:
|
424 |
+
if self.false_input_perc > 0:
|
425 |
+
x = self.distort_input(x)
|
426 |
+
|
427 |
+
gene_embedd,attn_scores_first = self.transformer_layers(
|
428 |
+
x[:, : self.tokens_len].type(long_tensor)
|
429 |
+
)
|
430 |
+
attn_scores_second = None
|
431 |
+
second_input_embedd,attn_scores_second = self.second_input_model(
|
432 |
+
x[:, self.tokens_len :-1].type(long_tensor)
|
433 |
+
)
|
434 |
+
|
435 |
+
#for tcga: if seq or baseline
|
436 |
+
if self.num_transformers == 1:
|
437 |
+
activations = self.final_clf_1(gene_embedd)
|
438 |
+
else:
|
439 |
+
out_clf_1 = self.final_clf_1(torch.cat((gene_embedd, second_input_embedd), 1))
|
440 |
+
out = self.BN(out_clf_1)
|
441 |
+
out = self.relu(out)
|
442 |
+
out = self.dropout(out)
|
443 |
+
activations = self.final_clf_2(out)
|
444 |
+
|
445 |
+
#create dummy attn scores for baseline
|
446 |
+
if 'baseline' in self.model_input:
|
447 |
+
attn_scores_first = torch.ones((1,2,2),device=x.device)
|
448 |
+
|
449 |
+
return [gene_embedd, second_input_embedd, activations,attn_scores_first,attn_scores_second]
|
transforna/src/model/skorchWrapper.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import skorch
|
6 |
+
import torch
|
7 |
+
from skorch.dataset import Dataset, ValidSplit
|
8 |
+
from skorch.setter import optimizer_setter
|
9 |
+
from skorch.utils import is_dataset, to_device
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
#from ..tbWriter import writer
|
13 |
+
|
14 |
+
|
15 |
+
class Net(skorch.NeuralNet):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
clip=0.25,
|
19 |
+
top_k=1,
|
20 |
+
correct=0,
|
21 |
+
save_embedding=False,
|
22 |
+
gene_embedds=[],
|
23 |
+
second_input_embedd=[],
|
24 |
+
confidence_threshold = 0.95,
|
25 |
+
*args,
|
26 |
+
**kwargs
|
27 |
+
):
|
28 |
+
self.clip = clip
|
29 |
+
self.curr_epoch = 0
|
30 |
+
super(Net, self).__init__(*args, **kwargs)
|
31 |
+
self.correct = correct
|
32 |
+
self.save_embedding = save_embedding
|
33 |
+
self.gene_embedds = gene_embedds
|
34 |
+
self.second_input_embedds = second_input_embedd
|
35 |
+
self.main_config = kwargs["module__main_config"]
|
36 |
+
self.train_config = self.main_config["train_config"]
|
37 |
+
self.top_k = self.train_config.top_k
|
38 |
+
self.num_classes = self.main_config["model_config"].num_classes
|
39 |
+
self.labels_mapping_path = self.train_config.labels_mapping_path
|
40 |
+
if self.labels_mapping_path:
|
41 |
+
with open(self.labels_mapping_path, 'rb') as handle:
|
42 |
+
self.labels_mapping_dict = pickle.load(handle)
|
43 |
+
self.confidence_threshold = confidence_threshold
|
44 |
+
self.max_epochs = kwargs["max_epochs"]
|
45 |
+
self.task = '' #is set in utils.instantiate_predictor
|
46 |
+
self.log_tb = False
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def set_save_epoch(self):
|
52 |
+
'''
|
53 |
+
scale best train epoch by valid size
|
54 |
+
'''
|
55 |
+
if self.task !='tcga':
|
56 |
+
if self.train_split:
|
57 |
+
self.save_epoch = self.main_config["train_config"].train_epoch
|
58 |
+
else:
|
59 |
+
self.save_epoch = round(self.main_config["train_config"].train_epoch*\
|
60 |
+
(1+self.main_config["valid_size"]))
|
61 |
+
|
62 |
+
def save_benchmark_model(self):
|
63 |
+
'''
|
64 |
+
saves benchmark epochs when train_split is none
|
65 |
+
'''
|
66 |
+
try:
|
67 |
+
os.mkdir("ckpt")
|
68 |
+
except:
|
69 |
+
pass
|
70 |
+
cwd = os.getcwd()+"/ckpt/"
|
71 |
+
self.save_params(f_params= f'{cwd}/model_params_{self.main_config["task"]}.pt')
|
72 |
+
|
73 |
+
|
74 |
+
def fit(self, X, y=None, valid_ds=None,**fit_params):
|
75 |
+
#all sequence lengths should be saved to compute the median based
|
76 |
+
self.all_lengths = [[] for i in range(self.num_classes)]
|
77 |
+
self.median_lengths = []
|
78 |
+
|
79 |
+
if not self.warm_start or not self.initialized_:
|
80 |
+
self.initialize()
|
81 |
+
|
82 |
+
if valid_ds:
|
83 |
+
self.validation_dataset = valid_ds
|
84 |
+
else:
|
85 |
+
self.validation_dataset = None
|
86 |
+
|
87 |
+
self.partial_fit(X, y, **fit_params)
|
88 |
+
return self
|
89 |
+
|
90 |
+
def fit_loop(self, X, y=None, epochs=None, **fit_params):
|
91 |
+
#if id then train longer otherwise stop at 0.99
|
92 |
+
rounding_digits = 3
|
93 |
+
if self.main_config['trained_on'] == 'full':
|
94 |
+
rounding_digits = 2
|
95 |
+
self.check_data(X, y)
|
96 |
+
epochs = epochs if epochs is not None else self.max_epochs
|
97 |
+
|
98 |
+
dataset_train, dataset_valid = self.get_split_datasets(X, y, **fit_params)
|
99 |
+
|
100 |
+
if self.validation_dataset is not None:
|
101 |
+
dataset_valid = self.validation_dataset.keywords["valid_ds"]
|
102 |
+
|
103 |
+
on_epoch_kwargs = {
|
104 |
+
"dataset_train": dataset_train,
|
105 |
+
"dataset_valid": dataset_valid,
|
106 |
+
}
|
107 |
+
|
108 |
+
iterator_train = self.get_iterator(dataset_train, training=True)
|
109 |
+
iterator_valid = None
|
110 |
+
if dataset_valid is not None:
|
111 |
+
iterator_valid = self.get_iterator(dataset_valid, training=False)
|
112 |
+
|
113 |
+
self.set_save_epoch()
|
114 |
+
|
115 |
+
for epoch_no in range(epochs):
|
116 |
+
#save model if training only on test set
|
117 |
+
self.curr_epoch = epoch_no
|
118 |
+
#save epoch is scaled by best train epoch
|
119 |
+
#save benchmark only when training on boith train and val sets
|
120 |
+
if self.task != 'tcga' and epoch_no == self.save_epoch and self.train_split == None:
|
121 |
+
self.save_benchmark_model()
|
122 |
+
|
123 |
+
self.notify("on_epoch_begin", **on_epoch_kwargs)
|
124 |
+
|
125 |
+
self.run_single_epoch(
|
126 |
+
iterator_train,
|
127 |
+
training=True,
|
128 |
+
prefix="train",
|
129 |
+
step_fn=self.train_step,
|
130 |
+
**fit_params
|
131 |
+
)
|
132 |
+
|
133 |
+
if dataset_valid is not None:
|
134 |
+
self.run_single_epoch(
|
135 |
+
iterator_valid,
|
136 |
+
training=False,
|
137 |
+
prefix="valid",
|
138 |
+
step_fn=self.validation_step,
|
139 |
+
**fit_params
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
self.notify("on_epoch_end", **on_epoch_kwargs)
|
144 |
+
#manual early stopping for tcga
|
145 |
+
if self.task == 'tcga':
|
146 |
+
train_acc = round(self.history[:,'train_acc'][-1],rounding_digits)
|
147 |
+
if train_acc == 1:
|
148 |
+
break
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
return self
|
153 |
+
|
154 |
+
def train_step(self, X, y=None):
|
155 |
+
y = X[1]
|
156 |
+
X = X[0]
|
157 |
+
sample_weights = X[:,-1]
|
158 |
+
if self.device == 'cuda':
|
159 |
+
sample_weights = sample_weights.to(self.train_config.device)
|
160 |
+
self.module_.train()
|
161 |
+
self.module_.zero_grad()
|
162 |
+
gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1],train=True)
|
163 |
+
#curr_epoch is passed to loss as it is used to switch loss criteria from unsup. -> sup
|
164 |
+
loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)
|
165 |
+
|
166 |
+
###sup loss should be X with samples weight and aggregated
|
167 |
+
|
168 |
+
loss = loss*sample_weights
|
169 |
+
loss = loss.mean()
|
170 |
+
|
171 |
+
loss.backward()
|
172 |
+
|
173 |
+
# TODO: clip only some parameters
|
174 |
+
torch.nn.utils.clip_grad_norm_(self.module_.parameters(), self.clip)
|
175 |
+
self.optimizer_.step()
|
176 |
+
|
177 |
+
return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}
|
178 |
+
|
179 |
+
def validation_step(self, X, y=None):
|
180 |
+
y = X[1]
|
181 |
+
X = X[0]
|
182 |
+
sample_weights = X[:,-1]
|
183 |
+
if self.device == 'cuda':
|
184 |
+
sample_weights = sample_weights.to(self.train_config.device)
|
185 |
+
self.module_.eval()
|
186 |
+
with torch.no_grad():
|
187 |
+
gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1])
|
188 |
+
loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)
|
189 |
+
|
190 |
+
###sup loss should be X with samples weight and aggregated
|
191 |
+
|
192 |
+
loss = loss*sample_weights
|
193 |
+
loss = loss.mean()
|
194 |
+
|
195 |
+
return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}
|
196 |
+
|
197 |
+
def get_attention_scores(self, X, y=None):
|
198 |
+
'''
|
199 |
+
returns attention scores for a given input
|
200 |
+
'''
|
201 |
+
self.module_.eval()
|
202 |
+
with torch.no_grad():
|
203 |
+
_, _, _,attn_scores_first,attn_scores_second = self.module_(X[:,:-1])
|
204 |
+
|
205 |
+
attn_scores_first = attn_scores_first.detach().cpu().numpy()
|
206 |
+
if attn_scores_second is not None:
|
207 |
+
attn_scores_second = attn_scores_second.detach().cpu().numpy()
|
208 |
+
return attn_scores_first,attn_scores_second
|
209 |
+
|
210 |
+
def predict(self, X):
|
211 |
+
self.module_.train(False)
|
212 |
+
embedds = self.module_(X[:,:-1])
|
213 |
+
sample_weights = X[:,-1]
|
214 |
+
if self.device == 'cuda':
|
215 |
+
sample_weights = sample_weights.to(self.train_config.device)
|
216 |
+
|
217 |
+
gene_embedd, second_input_embedd, activations,_,_ = embedds
|
218 |
+
if self.save_embedding:
|
219 |
+
self.gene_embedds.append(gene_embedd.detach().cpu())
|
220 |
+
#in case only a single transformer is deployed, then second_input_embedd are None. thus have no detach()
|
221 |
+
if second_input_embedd is not None:
|
222 |
+
self.second_input_embedds.append(second_input_embedd.detach().cpu())
|
223 |
+
|
224 |
+
predictions = torch.cat([activations,sample_weights[:,None]],dim=1)
|
225 |
+
return predictions
|
226 |
+
|
227 |
+
|
228 |
+
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
229 |
+
# log gradients and weights
|
230 |
+
for _, m in self.module_.named_modules():
|
231 |
+
for pn, p in m.named_parameters():
|
232 |
+
if pn.endswith("weight") and pn.find("norm") < 0:
|
233 |
+
if p.grad != None:
|
234 |
+
if self.log_tb:
|
235 |
+
from ..callbacks.tbWriter import writer
|
236 |
+
writer.add_histogram("weights/" + pn, p, len(net.history))
|
237 |
+
writer.add_histogram(
|
238 |
+
"gradients/" + pn, p.grad.data, len(net.history)
|
239 |
+
)
|
240 |
+
|
241 |
+
return
|
242 |
+
|
243 |
+
def configure_opt(self, l2_weight_decay):
|
244 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
245 |
+
params_decay = [
|
246 |
+
p
|
247 |
+
for n, p in self.module_.named_parameters()
|
248 |
+
if not any(nd in n for nd in no_decay)
|
249 |
+
]
|
250 |
+
params_nodecay = [
|
251 |
+
p
|
252 |
+
for n, p in self.module_.named_parameters()
|
253 |
+
if any(nd in n for nd in no_decay)
|
254 |
+
]
|
255 |
+
optim_groups = [
|
256 |
+
{"params": params_decay, "weight_decay": l2_weight_decay},
|
257 |
+
{"params": params_nodecay, "weight_decay": 0.0},
|
258 |
+
]
|
259 |
+
return optim_groups
|
260 |
+
|
261 |
+
def initialize_optimizer(self, triggered_directly=True):
|
262 |
+
"""Initialize the model optimizer. If ``self.optimizer__lr``
|
263 |
+
is not set, use ``self.lr`` instead.
|
264 |
+
|
265 |
+
Parameters
|
266 |
+
----------
|
267 |
+
triggered_directly : bool (default=True)
|
268 |
+
Only relevant when optimizer is re-initialized.
|
269 |
+
Initialization of the optimizer can be triggered directly
|
270 |
+
(e.g. when lr was changed) or indirectly (e.g. when the
|
271 |
+
module was re-initialized). If and only if the former
|
272 |
+
happens, the user should receive a message informing them
|
273 |
+
about the parameters that caused the re-initialization.
|
274 |
+
|
275 |
+
"""
|
276 |
+
# get learning rate from train config
|
277 |
+
optimizer_params = self.main_config["train_config"]
|
278 |
+
kwargs = {}
|
279 |
+
kwargs["lr"] = optimizer_params.learning_rate
|
280 |
+
# get l2 weight decay to init opt params
|
281 |
+
args = self.configure_opt(optimizer_params.l2_weight_decay)
|
282 |
+
|
283 |
+
if self.initialized_ and self.verbose:
|
284 |
+
msg = self._format_reinit_msg(
|
285 |
+
"optimizer", kwargs, triggered_directly=triggered_directly
|
286 |
+
)
|
287 |
+
logger.info(msg)
|
288 |
+
|
289 |
+
self.optimizer_ = self.optimizer(args, lr=kwargs["lr"])
|
290 |
+
|
291 |
+
self._register_virtual_param(
|
292 |
+
["optimizer__param_groups__*__*", "optimizer__*", "lr"],
|
293 |
+
optimizer_setter,
|
294 |
+
)
|
295 |
+
|
296 |
+
def initialize_criterion(self):
|
297 |
+
"""Initializes the criterion."""
|
298 |
+
# critereon takes train_config and model_config as an input.
|
299 |
+
# we get both from the module parameters
|
300 |
+
self.criterion_ = self.criterion(
|
301 |
+
self.main_config
|
302 |
+
)
|
303 |
+
if isinstance(self.criterion_, torch.nn.Module):
|
304 |
+
self.criterion_ = to_device(self.criterion_, self.device)
|
305 |
+
return self
|
306 |
+
|
307 |
+
def initialize_callbacks(self):
|
308 |
+
"""Initializes all callbacks and save the result in the
|
309 |
+
``callbacks_`` attribute.
|
310 |
+
|
311 |
+
Both ``default_callbacks`` and ``callbacks`` are used (in that
|
312 |
+
order). Callbacks may either be initialized or not, and if
|
313 |
+
they don't have a name, the name is inferred from the class
|
314 |
+
name. The ``initialize`` method is called on all callbacks.
|
315 |
+
|
316 |
+
The final result will be a list of tuples, where each tuple
|
317 |
+
consists of a name and an initialized callback. If names are
|
318 |
+
not unique, a ValueError is raised.
|
319 |
+
|
320 |
+
"""
|
321 |
+
if self.callbacks == "disable":
|
322 |
+
self.callbacks_ = []
|
323 |
+
return self
|
324 |
+
|
325 |
+
callbacks_ = []
|
326 |
+
|
327 |
+
class Dummy:
|
328 |
+
# We cannot use None as dummy value since None is a
|
329 |
+
# legitimate value to be set.
|
330 |
+
pass
|
331 |
+
|
332 |
+
for name, cb in self._uniquely_named_callbacks():
|
333 |
+
# check if callback itself is changed
|
334 |
+
param_callback = getattr(self, "callbacks__" + name, Dummy)
|
335 |
+
if param_callback is not Dummy: # callback itself was set
|
336 |
+
cb = param_callback
|
337 |
+
|
338 |
+
# below: check for callback params
|
339 |
+
# don't set a parameter for non-existing callback
|
340 |
+
|
341 |
+
# if the callback is lrcallback then initializa it with the train config,
|
342 |
+
# which is an input to the module
|
343 |
+
if name == "lrcallback":
|
344 |
+
params["config"] = self.main_config["train_config"]
|
345 |
+
else:
|
346 |
+
params = self.get_params_for("callbacks__{}".format(name))
|
347 |
+
if (cb is None) and params:
|
348 |
+
raise ValueError(
|
349 |
+
"Trying to set a parameter for callback {} "
|
350 |
+
"which does not exist.".format(name)
|
351 |
+
)
|
352 |
+
if cb is None:
|
353 |
+
continue
|
354 |
+
|
355 |
+
if isinstance(cb, type): # uninitialized:
|
356 |
+
cb = cb(**params)
|
357 |
+
else:
|
358 |
+
cb.set_params(**params)
|
359 |
+
cb.initialize()
|
360 |
+
callbacks_.append((name, cb))
|
361 |
+
|
362 |
+
self.callbacks_ = callbacks_
|
363 |
+
|
364 |
+
return self
|
transforna/src/novelty_prediction/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .id_vs_ood_entropy_clf import *
|
2 |
+
from .id_vs_ood_nld_clf import *
|