Yak-hbdx commited on
Commit
0b11a42
·
verified ·
1 Parent(s): fc438c0

uploaded TransfoRNA repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +674 -0
  2. README.md +145 -3
  3. conf/__init__.py +0 -0
  4. conf/hydra/job_logging/custom.yaml +13 -0
  5. conf/inference_settings/default.yaml +3 -0
  6. conf/main_config.yaml +51 -0
  7. conf/model/transforna.yaml +9 -0
  8. conf/readme.md +17 -0
  9. conf/train_model_configs/__init__.py +0 -0
  10. conf/train_model_configs/custom.py +76 -0
  11. conf/train_model_configs/premirna.py +68 -0
  12. conf/train_model_configs/sncrna.py +70 -0
  13. conf/train_model_configs/tcga.py +81 -0
  14. environment.yml +24 -0
  15. install.sh +33 -0
  16. kba_pipeline/README.md +58 -0
  17. kba_pipeline/environment.yml +20 -0
  18. kba_pipeline/src/annotate_from_mapping.py +751 -0
  19. kba_pipeline/src/make_anno.py +59 -0
  20. kba_pipeline/src/map_2_HBDxBase.py +318 -0
  21. kba_pipeline/src/precursor_bins.py +127 -0
  22. kba_pipeline/src/utils.py +163 -0
  23. requirements.txt +18 -0
  24. scripts/test_inference_api.py +29 -0
  25. scripts/train.sh +80 -0
  26. setup.py +40 -0
  27. transforna/__init__.py +7 -0
  28. transforna/__main__.py +54 -0
  29. transforna/bin/figure_scripts/figure_4_table_3.py +173 -0
  30. transforna/bin/figure_scripts/figure_5_S10_table_4.py +466 -0
  31. transforna/bin/figure_scripts/figure_6.ipynb +228 -0
  32. transforna/bin/figure_scripts/figure_S4.ipynb +368 -0
  33. transforna/bin/figure_scripts/figure_S5.py +94 -0
  34. transforna/bin/figure_scripts/figure_S8.py +56 -0
  35. transforna/bin/figure_scripts/figure_S9_S11.py +136 -0
  36. transforna/bin/figure_scripts/infer_lc_using_tcga.py +438 -0
  37. transforna/src/__init__.py +9 -0
  38. transforna/src/callbacks/LRCallback.py +174 -0
  39. transforna/src/callbacks/__init__.py +4 -0
  40. transforna/src/callbacks/criterion.py +165 -0
  41. transforna/src/callbacks/metrics.py +218 -0
  42. transforna/src/callbacks/tbWriter.py +6 -0
  43. transforna/src/inference/__init__.py +3 -0
  44. transforna/src/inference/inference_api.py +243 -0
  45. transforna/src/inference/inference_benchmark.py +48 -0
  46. transforna/src/inference/inference_tcga.py +38 -0
  47. transforna/src/model/__init__.py +2 -0
  48. transforna/src/model/model_components.py +449 -0
  49. transforna/src/model/skorchWrapper.py +364 -0
  50. transforna/src/novelty_prediction/__init__.py +2 -0
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
README.md CHANGED
@@ -1,3 +1,145 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TransfoRNA
2
+ TransfoRNA is a **bioinformatics** and **machine learning** tool based on **Transformers** to provide annotations for 11 major classes (miRNA, rRNA, tRNA, snoRNA, protein
3
+ -coding/mRNA, lncRNA, YRNA, piRNA, snRNA, snoRNA and vtRNA) and 1923 sub-classes
4
+ for **human small RNAs and RNA fragments**. These are typically detected by RNA-seq NGS (next generation sequencing) data.
5
+
6
+ TransfoRNA can be trained on just the RNA sequences and optionally on additional information such as secondary structure. The result is a major and sub-class assignment combined with a novelty score (Normalized Levenshtein Distance) that quantifies the difference between the query sequence and the closest match found in the training set. Based on that it decides if the query sequence is novel or familiar. TransfoRNA uses a small curated set of ground truth labels obtained from common knowledge-based bioinformatics tools that map the sequences to transcriptome databases and a reference genome. Using TransfoRNA's framewok, the high confidence annotations in the TCGA dataset can be increased by 3 folds.
7
+
8
+
9
+ ## Dataset (Objective):
10
+ - **The Cancer Genome Atlas, [TCGA](https://www.cancer.gov/about-nci/organization/ccg/research/structural-genomics/tcga)** offers sequencing data of small RNAs and is used to evaluate TransfoRNAs classification performance
11
+ - Sequences are annotated based on a knowledge-based annotation approach that provides annotations for ~2k different sub-classes belonging to 11 major classes.
12
+ - Knowledge-based annotations are divided into three sets of varying confidence levels: a **high-confidence (HICO)** set, a **low-confidence (LOCO)** set, and a **non-annotated (NA)** set for sequences that could not be annotated at all. Only HICO annotations are used for training.
13
+ - HICO RNAs cover ~2k sub-classes and constitute 19.6% of all RNAs found in TCGA. LOCO and NA sets comprise 66.9% and 13.6% of RNAs, respectively.
14
+ - HICO RNAs are further divided into **in-distribution, ID** (374 sub-classes) and **out-of-distribution, OOD** (1549 sub-classes) sets.
15
+ - Criteria for ID and OOD: Sub-class containing more than 8 sequences are considered ID, otherwise OOD.
16
+ - An additional **putative 5' adapter affixes set** contains 294 sequences known to be technical artefacts. The 5’-end perfectly matches the last five or more nucleotides of the 5’-adapter sequence, commonly used in small RNA sequencing.
17
+ - The knowledge-based annotation (KBA) pipline including installation guide is located under `kba_pipline`
18
+
19
+ ## Models
20
+ There are 5 classifier models currently available, each with different input representation.
21
+ - Baseline:
22
+ - Input: (single input) Sequence
23
+ - Model: An embedding layer that converts sequences into vectors followed by a classification feed forward layer.
24
+ - Seq:
25
+ - Input: (single input) Sequence
26
+ - Model: A transformer based encoder model.
27
+ - Seq-Seq:
28
+ - Input: (dual inputs) Sequence divided into even and odd tokens.
29
+ - Model: A transformer encoder is placed for odd tokens and another for even tokens.
30
+ - Seq-Struct:
31
+ - Input: (dual inputs) Sequence + Secondary structure
32
+ - Model: A transformer encoder for the sequence and another for the secondary structure.
33
+ - Seq-Rev (best performant):
34
+ - Input: (dual inputs) Sequence
35
+ - Model: A transformer encoder for the sequence and another for the sequence reversed.
36
+
37
+
38
+ *Note: These (Transformer) based models show overlapping and distinct capabilities. Consequently, an ensemble model is created to leverage those capabilities.*
39
+
40
+
41
+ <img width="948" alt="Screenshot 2023-08-16 at 16 39 20" src="https://github.com/gitHBDX/TransfoRNA-Framework/assets/82571392/d7d092d8-8cbd-492a-9ccc-994ffdd5aa5f">
42
+
43
+ ## Data Availability
44
+ Downloading the data and the models can be done from [here](https://www.dropbox.com/sh/y7u8cofmg41qs0y/AADvj5lw91bx7fcDxghMbMtsa?dl=0).
45
+
46
+ This will download three subfolders that should be kept on the same folder level as `src`:
47
+ - `data`: Contains three files:
48
+ - `TCGA` anndata with ~75k sequences and `var` columns containing the knowledge based annotations.
49
+ - `HBDXBase.csv` containing a list of RNA precursors which are then used for data augmentation.
50
+ - `subclass_to_annotation.json` holds mappings for every sub-class to major-class.
51
+
52
+ - `models`:
53
+ - `benchmark` : contains benchmark models trained on sncRNA and premiRNA data. (See additional datasets at the bottom)
54
+ - `tcga`: All models trained on the TCGA data; `TransfoRNA_ID` (for testing and validation) and `TransfoRNA_FULL` (the production version) containing higher RNA major and sub-class coverage. Each of the two folders contain all the models trained seperately on major-class and sub-class.
55
+ - `kba_pipeline`: contains mapping reference data required to run the knowledge based pipeline manually
56
+ ## Repo Structure
57
+ - configs: Contains the configurations of each model, training and inference settings.
58
+
59
+ The `conf/main_config.yaml` file offers options to change the task, the training settings and the logging. The following shows all the options and permitted values for each option.
60
+
61
+ <img width="835" alt="Screenshot 2024-05-22 at 10 19 15" src="https://github.com/gitHBDX/TransfoRNA/assets/82571392/225d2c98-ed45-4ca7-9e86-557a73af702d">
62
+
63
+ - transforna contains two folders:
64
+ - `src` folder which contains transforna package. View transforna's architecture [here](https://github.com/gitHBDX/TransfoRNA/blob/master/transforna/src/readme.md).
65
+ - `bin` folder contains all scripts necessary for reproducing manuscript figures.
66
+
67
+ ## Installation
68
+
69
+ The `install.sh` is a script that creates an transforna environment in which all the required packages for TransfoRNA are installed. Simply navigate to the root directory and run from terminal:
70
+
71
+ ```
72
+ #make install script executable
73
+ chmod +x install.sh
74
+
75
+
76
+ #run script
77
+ ./install.sh
78
+ ```
79
+
80
+ ## TransfoRNA API
81
+ In `transforna/src/inference/inference_api.py`, all the functionalities of transforna are offered as APIs. There are two functions of interest:
82
+ - `predict_transforna` : Computes for a set of sequences and for a given model, one of various options; the embeddings, logits, explanatory (similar) sequences, attentions masks or umap coordinates.
83
+ - `predict_transforna_all_models`: Same as `predict_transforna` but computes the desired option for all the models as well as aggregates the output of the ensemble model.
84
+ Both return a pandas dataframe containing the sequence along with the desired computation.
85
+
86
+ Check the script at `src/test_inference_api.py` for a basic demo on how to call the either of the APIs.
87
+
88
+ ## Inference from terminal
89
+ For inference, two paths in `configs/inference_settings/default.yaml` have to be edited:
90
+ - `sequences_path`: The full path to a csv file containing the sequences for which annotations are to be inferred.
91
+ - `model_path`: The full path of the model. (currently this points to the Seq model)
92
+
93
+ Also in the `main_config.yaml`, make sure to edit the `model_name` to match the input expected by the loaded model.
94
+ - `model_name`: add the name of the model. One of `"seq"`,`"seq-seq"`,`"seq-struct"`,`"baseline"` or `"seq-rev"` (see above)
95
+
96
+
97
+ Then, navigate the repositories' root directory and run the following command:
98
+
99
+ ```
100
+ python transforna/__main__.py inference=True
101
+ ```
102
+
103
+ After inference, an `inference_output` folder will be created under `outputs/` which will include two files.
104
+ - `(model_name)_embedds.csv`: contains vector embedding per sequence in the inference set- (could be used for downstream tasks).
105
+ *Note: The embedds of each sequence will only be logged if `log_embedds` in the `main_config` is `True`.*
106
+ - `(model_name)_inference_results.csv`: Contains columns; Net-Label containing predicted label and Is Familiar? boolean column containing the models' novelty predictor output. (True: familiar/ False: Novel)
107
+ *Note: The output will also contain the logits of the model is `log_logits` in the `main_config` is `True`.*
108
+
109
+
110
+ ## Train on custom data
111
+ TransfoRNA can be trained using input data as Anndata, csv or fasta. If the input is anndata, then `anndata.var` should contains all the sequences. Some changes has to be made (follow `configs/train_model_configs/tcga`):
112
+
113
+ In `configs/train_model_configs/custom`:
114
+ - `dataset_path_train` has to point to the input_data which should contain; a `sequence` column, a `small_RNA_class_annotation` coliumn indicating the major class if available (otherwise should be NaN), `five_prime_adapter_filter` specifies whether the sequence is considered a real sequence or an artifact (`True `for Real and `False` for artifact), a `subclass_name` column containing the sub-class name if available (otherwise should be NaN), and a boolean column `hico` indicating whether a sequence is high confidence or not.
115
+ - If sampling from the precursor is required in order to augment the sub-classes, the `precursor_file_path` should include precursors. Follow the scheme of the HBDxBase.csv and have a look at `PrecursorAugmenter` class in `transforna/src/processing/augmentation.py`
116
+ - `mapping_dict_path` should contain the mapping from sub class to major class. i.e: 'miR-141-5p' to 'miRNA'.
117
+ - `clf_target` sets the classification target of the mopdel and should be either `sub_class_hico` for training on targets in `subclass_name` or `major_class_hico` for training on targets in `small_RNA_class_annotation`. For both, only high confidence sequences are selected for training (based on `hico` column).
118
+
119
+ In configs/main_config, some changes should be made:
120
+ - change `task` to `custom` or to whatever name the `custom.py` has been renamed.
121
+ - set the `model_name` as desired.
122
+
123
+ For training TransfoRNA from the root directory:
124
+ ```
125
+ python transforna/__main__.py
126
+ ```
127
+ Using [Hydra](https://hydra.cc/), any option in the main config can be changed. For instance, to train a `Seq-Struct` TransfoRNA model without using a validation split:
128
+ ```
129
+ python transforna/__main__.py train_split=False model_name='seq-struct'
130
+ ```
131
+ After training, an output folder is automatically created in the root directory where training is logged.
132
+ The structure of the output folder is chosen by hydra to be `/day/time/results folders`. Results folders are a set of folders created during training:
133
+ - `ckpt`: (containing the latest checkpoint of the model)
134
+ - `embedds`:
135
+ - Contains a file per each split (train/valid/test/ood/na).
136
+ - Each file is a `csv` containing the sequences plus their embeddings (obtained by the model and represent numeric representation of a given RNA sequence) as well as the logits. The logits are values the models produce for each sequence, reflecting its confidence of a sequence belonging to a certain class.
137
+ - `meta`: A folder containing a `yaml` file with all the hyperparameters used for the current run.
138
+ - `analysis`: contains the learned novelty threshold seperating the in-distribution set(Familiar) from the out of distribution set (Novel).
139
+ - `figures`: some figures are saved containing the Normalized Levenstein Distance NLD, distribution per split.
140
+
141
+
142
+ ## Additional Datasets (Objective):
143
+ - sncRNA, collected from [RFam](https://rfam.org/) (classification of RNA precursors into 13 classes)
144
+ - premiRNA [human miRNAs](http://www.mirbase.org)(classification of true vs pseudo precursors)
145
+
conf/__init__.py ADDED
File without changes
conf/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 1
2
+ formatters:
3
+ simple:
4
+ format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
5
+ handlers:
6
+ console:
7
+ class: logging.StreamHandler
8
+ formatter: simple
9
+ stream: ext://sys.stdout
10
+ root:
11
+ handlers: [console]
12
+
13
+ disable_existing_loggers: false
conf/inference_settings/default.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ infere_original_testset: false
2
+ model_path: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq/ckpt/model_params_tcga.pt
3
+ sequences_path: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/data/inference_set.csv
conf/main_config.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: transforna
3
+ - inference_settings: default
4
+ - override hydra/job_logging: disabled
5
+
6
+ task: tcga # tcga,sncrna or premirna or custom (for a custom dataset)
7
+
8
+
9
+ train_config:
10
+ _target_: train_model_configs.${task}.GeneEmbeddTrainConfig
11
+
12
+ model_config:
13
+ _target_: train_model_configs.${task}.GeneEmbeddModelConfig
14
+
15
+ #train settings
16
+ model_name: seq #seq, seq-seq, seq-struct, seq-reverse, or baseline
17
+ trained_on: full #full(production, more coverage) or id (for test/eval purposes)
18
+ path_to_models: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/ #edit path to point to models/tcga/ directory: will be used if trained_on is full
19
+ inference: False # Should TransfoRNA be used for inference or train? True or False
20
+ #if inference is true, should the logits be logged?
21
+ log_logits: False
22
+
23
+
24
+ train_split: True # True or False
25
+ valid_size: 0.15 # 0 < valid_size < 1
26
+
27
+ #CV
28
+ cross_val: True # True or False
29
+ num_replicates: 1 # Integer, num_replicates for cross-validation
30
+
31
+ #seed
32
+ seed: 1 # Integer
33
+ device_number: 1 # Integer, select GPU
34
+
35
+
36
+ #logging sequence embeddings + metrics to tensorboard
37
+ log_embedds: True # True or False
38
+ tensorboard: False # True or False
39
+
40
+ #disable hydra output
41
+ hydra:
42
+ run:
43
+ dir: ./outputs/${now:%Y-%m-%d}/${model_name}
44
+ searchpath:
45
+ - file:///nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/conf
46
+ # output_subdir: null #uncomment to disable hydra output
47
+
48
+
49
+
50
+
51
+
conf/model/transforna.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ skorch_model:
2
+ _target_: transforna.Net
3
+ module: transforna.GeneEmbeddModel
4
+ criterion: transforna.LossFunction
5
+ max_epochs: 0 #infered from task specific train config
6
+ optimizer: torch.optim.AdamW
7
+ device: cuda
8
+ batch_size: 64
9
+ iterator_train__shuffle: True
conf/readme.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The `inference_settings` has a default yaml containing four keys:
2
+ -`sequences_path`: The full path of the file containing the sequences for which their annotations are to be infered.
3
+ - `model_path`: the full path of the model to be used for inference.
4
+ - `model_name`: A model name indicating the inputs the model expects. One of `seq`,`seq-seq`,`seq-struct`,`seq-reverse` or `baseline`
5
+ - `infere_original_testset`: True/False indicating whether inference should be computed on the original test set.
6
+
7
+ `model` contains the skeleton of the model used, the optimizer, loss function and device. All models are built using [skorch](https://skorch.readthedocs.io/en/latest/?badge=latest)
8
+
9
+ `train_model_configs` contain the hyperparameters for each dataset; tcga, sncrna and premirna:
10
+
11
+ - Each file contains the model and the train config.
12
+
13
+ - Model config: contains the model hyperparameters, sequence tokenization scheme and allows for choosing the model.
14
+
15
+ - Train config: contains training settings such as the learning rate hyper parameters as well as `dataset_path_train`.
16
+ - `dataset_path_train`: should point to the dataset [(Anndata)](https://anndata.readthedocs.io/en/latest/) used for training.
17
+
conf/train_model_configs/__init__.py ADDED
File without changes
conf/train_model_configs/custom.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List
5
+
6
+ dirname, _ = os.path.split(os.path.dirname(__file__))
7
+
8
+
9
+ @dataclass
10
+ class GeneEmbeddModelConfig:
11
+
12
+ model_input: str = "" #will be infered
13
+
14
+ num_embed_hidden: int = 100 #30 for exp, 100 for rest
15
+ ff_input_dim:int = 0 #is infered later, equals gene expression len
16
+ ff_hidden_dim: List = field(default_factory=lambda: [300]) #300 for exp hico
17
+ feed_forward1_hidden: int = 256
18
+ num_attention_project: int = 64
19
+ num_encoder_layers: int = 1
20
+ dropout: float = 0.2
21
+ n: int = 121
22
+ relative_attns: List = field(default_factory=lambda: [29, 4, 6, 8, 10, 11])
23
+ num_attention_heads: int = 5
24
+
25
+ window: int = 2
26
+ tokens_len: int = math.ceil(max_length / window)
27
+ second_input_token_len: int = 0 # is infered during runtime
28
+ vocab_size: int = 0 # is infered during runtime
29
+ second_input_vocab_size: int = 0 # is infered during runtime
30
+ tokenizer: str = (
31
+ "overlap" # either overlap or no_overlap or overlap_multi_window
32
+ )
33
+
34
+ clf_target:str = 'm' # sub_class_hico or major_class_hico. hico = high confidence
35
+ num_classes: int = 0 #will be infered during runtime
36
+ class_mappings:List = field(default_factory=lambda: [])#will be infered during runtime
37
+ class_weights :List = field(default_factory=lambda: [])
38
+ # how many extra window sizes other than deafault window
39
+ temperatures: List = field(default_factory=lambda: [0,10])
40
+
41
+ tokens_mapping_dict: Dict = None
42
+ false_input_perc:float = 0.0
43
+
44
+
45
+ @dataclass
46
+ class GeneEmbeddTrainConfig:
47
+ dataset_path_train: str = 'path/to/anndata.h5ad'
48
+ precursor_file_path: str = 'path/to/precursor_file.csv' #if not provided, sampling from the precurosr will not be done
49
+ mapping_dict_path: str = 'path/to/mapping_dict.json' #required for mapping sub class to major class, i.e: mir-568-3p to miRNA
50
+ device: str = "cuda"
51
+ l2_weight_decay: float = 0.05
52
+ batch_size: int = 512
53
+
54
+ batch_per_epoch:int = 0 # is infered during runtime
55
+
56
+ label_smoothing_sim:float = 0.2
57
+ label_smoothing_clf:float = 0.0
58
+ # learning rate
59
+ learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
60
+ lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
61
+ lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
62
+ # TODO: 122 is the number of train batches per epoch, should be infered and set
63
+ # warmup batch should be during the form epoch*(train batch per epoch)
64
+ warmup_epoch: int = 10 # how many batches linear warm up for
65
+ final_epoch: int = 20 # final batch of training when want learning rate
66
+
67
+ top_k: int = 10#int(0.1 * batch_size) # if the corresponding rna/GE appears during the top k, the correctly classified
68
+ cross_val: bool = False
69
+ labels_mapping_path: str = None
70
+ filter_seq_length:bool = False
71
+
72
+ num_augment_exp:int = 20
73
+ shuffle_exp: bool = False
74
+
75
+ max_epochs: int = 3000
76
+
conf/train_model_configs/premirna.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+
5
+ @dataclass
6
+ class GeneEmbeddModelConfig:
7
+ # input dim for the embedding and positional encoders
8
+ # as well as all k,q,v input/output dims
9
+ model_input: str = "seq-struct"
10
+ num_embed_hidden: int = 256
11
+ ff_hidden_dim: List = field(default_factory=lambda: [1200, 800])
12
+ feed_forward1_hidden: int = 1024
13
+ num_attention_project: int = 64
14
+ num_encoder_layers: int = 1
15
+ dropout: float = 0.5
16
+ n: int = 121
17
+ relative_attns: List = field(default_factory=lambda: [int(112), int(112), 6*3, 8*3, 10*3, 11*3])
18
+ num_attention_heads: int = 1
19
+
20
+ window: int = 2
21
+ tokens_len: int = 0 #will be infered later
22
+ second_input_token_len: int = 0 # is infered in runtime
23
+ vocab_size: int = 0 # is infered in runtime
24
+ second_input_vocab_size: int = 0 # is infered in runtime
25
+ tokenizer: str = (
26
+ "overlap" # either overlap or no_overlap or overlap_multi_window
27
+ )
28
+ num_classes: int = 0 #will be infered in runtime
29
+ class_weights :List = field(default_factory=lambda: [])
30
+ tokens_mapping_dict: dict = None
31
+
32
+ #false input percentage
33
+ false_input_perc:float = 0.1
34
+ model_input: str = "seq-struct"
35
+
36
+ @dataclass
37
+ class GeneEmbeddTrainConfig:
38
+ dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/premirna/train"
39
+ dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/premirna/test"
40
+ datset_path_additional_testset: str = "/data/hbdx_ldap_local/analysis/data/premirna/"
41
+ labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/premirna/labels_mapping_dict.pkl"
42
+ device: str = "cuda"
43
+ l2_weight_decay: float = 1e-5
44
+ batch_size: int = 64
45
+
46
+ batch_per_epoch: int = 0 #will be infered later
47
+ label_smoothing_sim:float = 0.0
48
+ label_smoothing_clf:float = 0.0
49
+
50
+ # learning rate
51
+ learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
52
+ lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
53
+ lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
54
+ # TODO: 122 is the number of train batches per epoch, should be infered and set
55
+ # warmup batch should be in the form epoch*(train batch per epoch)
56
+ warmup_epoch: int = 10 # how many batches linear warm up for
57
+ final_epoch: int = 20 # final batch of training when want learning rate
58
+
59
+ top_k: int = int(
60
+ 0.05 * batch_size
61
+ ) # if the corresponding rna/GE appears in the top k, the correctly classified
62
+ label_smoothing: float = 0.0
63
+ cross_val: bool = False
64
+ filter_seq_length:bool = True
65
+ train_epoch: int = 3000
66
+ max_epochs: int = 3500
67
+
68
+
conf/train_model_configs/sncrna.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from pickletools import int4
4
+ from typing import List
5
+
6
+
7
+ @dataclass
8
+ class GeneEmbeddModelConfig:
9
+ # input dim for the embedding and positional encoders
10
+ # as well as all k,q,v input/output dims
11
+ model_input: str = "seq-struct"
12
+ num_embed_hidden: int = 256
13
+ ff_hidden_dim: List = field(default_factory=lambda: [1200, 800])
14
+ feed_forward1_hidden: int = 1024
15
+ num_attention_project: int = 64
16
+ num_encoder_layers: int = 2
17
+ dropout: float = 0.3
18
+ n: int = 121
19
+ window:int = 4
20
+ relative_attns: List = field(default_factory=lambda: [int(360), int(360)])
21
+ num_attention_heads: int = 4
22
+
23
+ tokens_len: int = 0 #will be infered later
24
+ second_input_token_len:int = 0 # is infered in runtime
25
+ vocab_size: int = 0 # is infered in runtime
26
+ second_input_vocab_size: int = 0 # is infered in runtime
27
+ tokenizer: str = (
28
+ "overlap" # either overlap or no_overlap or overlap_multi_window
29
+ )
30
+ # how many extra window sizes other than deafault window
31
+ num_classes: int = 0 #will be infered in runtime
32
+ class_weights :List = field(default_factory=lambda: [])
33
+ tokens_mapping_dict: dict = None
34
+
35
+ #false input percentage
36
+ false_input_perc:float = 0.2
37
+
38
+ model_input: str = "seq-struct"
39
+
40
+
41
+ @dataclass
42
+ class GeneEmbeddTrainConfig:
43
+ dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/train.h5ad"
44
+ dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/test.h5ad"
45
+ labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/sncRNA/labels_mapping_dict.pkl"
46
+ device: str = "cuda"
47
+ l2_weight_decay: float = 1e-5
48
+ batch_size: int = 64
49
+
50
+ batch_per_epoch:int = 0 #will be infered later
51
+ label_smoothing_sim:float = 0.0
52
+ label_smoothing_clf:float = 0.0
53
+
54
+ # learning rate
55
+ learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
56
+ lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
57
+ lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
58
+ # TODO: 122 is the number of train batches per epoch, should be infered and set
59
+ # warmup batch should be in the form epoch*(train batch per epoch)
60
+ warmup_epoch: int = 10 # how many batches linear warm up for
61
+ final_epoch: int = 20 # final batch of training when want learning rate
62
+
63
+ top_k: int = int(
64
+ 0.05 * batch_size
65
+ ) # if the corresponding rna/GE appears in the top k, the correctly classified
66
+ label_smoothing: float = 0.0
67
+ cross_val: bool = False
68
+ filter_seq_length:bool = True
69
+ train_epoch: int = 800
70
+ max_epochs:int = 1000
conf/train_model_configs/tcga.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List
5
+
6
+ dirname, _ = os.path.split(os.path.dirname(__file__))
7
+
8
+
9
+ @dataclass
10
+ class GeneEmbeddModelConfig:
11
+
12
+ model_input: str = "" #will be infered
13
+
14
+ num_embed_hidden: int = 100 #30 for exp, 100 for rest
15
+ ff_input_dim:int = 0 #is infered later, equals gene expression len
16
+ ff_hidden_dim: List = field(default_factory=lambda: [300]) #300 for exp hico
17
+ feed_forward1_hidden: int = 256
18
+ num_attention_project: int = 64
19
+ num_encoder_layers: int = 1
20
+ dropout: float = 0.2
21
+ n: int = 121
22
+ relative_attns: List = field(default_factory=lambda: [29, 4, 6, 8, 10, 11])
23
+ num_attention_heads: int = 5
24
+
25
+ window: int = 2
26
+ # 200 is max rna length.
27
+ # TODO: if tokenizer is overlap, then max_length should be 60
28
+ # otherwise, will get cuda error, maybe dask can help
29
+ max_length: int = 40
30
+ tokens_len: int = math.ceil(max_length / window)
31
+ second_input_token_len: int = 0 # is infered during runtime
32
+ vocab_size: int = 0 # is infered during runtime
33
+ second_input_vocab_size: int = 0 # is infered during runtime
34
+ tokenizer: str = (
35
+ "overlap" # either overlap or no_overlap or overlap_multi_window
36
+ )
37
+
38
+ clf_target:str = 'sub_class_hico' # sub_class, major_class, sub_class_hico or major_class_hico. hico = high confidence
39
+ num_classes: int = 0 #will be infered during runtime
40
+ class_mappings:List = field(default_factory=lambda: [])#will be infered during runtime
41
+ class_weights :List = field(default_factory=lambda: [])
42
+ # how many extra window sizes other than deafault window
43
+ temperatures: List = field(default_factory=lambda: [0,10])
44
+
45
+ tokens_mapping_dict: Dict = None
46
+ false_input_perc:float = 0.0
47
+
48
+
49
+ @dataclass
50
+ class GeneEmbeddTrainConfig:
51
+ dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
52
+ precursor_file_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/HBDxBase.csv'
53
+ mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'
54
+ device: str = "cuda"
55
+ l2_weight_decay: float = 0.05
56
+ batch_size: int = 512
57
+
58
+ batch_per_epoch:int = 0 # is infered during runtime
59
+
60
+ label_smoothing_sim:float = 0.2
61
+ label_smoothing_clf:float = 0.0
62
+ # learning rate
63
+ learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
64
+ lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
65
+ lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
66
+ # TODO: 122 is the number of train batches per epoch, should be infered and set
67
+ # warmup batch should be during the form epoch*(train batch per epoch)
68
+ warmup_epoch: int = 10 # how many batches linear warm up for
69
+ final_epoch: int = 20 # final batch of training when want learning rate
70
+
71
+ top_k: int = 10#int(0.1 * batch_size) # if the corresponding rna/GE appears during the top k, the correctly classified
72
+ cross_val: bool = False
73
+ labels_mapping_path: str = None
74
+ filter_seq_length:bool = False
75
+
76
+ num_augment_exp:int = 20
77
+ shuffle_exp: bool = False
78
+
79
+ max_epochs: int = 3000
80
+
81
+
environment.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: transforna
2
+ channels:
3
+ - pytorch
4
+ - bioconda
5
+ - conda-forge
6
+ dependencies:
7
+ - anndata==0.8.0
8
+ - dill==0.3.6
9
+ - hydra-core==1.3.0
10
+ - imbalanced-learn==0.9.1
11
+ - matplotlib==3.5.3
12
+ - numpy==1.22.3
13
+ - omegaconf==2.2.2
14
+ - pandas==1.5.2
15
+ - plotly==5.10.0
16
+ - PyYAML==6.0
17
+ - rich==12.6.0
18
+ - viennarna=2.5.0=py39h98c8e45_0
19
+ - scanpy==1.9.1
20
+ - scikit-learn==1.2.0
21
+ - skorch==0.12.1
22
+ - pytorch=1.10.1=py3.9_cuda11.3_cudnn8.2.0_0
23
+ - tensorboard==2.11.2
24
+ - Levenshtein==0.21.0
install.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ python_version="3.9"
3
+ # Initialize conda
4
+ eval "$(conda shell.bash hook)"
5
+ #print the current environment
6
+ echo "The current environment is $CONDA_DEFAULT_ENV."
7
+ while [[ "$CONDA_DEFAULT_ENV" != "base" ]]; do
8
+ conda deactivate
9
+ done
10
+
11
+ #if conde transforna not found in the list of environments, then create the environment
12
+ if [[ $(conda env list | grep "transforna") == "" ]]; then
13
+ conda create -n transforna python=$python_version -y
14
+ conda activate transforna
15
+ conda install -c anaconda setuptools -y
16
+
17
+
18
+
19
+ fi
20
+ conda activate transforna
21
+
22
+ echo "The current environment is transforna."
23
+ pip install setuptools==59.5.0
24
+ # Uninstall TransfoRNA using pip
25
+ pip uninstall -y TransfoRNA
26
+
27
+ rm -rf dist TransfoRNA.egg-info
28
+
29
+
30
+ # Reinstall TransfoRNA using pip
31
+ python setup.py sdist
32
+ pip install dist/TransfoRNA-0.0.1.tar.gz
33
+ rm -rf TransfoRNA.egg-info dist
kba_pipeline/README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The HBDx knowledge-based annotation (KBA) pipeline for small RNA sequences
2
+
3
+ Most small RNA annotation tools map the sequences sequentially to different small RNA class specific reference databases, which prioritizes the distinct small RNA classes and conceals potential assignment ambiguities. The annotation strategy used here, maps the sequences to the reference sequences of all small RNA classes at the same time starting with zero mismatch tolerance. Unmapped sequences are intended to map with iterating mismatch tolerance up to three mismatches. To reduce ambiguity, sequences are first mapped to the standard coding and non-coding genes with increasing mismatch tolerance. Only then the unassigned sequences are mapped to pseudogenes in the same manner. Additionally, all small RNA sequences are checked for potential bacterial or viral origin, for genomic overlap to human transposable element loci and whether they contain potential prefixes of the 5‘ adapter.
4
+
5
+ In cases of multiple assignments per sequence (multiple precursors could be the origin of the sequence), the ambigous annotation is resolved if
6
+ a) all assigned precursors overlap with the genomic region of the precursor with the shortest genomic context -> the subclass name of the precursor with the shortest genomic context is used OR if
7
+ b) a bin of the assigned subclass names is at the 5' or 3' end of the respective precursor -> the subclass name matching the precursor end is used.
8
+ In cases where subclass names of a) and b) are not identical, the subclass name of method a) is assigned.
9
+
10
+
11
+ ![kba_pipeline_scheme_v05](https://github.com/gitHBDX/TransfoRNA/assets/79092907/62bf9e36-c7c7-4ff5-b747-c2c651281b42)
12
+
13
+
14
+ a) Schematic overview of the knowledge-based annotation (KBA) strategy applied for TransfoRNA.
15
+
16
+ b) Schematic overview of the miRNA annotation of the custom annotation (isomiR definition based on recent miRNA research [1]).
17
+
18
+ c) Schematic overview of the tRNA annotation of the custom annotation (inspired by UNITAS sub-classification [2]).
19
+
20
+ d) Binning strategy used in the custom annotation for the remaining RNA major classes. The number of nucleotides per bin is constant for each precursor sequence and ranges between 20 and 39 nucleotides. Assignments are based on the bin with the highest overlap to the sequence of interest.
21
+
22
+ e) Filtering steps that were applied to obtain the set of HICO annotations that were used for training of the TransfoRNA models.
23
+
24
+
25
+ ## Install environment
26
+
27
+ ```bash
28
+ cd kba_pipeline
29
+ conda env create --file environment.yml
30
+ ```
31
+
32
+ ## Run annotation pipeline
33
+
34
+ <b>Prerequisites:</b>
35
+ - [ ] the sequences to be annotated need to be stored as fasta format in the `kba_pipeline/data` folder
36
+ - [ ] the reference files for mapping need to be stored in the `kba_pipeline/references` folder (the required subfolders `HBDxBase`, `hg38` and `bacterial_viral` can be downloaded together with the TransfoRNA models from https://www.dropbox.com/sh/y7u8cofmg41qs0y/AADvj5lw91bx7fcDxghMbMtsa?dl=0)
37
+
38
+ ```bash
39
+ conda activate hbdx_kba
40
+ cd src
41
+ python make_anno.py --fasta_file your_sequences_to_be_annotated.fa
42
+ ```
43
+
44
+ This script calls two major functions:
45
+ - <b>map_2_HBDxBase</b>: sequential mismatch mapping to HBDxBase and genome
46
+ - <b>annotate_from_mapping</b>: generate sequence annotation based on mapping outputs
47
+
48
+ The main annotation file `sRNA_anno_aggregated_on_seq.csv` will be generated in the folder `outputs`
49
+
50
+
51
+
52
+ ## References
53
+
54
+ [1] Tomasello, Luisa, Rosario Distefano, Giovanni Nigita, and Carlo M. Croce. 2021. “The MicroRNA Family Gets Wider: The IsomiRs Classification and Role.” Frontiers in Cell and Developmental Biology 9 (June): 1–15. https://doi.org/10.3389/fcell.2021.668648.
55
+
56
+ [2] Gebert, Daniel, Charlotte Hewel, and David Rosenkranz. 2017. “Unitas: The Universal Tool for Annotation of Small RNAs.” BMC Genomics 18 (1): 1–14. https://doi.org/10.1186/s12864-017-4031-9.
57
+
58
+
kba_pipeline/environment.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: hbdx_kba
2
+ channels:
3
+ - bioconda
4
+ - conda-forge
5
+ dependencies:
6
+ - anndata=0.8.0
7
+ - bedtools=2.30.0
8
+ - biopython=1.79
9
+ - bowtie=1.3.1
10
+ - joblib>=1.2.0
11
+ - pyfastx=0.8.4
12
+ - pytest
13
+ - python=3.10.6
14
+ - pyyaml
15
+ - rich
16
+ - samtools=1.16.1
17
+ - tqdm
18
+ - viennarna=2.5.1
19
+ - levenshtein
20
+ - pip
kba_pipeline/src/annotate_from_mapping.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ######################################################################################################
2
+ # annotate sequences based on mapping results
3
+ ######################################################################################################
4
+ #%%
5
+ import os
6
+ import logging
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from difflib import get_close_matches
11
+ from Levenshtein import distance
12
+ import json
13
+
14
+ from joblib import Parallel, delayed
15
+ import multiprocessing
16
+
17
+
18
+ from utils import (fasta2df, fasta2df_subheader,log_time, reverse_complement)
19
+ from precursor_bins import get_bin_with_max_overlap
20
+
21
+
22
+ log = logging.getLogger(__name__)
23
+
24
+ pd.options.mode.chained_assignment = None
25
+
26
+
27
+ ######################################################################################################
28
+ # paths to reference and mapping files
29
+ ######################################################################################################
30
+
31
+ version = '_v4'
32
+
33
+ HBDxBase_csv = f'../../references/HBDxBase/HBDxBase_all{version}.csv'
34
+ miRBase_mature_path = '../../references/HBDxBase/miRBase/mature.fa'
35
+ mat_miRNA_pos_path = '../../references/HBDxBase/miRBase/hsa_mature_position.txt'
36
+
37
+ mapped_file = 'seqsmapped2HBDxBase_combined.txt'
38
+ unmapped_file = 'tmp_seqs3mm2HBDxBase_pseudo__unmapped.fa'
39
+ TE_file = 'tmp_seqsmapped2genome_intersect_TE.txt'
40
+ mapped_genome_file = 'seqsmapped2genome_combined.txt'
41
+ toomanyloci_genome_file = 'tmp_seqs0mm2genome__toomanyalign.fa'
42
+ unmapped_adapter_file = 'tmp_seqs3mm2adapters__unmapped.fa'
43
+ unmapped_genome_file = 'tmp_seqs0mm2genome__unmapped.fa'
44
+ unmapped_bacterial_file = 'tmp_seqs0mm2bacterial__unmapped.fa'
45
+ unmapped_viral_file = 'tmp_seqs0mm2viral__unmapped.fa'
46
+
47
+
48
+ sRNA_anno_file = 'sRNA_anno_from_mapping.csv'
49
+ aggreg_sRNA_anno_file = 'sRNA_anno_aggregated_on_seq.csv'
50
+
51
+
52
+
53
+ #%%
54
+ ######################################################################################################
55
+ # specific functions
56
+ ######################################################################################################
57
+
58
+ @log_time(log)
59
+ def extract_general_info(mapping_file):
60
+ # load mapping file
61
+ mapping_df = pd.read_csv(mapping_file, sep='\t', header=None)
62
+ mapping_df.columns = ['tmp_seq_id','reference','ref_start','sequence','other_alignments','mm_descriptors']
63
+
64
+ # add precursor length + number of bins that will be used for names
65
+ HBDxBase_df = pd.read_csv(HBDxBase_csv, index_col=0)
66
+ HBDxBase_df = HBDxBase_df[['precursor_length','precursor_bins','pseudo_class']].reset_index()
67
+ HBDxBase_df.rename(columns={'index': "reference"}, inplace=True)
68
+ mapping_df = mapping_df.merge(HBDxBase_df, left_on='reference', right_on='reference', how='left')
69
+
70
+ # extract information
71
+ mapping_df.loc[:,'mms'] = mapping_df.mm_descriptors.fillna('').str.count('>')
72
+ mapping_df.loc[:,'mm_descriptors'] = mapping_df.mm_descriptors.str.replace(',', ';')
73
+ mapping_df.loc[:,'small_RNA_class_annotation'] = mapping_df.reference.str.split('|').str[0]
74
+ mapping_df.loc[:,'subclass_type'] = mapping_df.reference.str.split('|').str[2]
75
+ mapping_df.loc[:,'precursor_name_full'] = mapping_df.reference.str.split('|').str[1].str.split('|').str[0]
76
+ mapping_df.loc[:,'precursor_name'] = mapping_df.precursor_name_full.str.split('__').str[0].str.split('|').str[0]
77
+ mapping_df.loc[:,'seq_length'] = mapping_df.sequence.apply(lambda x: len(x))
78
+ mapping_df.loc[:,'ref_end'] = mapping_df.ref_start + mapping_df.seq_length - 1
79
+ mapping_df.loc[:,'mitochondrial'] = np.where(mapping_df.reference.str.contains(r'(\|MT-)|(12S)|(16S)'), 'mito', 'nuclear')
80
+
81
+ return mapping_df
82
+
83
+
84
+ #%%
85
+ @log_time(log)
86
+ def tRNA_annotation(mapping_df):
87
+ """Extract tRNA specific annotation from mapping.
88
+ """
89
+ # keep only tRNA leader/trailer with right cutting sites (+/- 5nt)
90
+ # leader
91
+ tRF_leader_df = mapping_df[mapping_df['subclass_type'] == 'leader_tRF']
92
+ # assign as misc-leader-tRF if exceeding defined cutting site range
93
+ tRF_leader_df.loc[:,'subclass_type'] = np.where((tRF_leader_df.ref_start + tRF_leader_df.sequence.apply(lambda x: len(x))).between(45, 55, inclusive='both'), 'leader_tRF', 'misc-leader-tRF')
94
+
95
+ # trailer
96
+ tRF_trailer_df = mapping_df[mapping_df['subclass_type'] == 'trailer_tRF']
97
+ # assign as misc-trailer-tRF if exceeding defined cutting site range
98
+ tRF_trailer_df.loc[:,'subclass_type'] = np.where(tRF_trailer_df.ref_start.between(0, 5, inclusive='both'), 'trailer_tRF', 'misc-trailer-tRF')
99
+
100
+ # define tRF subclasses (leader_tRF and trailer_tRF have been assigned previously)
101
+ # NOTE: allow more flexibility at ends (similar to miRNA annotation)
102
+ tRNAs_df = mapping_df[((mapping_df['small_RNA_class_annotation'] == 'tRNA') & mapping_df['subclass_type'].isna())]
103
+ tRNAs_df.loc[((tRNAs_df.ref_start < 3) & (tRNAs_df.seq_length >= 30)),'subclass_type'] = '5p-tR-half'
104
+ tRNAs_df.loc[((tRNAs_df.ref_start < 3) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '5p-tRF'
105
+ tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)) < 6) & (tRNAs_df.seq_length >= 30)),'subclass_type'] = '3p-tR-half'
106
+ tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)).between(3,6,inclusive='neither')) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '3p-tRF'
107
+ tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)) < 3) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '3p-CCA-tRF'
108
+ tRNAs_df.loc[tRNAs_df.subclass_type.isna(),'subclass_type'] = 'misc-tRF'
109
+ # add ref_iso flag
110
+ tRNAs_df['tRNA_ref_iso'] = np.where(
111
+ (
112
+ (tRNAs_df.ref_start == 0)
113
+ | ((tRNAs_df.ref_end + 1) == tRNAs_df.precursor_length)
114
+ | ((tRNAs_df.ref_end + 1) == (tRNAs_df.precursor_length - 3))
115
+ ), 'reftRF', 'isotRF'
116
+ )
117
+ # concat tRNA, leader & trailer dfs
118
+ tRNAs_df = pd.concat([tRNAs_df, tRF_leader_df, tRF_trailer_df],axis=0)
119
+ # adjust precursor name and create tRNA name
120
+ tRNAs_df['precursor_name'] = tRNAs_df.precursor_name.str.extract(r"((tRNA-...-...)|(MT-..)|(tRX-...-...)|(tRNA-i...-...))", expand=True)[0]
121
+ tRNAs_df['subclass_name'] = tRNAs_df.subclass_type + '__' + tRNAs_df.precursor_name
122
+
123
+ return tRNAs_df
124
+
125
+ #%%
126
+ def faustrules_check(row):
127
+ """Check if isomiRs follow Faustrules (based on Tomasello et al. 2021).
128
+ """
129
+
130
+ # mark seqs that are not in range +/- 2nt of mature start
131
+ # check if ref_start.between(miRNAs_df.mature_start-2, miRNAs_df.mature_start+2, inclusive='both')]
132
+ ref_start = row['ref_start']
133
+ mature_start = row['mature_start']
134
+
135
+ if ref_start < mature_start - 2 or ref_start > mature_start + 2:
136
+ return False
137
+
138
+ # mark seqs with mismatch unless A>G or C>T in seed region (= position 0-8) or 3' polyA/polyT (max 3nt)
139
+ if pd.isna(row['mm_descriptors']):
140
+ return True
141
+
142
+ seed_region_positions = set(range(9))
143
+ non_templated_ends = {'A', 'AA', 'AAA', 'T', 'TT', 'TTT'}
144
+
145
+ sequence = row['sequence']
146
+ mm_descriptors = row['mm_descriptors'].split(';')
147
+
148
+ seed_region_mismatches = 0
149
+ three_prime_end_mismatches = 0
150
+
151
+ for descriptor in mm_descriptors:
152
+ pos, change = descriptor.split(':')
153
+ pos = int(pos)
154
+ original, new = change.split('>')
155
+
156
+ if pos in seed_region_positions and (original == 'A' and new == 'G' or original == 'C' and new == 'T'):
157
+ seed_region_mismatches += 1
158
+
159
+ if pos >= len(sequence) - 3 and sequence[pos:] in non_templated_ends:
160
+ three_prime_end_mismatches += 1
161
+
162
+ total_mismatches = seed_region_mismatches + three_prime_end_mismatches
163
+
164
+ return total_mismatches == len(mm_descriptors)
165
+
166
+ @log_time(log)
167
+ def miRNA_annotation(mapping_df):
168
+ """Extract miRNA specific annotation from mapping. RaH Faustrules are applied.
169
+ """
170
+
171
+ miRNAs_df = mapping_df[mapping_df.small_RNA_class_annotation == 'miRNA']
172
+
173
+ nr_missing_alignments_expected = len(miRNAs_df.loc[miRNAs_df.duplicated(['tmp_seq_id','reference'], keep='first'),:])
174
+
175
+ # load positions of mature miRNAs within precursor
176
+ miRNA_pos_df = pd.read_csv(mat_miRNA_pos_path, sep='\t')
177
+ miRNA_pos_df.drop(columns=['precursor_length'], inplace=True)
178
+ miRNAs_df = miRNAs_df.merge(miRNA_pos_df, left_on='precursor_name_full', right_on='name_precursor', how='left')
179
+
180
+ # load mature miRNA sequences from miRBase
181
+ miRBase_mature_df = fasta2df_subheader(miRBase_mature_path,0)
182
+ # subset to human miRNAs
183
+ miRBase_mature_df = miRBase_mature_df.loc[miRBase_mature_df.index.str.contains('hsa-'),:]
184
+ miRBase_mature_df.index = miRBase_mature_df.index.str.replace('hsa-','')
185
+ miRBase_mature_df.reset_index(inplace=True)
186
+ miRBase_mature_df.columns = ['name_mature','ref_miR_seq']
187
+ # add 'ref_miR_seq'
188
+ miRNAs_df = miRNAs_df.merge(miRBase_mature_df, left_on='name_mature', right_on='name_mature', how='left')
189
+
190
+ # for each duplicated tmp_seq_id/reference combi, keep the one lowest lev dist of sequence to ref_miR_seq
191
+ miRNAs_df['lev_dist'] = miRNAs_df.apply(lambda x: distance(x['sequence'], x['ref_miR_seq']), axis=1)
192
+ miRNAs_df = miRNAs_df.sort_values(by=['tmp_seq_id','lev_dist'], ascending=[True, True]).drop_duplicates(['tmp_seq_id','reference'], keep='first')
193
+
194
+ # add ref_iso flag
195
+ miRNAs_df['miRNA_ref_iso'] = np.where(
196
+ (
197
+ (miRNAs_df.ref_start == miRNAs_df.mature_start)
198
+ & (miRNAs_df.ref_end == miRNAs_df.mature_end)
199
+ & (miRNAs_df.mms == 0)
200
+ ), 'refmiR', 'isomiR'
201
+ )
202
+
203
+ # apply RaH Faustrules
204
+ miRNAs_df['faustrules_check'] = miRNAs_df.apply(faustrules_check, axis=1)
205
+
206
+ # set miRNA_ref_iso to 'misc-miR' if faustrules_check is False
207
+ miRNAs_df.loc[~miRNAs_df.faustrules_check,'miRNA_ref_iso'] = 'misc-miR'
208
+
209
+ # set subclass_name to name_mature if faustrules_check is True, else use precursor_name
210
+ miRNAs_df['subclass_name'] = np.where(miRNAs_df.faustrules_check, miRNAs_df.name_mature, miRNAs_df.precursor_name)
211
+
212
+ # store name_mature for functional analysis as miRNA_names, set miR- to mir- if faustrules_check is False
213
+ miRNAs_df['miRNA_names'] = np.where(miRNAs_df.faustrules_check, miRNAs_df.name_mature, miRNAs_df.name_mature.str.replace('miR-', 'mir-'))
214
+
215
+ # add subclass (NOTE: in cases where subclass is not part of mature name, use position relative to precursor half to define group )
216
+ miRNAs_df['subclass_type'] = np.where(miRNAs_df.name_mature.str.endswith('5p'), '5p', np.where(miRNAs_df.name_mature.str.endswith('3p'), '3p', 'tbd'))
217
+ miRNAs_df.loc[((miRNAs_df.subclass_type == 'tbd') & (miRNAs_df.mature_start < miRNAs_df.precursor_length/2)), 'subclass_type'] = '5p'
218
+ miRNAs_df.loc[((miRNAs_df.subclass_type == 'tbd') & (miRNAs_df.mature_start >= miRNAs_df.precursor_length/2)), 'subclass_type'] = '3p'
219
+
220
+ # subset to relevant columns
221
+ miRNAs_df = miRNAs_df[list(mapping_df.columns) + ['subclass_name','miRNA_ref_iso','miRNA_names','ref_miR_seq']]
222
+
223
+ return miRNAs_df, nr_missing_alignments_expected
224
+
225
+
226
+ #%%
227
+ ######################################################################################################
228
+ # annotation of other sRNA classes
229
+ ######################################################################################################
230
+ def get_bin_with_max_overlap_parallel(df):
231
+ return df.apply(get_bin_with_max_overlap, axis=1)
232
+
233
+ def applyParallel(df, func):
234
+ retLst = Parallel(n_jobs=multiprocessing.cpu_count())(delayed(func)(group) for group in np.array_split(df,30))
235
+ return pd.concat(retLst)
236
+
237
+
238
+ @log_time(log)
239
+ def other_sRNA_annotation_new_binning(mapping_df):
240
+ """Generate subclass_name for non-tRNA/miRNA sRNAs by precursor-binning.
241
+ New binning approach: bin size is dynamically determined by the precursor length. Assignments are based on the bin with the highest overlap.
242
+ """
243
+
244
+ other_sRNAs_df = mapping_df[~((mapping_df.small_RNA_class_annotation == 'miRNA') | (mapping_df.small_RNA_class_annotation == 'tRNA'))]
245
+
246
+ #create empty columns; bin start and bin end
247
+ other_sRNAs_df['bin_start'] = ''
248
+ other_sRNAs_df['bin_end'] = ''
249
+
250
+ other_sRNAs_df = applyParallel(other_sRNAs_df, get_bin_with_max_overlap_parallel)
251
+
252
+ return other_sRNAs_df
253
+
254
+
255
+ #%%
256
+ @log_time(log)
257
+ def extract_sRNA_class_specific_info(mapping_df):
258
+ tRNAs_df = tRNA_annotation(mapping_df)
259
+ miRNAs_df, nr_missing_alignments_expected = miRNA_annotation(mapping_df)
260
+ other_sRNAs_df = other_sRNA_annotation_new_binning(mapping_df)
261
+
262
+ # add miRNA columns
263
+ tRNAs_df[['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']] = pd.DataFrame(columns=['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq'])
264
+ other_sRNAs_df[['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']] = pd.DataFrame(columns=['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq'])
265
+
266
+ # re-concat sRNA class dfs
267
+ sRNA_anno_df = pd.concat([miRNAs_df, tRNAs_df, other_sRNAs_df],axis=0)
268
+
269
+ # TEST if alignments were lost or duplicated
270
+ assert ((len(mapping_df) - nr_missing_alignments_expected) == len(sRNA_anno_df)), "alignments were lost or duplicated"
271
+
272
+ return sRNA_anno_df
273
+
274
+ #%%
275
+ def get_nth_nt(row):
276
+ return row['sequence'][int(row['PTM_position_in_seq'])-1]
277
+
278
+
279
+
280
+ #%%
281
+ @log_time(log)
282
+ def aggregate_info_per_seq(sRNA_anno_df):
283
+ # fillna of 'subclass_name_bin_pos' with 'subclass_name'
284
+ sRNA_anno_df['subclass_name_bin_pos'] = sRNA_anno_df['subclass_name_bin_pos'].fillna(sRNA_anno_df['subclass_name'])
285
+ # get aggregated info per seq
286
+ aggreg_per_seq_df = sRNA_anno_df.groupby(['sequence']).agg({'small_RNA_class_annotation': lambda x: ';'.join(sorted(x.unique())), 'pseudo_class': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'subclass_type': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'subclass_name': lambda x: ';'.join(sorted(x.unique())), 'subclass_name_bin_pos': lambda x: ';'.join(sorted(x.unique())), 'miRNA_names': lambda x: ';'.join(x.fillna('').unique()), 'precursor_name_full': lambda x: ';'.join(sorted(x.unique())), 'mms': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'reference': lambda x: len(x), 'mitochondrial': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'ref_miR_seq': lambda x: ';'.join(x.fillna('').unique())})
287
+ aggreg_per_seq_df['miRNA_names'] = aggreg_per_seq_df.miRNA_names.str.replace(r';$','', regex=True)
288
+ aggreg_per_seq_df['ref_miR_seq'] = aggreg_per_seq_df.ref_miR_seq.str.replace(r';$','', regex=True)
289
+ aggreg_per_seq_df['mms'] = aggreg_per_seq_df['mms'].astype(int)
290
+
291
+ # re-add 'miRNA_ref_iso','tRNA_ref_iso'
292
+ refmir_df = sRNA_anno_df[['sequence','miRNA_ref_iso','tRNA_ref_iso']]
293
+ refmir_df.drop_duplicates('sequence', inplace=True)
294
+ refmir_df.set_index('sequence', inplace=True)
295
+ aggreg_per_seq_df = aggreg_per_seq_df.merge(refmir_df, left_index=True, right_index=True, how='left')
296
+
297
+ # TEST if sequences were lost
298
+ assert (len(aggreg_per_seq_df) == len(sRNA_anno_df.sequence.unique())), "sequences were lost by aggregation"
299
+
300
+ # load unmapped seqs, if it exits
301
+ if os.path.exists(unmapped_file):
302
+ unmapped_df = fasta2df(unmapped_file)
303
+ unmapped_df = pd.DataFrame(data='no_annotation', index=unmapped_df.sequence, columns=aggreg_per_seq_df.columns)
304
+ unmapped_df['mms'] = np.nan
305
+ unmapped_df['reference'] = np.nan
306
+ unmapped_df['pseudo_class'] = True # set no annotation as pseudo_class
307
+
308
+ # merge mapped and unmapped
309
+ annotation_df = pd.concat([aggreg_per_seq_df,unmapped_df])
310
+ else:
311
+ annotation_df = aggreg_per_seq_df.copy()
312
+
313
+ # load mapping to genome file
314
+ mapping_genome_df = pd.read_csv(mapped_genome_file, index_col=0, sep='\t', header=None)
315
+ mapping_genome_df.columns = ['strand','reference','ref_start','sequence','other_alignments','mm_descriptors']
316
+ mapping_genome_df = mapping_genome_df[['strand','reference','ref_start','sequence','other_alignments']]
317
+
318
+ # use reverse complement of 'sequence' for 'strand' == '-'
319
+ mapping_genome_df.loc[:,'sequence'] = np.where(mapping_genome_df.strand == '-', mapping_genome_df.sequence.apply(lambda x: reverse_complement(x)), mapping_genome_df.sequence)
320
+
321
+ # get aggregated info per seq
322
+ aggreg_per_seq__genome_df = mapping_genome_df.groupby('sequence').agg({'reference': lambda x: ';'.join(sorted(x.unique())), 'other_alignments': lambda x: len(x)})
323
+ aggreg_per_seq__genome_df['other_alignments'] = aggreg_per_seq__genome_df['other_alignments'].astype(int)
324
+
325
+ # number of genomic loci
326
+ genomic_loci_df = pd.DataFrame(mapping_genome_df.sequence.value_counts())
327
+ genomic_loci_df.columns = ['num_genomic_loci_maps']
328
+
329
+ # load too many aligments seqs
330
+ if os.path.exists(toomanyloci_genome_file):
331
+ toomanyloci_genome_df = fasta2df(toomanyloci_genome_file)
332
+ toomanyloci_genome_df = pd.DataFrame(data=101, index=toomanyloci_genome_df.sequence, columns=genomic_loci_df.columns)
333
+ else:
334
+ toomanyloci_genome_df = pd.DataFrame(columns=genomic_loci_df.columns)
335
+
336
+ # load unmapped seqs
337
+ if os.path.exists(unmapped_genome_file):
338
+ unmapped_genome_df = fasta2df(unmapped_genome_file)
339
+ unmapped_genome_df = pd.DataFrame(data=0, index=unmapped_genome_df.sequence, columns=genomic_loci_df.columns)
340
+ else:
341
+ unmapped_genome_df = pd.DataFrame(columns=genomic_loci_df.columns)
342
+
343
+ # concat toomanyloci, unmapped, and genomic_loci
344
+ num_genomic_loci_maps_df = pd.concat([genomic_loci_df,toomanyloci_genome_df,unmapped_genome_df])
345
+
346
+ # merge to annotation_df
347
+ annotation_df = annotation_df.merge(num_genomic_loci_maps_df, left_index=True, right_index=True, how='left')
348
+ annotation_df.reset_index(inplace=True)
349
+
350
+ # add 'miRNA_seed'
351
+ annotation_df.loc[:,"miRNA_seed"] = np.where(annotation_df.small_RNA_class_annotation.str.contains('miRNA', na=False), annotation_df.sequence.str[1:9], "")
352
+
353
+ # TEST if nan values in 'num_genomic_loci_maps'
354
+ assert (annotation_df.num_genomic_loci_maps.isna().any() == False), "nan values in 'num_genomic_loci_maps'"
355
+
356
+ return annotation_df
357
+
358
+
359
+
360
+
361
+ #%%
362
+ @log_time(log)
363
+ def get_five_prime_adapter_info(annotation_df, five_prime_adapter):
364
+ adapter_df = pd.DataFrame(index=annotation_df.sequence)
365
+
366
+ min_length = 6
367
+
368
+ is_prefixed = None
369
+ print("5' adapter affixes:")
370
+ for l in range(0, len(five_prime_adapter) - min_length):
371
+ is_prefixed_l = adapter_df.index.str.startswith(five_prime_adapter[l:])
372
+ print(f"{five_prime_adapter[l:].ljust(30, ' ')}{is_prefixed_l.sum()}")
373
+ adapter_df.loc[adapter_df.index.str.startswith(five_prime_adapter[l:]), "five_prime_adapter_length"] = len(five_prime_adapter[l:])
374
+ if is_prefixed is None:
375
+ is_prefixed = is_prefixed_l
376
+ else:
377
+ is_prefixed |= is_prefixed_l
378
+
379
+ print(f"There are {is_prefixed.sum()} prefixed features.")
380
+ print("\n")
381
+
382
+ adapter_df['five_prime_adapter_length'] = adapter_df['five_prime_adapter_length'].fillna(0)
383
+ adapter_df['five_prime_adapter_length'] = adapter_df['five_prime_adapter_length'].astype('int')
384
+ adapter_df['five_prime_adapter_filter'] = np.where(adapter_df['five_prime_adapter_length'] == 0, True, False)
385
+ adapter_df = adapter_df.reset_index()
386
+
387
+ return adapter_df
388
+
389
+ #%%
390
+ @log_time(log)
391
+ def reduce_ambiguity(annotation_df: pd.DataFrame) -> pd.DataFrame:
392
+ """Reduce ambiguity by
393
+
394
+ a) using subclass_name of precursor with shortest genomic context, if all other assigned precursors overlap with its genomic region
395
+
396
+ b) using subclass_name whose bin is at the 5' or 3' end of the precursor
397
+
398
+ Parameters
399
+ ----------
400
+ annotation_df : pd.DataFrame
401
+ A DataFrame containing the annotation of the sequences (var)
402
+
403
+ Returns
404
+ -------
405
+ pd.DataFrame
406
+ An improved version of the input DataFrame with reduced ambiguity
407
+ """
408
+
409
+ # extract ambigious assignments for subclass name
410
+ ambigious_matches_df = annotation_df[annotation_df.subclass_name.str.contains(';',na=False)]
411
+ if len(ambigious_matches_df) == 0:
412
+ print('No ambigious assignments for subclass name found.')
413
+ return annotation_df
414
+ clear_matches_df = annotation_df[~annotation_df.subclass_name.str.contains(';',na=False)]
415
+
416
+ # extract required information from HBDxBase
417
+ HBDxBase_all_df = pd.read_csv(HBDxBase_csv, index_col=0)
418
+ bin_dict = HBDxBase_all_df[['precursor_name','precursor_bins']].set_index('precursor_name').to_dict()['precursor_bins']
419
+ sRNA_class_dict = HBDxBase_all_df[['precursor_name','small_RNA_class_annotation']].set_index('precursor_name').to_dict()['small_RNA_class_annotation']
420
+ pseudo_class_dict = HBDxBase_all_df[['precursor_name','pseudo_class']].set_index('precursor_name').to_dict()['pseudo_class']
421
+ sc_type_dict = HBDxBase_all_df[['precursor_name','subclass_type']].set_index('precursor_name').to_dict()['subclass_type']
422
+ genomic_context_bed = HBDxBase_all_df[['chr','start','end','precursor_name','score','strand']]
423
+ genomic_context_bed.columns = ['seq_id','start','end','name','score','strand']
424
+ genomic_context_bed.reset_index(drop=True, inplace=True)
425
+ genomic_context_bed['genomic_length'] = genomic_context_bed.end - genomic_context_bed.start
426
+
427
+
428
+ def get_overlaps(genomic_context_bed: pd.DataFrame, name: str = None, complement: bool = False) -> list:
429
+ """Get genomic overlap of a given precursor name
430
+
431
+ Parameters
432
+ ----------
433
+ genomic_context_bed : pd.DataFrame
434
+ A DataFrame containing genomic locations of precursors in bed format
435
+ with column names: 'chr','start','end','precursor_name','score','strand'
436
+ name : str
437
+ The name of the precursor to get genomic context for
438
+ complement : bool
439
+ If True, return all precursors that do not overlap with the given precursor
440
+
441
+ Returns
442
+ -------
443
+ list
444
+ A list containing the precursors in the genomic (anti-)context of the given precursor
445
+ (including the precursor itself)
446
+ """
447
+ series_OI = genomic_context_bed[genomic_context_bed['name'] == name]
448
+ start = series_OI['start'].values[0]
449
+ end = series_OI['end'].values[0]
450
+ seq_id = series_OI['seq_id'].values[0]
451
+ strand = series_OI['strand'].values[0]
452
+
453
+ overlap_df = genomic_context_bed.copy()
454
+
455
+ condition = (((overlap_df.start > start) &
456
+ (overlap_df.start < end)) |
457
+ ((overlap_df.end > start) &
458
+ (overlap_df.end < end)) |
459
+ ((overlap_df.start < start) &
460
+ (overlap_df.end > start)) |
461
+ ((overlap_df.start == start) &
462
+ (overlap_df.end == end)) |
463
+ ((overlap_df.start == start) &
464
+ (overlap_df.end > end)) |
465
+ ((overlap_df.start < start) &
466
+ (overlap_df.end == end)))
467
+ if not complement:
468
+ overlap_df = overlap_df[condition]
469
+ else:
470
+ overlap_df = overlap_df[~condition]
471
+ overlap_df = overlap_df[overlap_df.seq_id == seq_id]
472
+ if strand is not None:
473
+ overlap_df = overlap_df[overlap_df.strand == strand]
474
+ overlap_list = overlap_df['name'].tolist()
475
+ return overlap_list
476
+
477
+
478
+ def check_genomic_ctx_of_smallest_prec(precursor_name: str) -> str:
479
+ """Check for a given ambigious precursor assignment (several names separated by ';')
480
+ if all assigned precursors overlap with the genomic region
481
+ of the precursor with the shortest genomic context
482
+
483
+ Parameters
484
+ ----------
485
+ precursor_name: str
486
+ A string containing several precursor names separated by ';'
487
+
488
+ Returns
489
+ -------
490
+ str
491
+ The precursor suggested to be used instead of the multi assignment,
492
+ or None if the ambiguity could not be resolved
493
+ """
494
+ assigned_names = precursor_name.split(';')
495
+
496
+ tmp_genomic_context = genomic_context_bed[genomic_context_bed.name.isin(assigned_names)]
497
+ # get name of smallest genomic region
498
+ if len(tmp_genomic_context) > 0:
499
+ smallest_name = tmp_genomic_context.name[tmp_genomic_context.genomic_length.idxmin()]
500
+ # check if all assigned names are in overlap of smallest genomic region
501
+ if set(assigned_names).issubset(set(get_overlaps(genomic_context_bed,smallest_name))):
502
+ return smallest_name
503
+ else:
504
+ return None
505
+ else:
506
+ return None
507
+
508
+ def get_subclass_name(subclass_name: str, short_prec_match_new_name: str) -> str:
509
+ """Get subclass name matching to a precursor name from a ambigious assignment (several names separated by ';')
510
+
511
+ Parameters
512
+ ----------
513
+ subclass_name: str
514
+ A string containing several subclass names separated by ';'
515
+ short_prec_match_new_name: str
516
+ The name of the precursor to be used instead of the multi assignment
517
+
518
+ Returns
519
+ -------
520
+ str
521
+ The subclass name suggested to be used instead of the multi assignment,
522
+ or None if the ambiguity could not be resolved
523
+ """
524
+ if short_prec_match_new_name is not None:
525
+ matches = get_close_matches(short_prec_match_new_name,subclass_name.split(';'),cutoff=0.2)
526
+ if matches:
527
+ return matches[0]
528
+ else:
529
+ print(f"Could not find match for {short_prec_match_new_name} in {subclass_name}")
530
+ return subclass_name
531
+ else:
532
+ return None
533
+
534
+
535
+ def check_end_bins(subclass_name: str) -> str:
536
+ """Check for a given ambigious subclass name assignment (several names separated by ';')
537
+ if ambiguity can be resolved by selecting the subclass name whose bin matches the 3'/5' end of the precursor
538
+
539
+ Parameters
540
+ ----------
541
+ subclass_name: str
542
+ A string containing several subclass names separated by ';'
543
+
544
+ Returns
545
+ -------
546
+ str
547
+ The subclass name suggested to be used instead of the multi assignment,
548
+ or None if the ambiguity could not be resolved
549
+ """
550
+ for name in subclass_name.split(';'):
551
+ if '_bin-' in name:
552
+ name_parts = name.split('_bin-')
553
+ if name_parts[0] in bin_dict and bin_dict[name_parts[0]] == int(name_parts[1]):
554
+ return name
555
+ elif int(name_parts[1]) == 1:
556
+ return name
557
+ return None
558
+
559
+
560
+ def adjust_4_resolved_cases(row: pd.Series) -> tuple:
561
+ """For a resolved ambiguous subclass names return adjusted values of
562
+ precursor_name_full, small_RNA_class_annotation, pseudo_class, and subclass_type
563
+
564
+ Parameters
565
+ ----------
566
+ row: pd.Series
567
+ A row of the var annotation containing the columns 'subclass_name', 'precursor_name_full',
568
+ 'small_RNA_class_annotation', 'pseudo_class', 'subclass_type', and 'ambiguity_resolved'
569
+
570
+ Returns
571
+ -------
572
+ tuple
573
+ A tuple containing the adjusted values of 'precursor_name_full', 'small_RNA_class_annotation',
574
+ 'pseudo_class', and 'subclass_type' for resolved ambiguous cases and the original values for unresolved cases
575
+ """
576
+ if row.ambiguity_resolved:
577
+ matches_prec = get_close_matches(row.subclass_name, row.precursor_name_full.split(';'), cutoff=0.2)
578
+ if matches_prec:
579
+ return matches_prec[0], sRNA_class_dict[matches_prec[0]], pseudo_class_dict[matches_prec[0]], sc_type_dict[matches_prec[0]]
580
+ return row.precursor_name_full, row.small_RNA_class_annotation, row.pseudo_class, row.subclass_type
581
+
582
+
583
+ # resolve ambiguity by checking genomic context of smallest precursor
584
+ ambigious_matches_df['short_prec_match_new_name'] = ambigious_matches_df.precursor_name_full.apply(check_genomic_ctx_of_smallest_prec)
585
+ ambigious_matches_df['short_prec_match_new_name'] = ambigious_matches_df.apply(lambda x: get_subclass_name(x.subclass_name, x.short_prec_match_new_name), axis=1)
586
+ ambigious_matches_df['short_prec_match'] = ambigious_matches_df['short_prec_match_new_name'].notnull()
587
+
588
+ # resolve ambiguity by checking if bin matches 3'/5' end of precursor
589
+ ambigious_matches_df['end_bin_match_new_name'] = ambigious_matches_df.subclass_name.apply(check_end_bins)
590
+ ambigious_matches_df['end_bin_match'] = ambigious_matches_df['end_bin_match_new_name'].notnull()
591
+
592
+ # check if short_prec_match and end_bin_match are equal in any case
593
+ test_df = ambigious_matches_df[((ambigious_matches_df.short_prec_match == True) & (ambigious_matches_df.end_bin_match == True))]
594
+ if not (test_df.short_prec_match_new_name == test_df.end_bin_match_new_name).all():
595
+ print('Number of cases where short_prec_match is not matching end_bin_match_new_name:',len(test_df[(test_df.short_prec_match_new_name != test_df.end_bin_match_new_name)]))
596
+
597
+ # replace subclass_name with short_prec_match_new_name or end_bin_match_new_name
598
+ # NOTE: if short_prec_match and end_bin_match are True, short_prec_match_new_name is used
599
+ ambigious_matches_df['subclass_name'] = ambigious_matches_df.apply(lambda x: x.end_bin_match_new_name if x.end_bin_match == True else x.subclass_name, axis=1)
600
+ ambigious_matches_df['subclass_name'] = ambigious_matches_df.apply(lambda x: x.short_prec_match_new_name if x.short_prec_match == True else x.subclass_name, axis=1)
601
+
602
+ # generate column 'ambiguity_resolved' which is True if short_prec_match and/or end_bin_match is True
603
+ ambigious_matches_df['ambiguity_resolved'] = ambigious_matches_df.short_prec_match | ambigious_matches_df.end_bin_match
604
+ print("Ambiguity resolved?\n",ambigious_matches_df.ambiguity_resolved.value_counts(normalize=True))
605
+
606
+ # for resolved ambiguous matches, adjust precursor_name_full, small_RNA_class_annotation, pseudo_class, subclass_type
607
+ ambigious_matches_df[['precursor_name_full','small_RNA_class_annotation','pseudo_class','subclass_type']] = ambigious_matches_df.apply(adjust_4_resolved_cases, axis=1, result_type='expand')
608
+
609
+ # drop temporary columns
610
+ ambigious_matches_df.drop(columns=['short_prec_match_new_name','short_prec_match','end_bin_match_new_name','end_bin_match'], inplace=True)
611
+
612
+ # concat with clear_matches_df
613
+ clear_matches_df['ambiguity_resolved'] = False
614
+ improved_annotation_df = pd.concat([clear_matches_df, ambigious_matches_df], axis=0)
615
+ improved_annotation_df = improved_annotation_df.reindex(annotation_df.index)
616
+
617
+ return improved_annotation_df
618
+
619
+ #%%
620
+ ######################################################################################################
621
+ # HICO (=high confidence) annotation
622
+ ######################################################################################################
623
+ @log_time(log)
624
+ def add_hico_annotation(annotation_df, five_prime_adapter):
625
+ """For miRNAs only use hico annotation if part of miRBase hico set AND refmiR
626
+ """
627
+
628
+ # add 'TE_annotation'
629
+ TE_df = pd.read_csv(TE_file, sep='\t', header=None, names=['sequence','TE_annotation'])
630
+ annotation_df = annotation_df.merge(TE_df, left_on='sequence', right_on='sequence', how='left')
631
+
632
+ # add 'bacterial' mapping filter
633
+ bacterial_unmapped_df = fasta2df(unmapped_bacterial_file)
634
+ annotation_df.loc[:,'bacterial'] = np.where(annotation_df.sequence.isin(bacterial_unmapped_df.sequence), False, True)
635
+
636
+ # add 'viral' mapping filter
637
+ viral_unmapped_df = fasta2df(unmapped_viral_file)
638
+ annotation_df.loc[:,'viral'] = np.where(annotation_df.sequence.isin(viral_unmapped_df.sequence), False, True)
639
+
640
+ # add 'adapter_mapping_filter' column
641
+ adapter_unmapped_df = fasta2df(unmapped_adapter_file)
642
+ annotation_df.loc[:,'adapter_mapping_filter'] = np.where(annotation_df.sequence.isin(adapter_unmapped_df.sequence), True, False)
643
+
644
+ # add filter column 'five_prime_adapter_filter' and column 'five_prime_adapter_length' indicating the length of the prefixed 5' adapter sequence
645
+ adapter_df = get_five_prime_adapter_info(annotation_df, five_prime_adapter)
646
+ annotation_df = annotation_df.merge(adapter_df, left_on='sequence', right_on='sequence', how='left')
647
+
648
+ # apply ambiguity reduction
649
+ annotation_df = reduce_ambiguity(annotation_df)
650
+
651
+ # add 'single_class_annotation'
652
+ annotation_df.loc[:,'single_class_annotation'] = np.where(annotation_df.small_RNA_class_annotation.str.contains(';',na=True), False, True)
653
+
654
+ # add 'single_name_annotation'
655
+ annotation_df.loc[:,'single_name_annotation'] = np.where(annotation_df.subclass_name.str.contains(';',na=True), False, True)
656
+
657
+ # add 'hypermapper' for sequences where more than 50 potential mapping references are recorded
658
+ annotation_df.loc[annotation_df.reference > 50,'subclass_name'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str)
659
+ annotation_df.loc[annotation_df.reference > 50,'subclass_name_bin_pos'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str)
660
+ annotation_df.loc[annotation_df.reference > 50,'precursor_name_full'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str)
661
+
662
+ annotation_df.loc[:,'mitochondrial'] = np.where(annotation_df.mitochondrial.str.contains('mito',na=False), True, False)
663
+
664
+ # add 'hico'
665
+ annotation_df.loc[:,'hico'] = np.where((
666
+ (annotation_df.mms == 0)
667
+ & (annotation_df.single_name_annotation == True)
668
+ & (annotation_df.TE_annotation.isna() == True)
669
+ & (annotation_df.bacterial == False)
670
+ & (annotation_df.viral == False)
671
+ & (annotation_df.adapter_mapping_filter == True)
672
+ & (annotation_df.five_prime_adapter_filter == True)
673
+ ), True, False)
674
+ ## NOTE: for miRNAs only use hico annotation if part of refmiR set
675
+ annotation_df.loc[annotation_df.small_RNA_class_annotation == 'miRNA','hico'] = annotation_df.loc[annotation_df.small_RNA_class_annotation == 'miRNA','hico'] & (annotation_df.miRNA_ref_iso == 'refmiR')
676
+
677
+ print(annotation_df[annotation_df.single_class_annotation == True].groupby('small_RNA_class_annotation').hico.value_counts())
678
+
679
+ return annotation_df
680
+
681
+
682
+
683
+
684
+ #%%
685
+ ######################################################################################################
686
+ # annotation pipeline
687
+ ######################################################################################################
688
+ @log_time(log)
689
+ def main(five_prime_adapter):
690
+ """Executes 'annotate_from_mapping'.
691
+
692
+ Uses:
693
+
694
+ - HBDxBase_csv
695
+ - miRBase_mature_path
696
+ - mat_miRNA_pos_path
697
+
698
+ - mapping_file
699
+ - unmapped_file
700
+ - mapped_genome_file
701
+ - toomanyloci_genome_file
702
+ - unmapped_genome_file
703
+
704
+ - TE_file
705
+ - unmapped_adapter_file
706
+ - unmapped_bacterial_file
707
+ - unmapped_viral_file
708
+ - five_prime_adapter
709
+
710
+ """
711
+
712
+
713
+ print('-------- extract general information for sequences that mapped to the HBDxBase --------')
714
+ mapped_info_df = extract_general_info(mapped_file)
715
+ print("\n")
716
+
717
+ print('-------- extract sRNA class specific information for sequences that mapped to the HBDxBase --------')
718
+ mapped_sRNA_anno_df = extract_sRNA_class_specific_info(mapped_info_df)
719
+
720
+ print('-------- save to file --------')
721
+ mapped_sRNA_anno_df.to_csv(sRNA_anno_file)
722
+ print("\n")
723
+
724
+ print('-------- aggregate information for mapped and unmapped sequences (HBDxBase & human genome) --------')
725
+ sRNA_anno_per_seq_df = aggregate_info_per_seq(mapped_sRNA_anno_df)
726
+ print("\n")
727
+
728
+ print('-------- add hico annotation (based on aggregated infos + mapping to viral/bacterial genomes + intersection with TEs) --------')
729
+ sRNA_anno_per_seq_df = add_hico_annotation(sRNA_anno_per_seq_df, five_prime_adapter)
730
+ print("\n")
731
+
732
+ print('-------- save to file --------')
733
+ # set sequence as index again
734
+ sRNA_anno_per_seq_df.set_index('sequence', inplace=True)
735
+ sRNA_anno_per_seq_df.to_csv(aggreg_sRNA_anno_file)
736
+ print("\n")
737
+
738
+ print('-------- generate subclass_to_annotation dict --------')
739
+ result_df = sRNA_anno_per_seq_df[['subclass_name', 'small_RNA_class_annotation']].copy()
740
+ result_df.reset_index(drop=True, inplace=True)
741
+ result_df.drop_duplicates(inplace=True)
742
+ result_df = result_df[~result_df["subclass_name"].str.contains(";")]
743
+ subclass_to_annotation = dict(zip(result_df["subclass_name"],result_df["small_RNA_class_annotation"]))
744
+ with open('subclass_to_annotation.json', 'w') as fp:
745
+ json.dump(subclass_to_annotation, fp)
746
+
747
+ print('-------- delete tmp files --------')
748
+ os.system("rm *tmp_*")
749
+
750
+
751
+ #%%
kba_pipeline/src/make_anno.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #%%
3
+ import argparse
4
+ import os
5
+ import logging
6
+
7
+ from utils import make_output_dir,write_2_log,log_time
8
+ import map_2_HBDxBase as map_2_HBDxBase
9
+ import annotate_from_mapping as annotate_from_mapping
10
+
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+
16
+ #%%
17
+ # get command line arguments
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--five_prime_adapter', type=str, default='GTTCAGAGTTCTACAGTCCGACGATC')
20
+ parser.add_argument('--fasta_file', type=str, help="Required to provide: --fasta_file sequences_to_be_annotated.fa") # NOTE: needs to be stored in "data" folder
21
+ args = parser.parse_args()
22
+ if not args.fasta_file:
23
+ parser.print_help()
24
+ exit()
25
+ five_prime_adapter = args.five_prime_adapter
26
+ sequence_file = args.fasta_file
27
+
28
+ #%%
29
+ @log_time(log)
30
+ def main(five_prime_adapter, sequence_file):
31
+ """Executes 'make_anno'.
32
+ 1. Maps input sequences to HBDxBase, the human genome, and a collection of viral and bacterial genomes.
33
+ 2. Extracts information from mapping files.
34
+ 3. Generates annotation columns and final annotation dataframe.
35
+
36
+ Uses:
37
+
38
+ - sequence_file
39
+ - five_prime_adapter
40
+
41
+ """
42
+ output_dir = make_output_dir(sequence_file)
43
+ os.chdir(output_dir)
44
+
45
+ log_folder = "log"
46
+ if not os.path.exists(log_folder):
47
+ os.makedirs(log_folder)
48
+ write_2_log(f"{log_folder}/make_anno.log")
49
+
50
+ # add name of sequence_file to log file
51
+ with open(f"{log_folder}/make_anno.log", "a") as ofile:
52
+ ofile.write(f"Sequence file: {sequence_file}\n")
53
+
54
+ map_2_HBDxBase.main("../../data/" + sequence_file)
55
+ annotate_from_mapping.main(five_prime_adapter)
56
+
57
+
58
+ main(five_prime_adapter, sequence_file)
59
+ # %%
kba_pipeline/src/map_2_HBDxBase.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ######################################################################################################
2
+ # map sequences to HBDxBase
3
+ ######################################################################################################
4
+ #%%
5
+ import os
6
+ import logging
7
+
8
+ from utils import fasta2df,log_time
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ ######################################################################################################
14
+ # paths to reference files
15
+ ######################################################################################################
16
+
17
+ version = '_v4'
18
+
19
+ HBDxBase_index_path = f'../../references/HBDxBase/HBDxBase{version}'
20
+ HBDxBase_pseudo_index_path = f'../../references/HBDxBase/HBDxBase_pseudo{version}'
21
+ genome_index_path = '../../references/hg38/genome'
22
+ adapter_index_path = '../../references/HBDxBase/adapters'
23
+ TE_path = '../../references/hg38/TE.bed'
24
+ bacterial_index_path = '../../references/bacterial_viral/all_bacterial_refseq_with_human_host__201127.index'
25
+ viral_index_path = '../../references/bacterial_viral/viral_refseq_with_human_host__201127.index'
26
+
27
+
28
+
29
+
30
+ #%%
31
+ ######################################################################################################
32
+ # specific functions
33
+ ######################################################################################################
34
+
35
+ @log_time(log)
36
+ def prepare_input_files(seq_input):
37
+
38
+ # check if seq_input is path or list
39
+ if type(seq_input) == str:
40
+ # get seqs in dataset
41
+ seqs = fasta2df(seq_input)
42
+ seqs = seqs.sequence
43
+ elif type(seq_input) == list:
44
+ seqs = seq_input
45
+ else:
46
+ raise ValueError('seq_input must be either path to fasta file or list of sequences')
47
+
48
+ # add number of sequences to log file
49
+ log_folder = "log"
50
+ with open(f"{log_folder}/make_anno.log", "a") as ofile:
51
+ ofile.write(f"KBA pipeline based on HBDxBase{version}\n")
52
+ ofile.write(f"Number of sequences to be annotated: {str(len(seqs))}\n")
53
+
54
+ if type(seq_input) == str:
55
+ with open('seqs.fa', 'w') as ofile_1:
56
+ for i in range(len(seqs)):
57
+ ofile_1.write(">" + seqs.index[i] + "\n" + seqs[i] + "\n")
58
+ else:
59
+ with open('seqs.fa', 'w') as ofile_1:
60
+ for i in range(len(seqs)):
61
+ ofile_1.write(">seq_" + str(i) + "\n" + seqs[i] + "\n")
62
+
63
+ @log_time(log)
64
+ def map_seq_2_HBDxBase(
65
+ number_mm,
66
+ fasta_in_file,
67
+ out_prefix
68
+ ):
69
+
70
+ bowtie_index_file = HBDxBase_index_path
71
+
72
+ os.system(
73
+ f"bowtie -a --norc -v {number_mm} -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
74
+ --al {out_prefix + str(number_mm) + 'mm2HBDxBase__mapped.fa'} \
75
+ --un {out_prefix + str(number_mm) + 'mm2HBDxBase__unmapped.fa'} \
76
+ {out_prefix + str(number_mm) + 'mm2HBDxBase.txt'}"
77
+ )
78
+ @log_time(log)
79
+ def map_seq_2_HBDxBase_pseudo(
80
+ number_mm,
81
+ fasta_in_file,
82
+ out_prefix
83
+ ):
84
+
85
+ bowtie_index_file = HBDxBase_pseudo_index_path
86
+
87
+ os.system(
88
+ f"bowtie -a --norc -v {number_mm} -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
89
+ --al {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo__mapped.fa'} \
90
+ --un {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo__unmapped.fa'} \
91
+ {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo.txt'}"
92
+ )
93
+ # -a Report all valid alignments per read
94
+ # --norc No mapping to reverse strand
95
+ # -v Report alignments with at most <int> mismatches
96
+ # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
97
+ # -suppress Suppress columns of output in the default output mode
98
+ # -x The basename of the Bowtie, or Bowtie 2, index to be searched
99
+
100
+ @log_time(log)
101
+ def map_seq_2_adapters(
102
+ fasta_in_file,
103
+ out_prefix
104
+ ):
105
+
106
+ bowtie_index_file = adapter_index_path
107
+
108
+ os.system(
109
+ f"bowtie -a --best --strata --norc -v 3 -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
110
+ --al {out_prefix + '3mm2adapters__mapped.fa'} \
111
+ --un {out_prefix + '3mm2adapters__unmapped.fa'} \
112
+ {out_prefix + '3mm2adapters.txt'}"
113
+ )
114
+ # -a --best --strata Specifying --strata in addition to -a and --best causes bowtie to report only those alignments in the best alignment “stratum”. The alignments in the best stratum are those having the least number of mismatches
115
+ # --norc No mapping to reverse strand
116
+ # -v Report alignments with at most <int> mismatches
117
+ # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
118
+ # -suppress Suppress columns of output in the default output mode
119
+ # -x The basename of the Bowtie, or Bowtie 2, index to be searched
120
+
121
+
122
+ @log_time(log)
123
+ def map_seq_2_genome(
124
+ fasta_in_file,
125
+ out_prefix
126
+ ):
127
+
128
+ bowtie_index_file = genome_index_path
129
+
130
+ os.system(
131
+ f"bowtie -a -v 0 -f -m 100 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
132
+ --max {out_prefix + '0mm2genome__toomanyalign.fa'} \
133
+ --un {out_prefix + '0mm2genome__unmapped.fa'} \
134
+ {out_prefix + '0mm2genome.txt'}"
135
+ )
136
+ # -a Report all valid alignments per read
137
+ # -v Report alignments with at most <int> mismatches
138
+ # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
139
+ # -m Suppress all alignments for a particular read if more than <int> reportable alignments exist for it
140
+ # -suppress Suppress columns of output in the default output mode
141
+ # -x The basename of the Bowtie, or Bowtie 2, index to be searched
142
+
143
+
144
+ @log_time(log)
145
+ def map_seq_2_bacterial_viral(
146
+ fasta_in_file,
147
+ out_prefix
148
+ ):
149
+
150
+ bowtie_index_file = bacterial_index_path
151
+
152
+ os.system(
153
+ f"bowtie -a -v 0 -f -m 10 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
154
+ --al {out_prefix + '0mm2bacterial__mapped.fa'} \
155
+ --max {out_prefix + '0mm2bacterial__toomanyalign.fa'} \
156
+ --un {out_prefix + '0mm2bacterial__unmapped.fa'} \
157
+ {out_prefix + '0mm2bacterial.txt'}"
158
+ )
159
+
160
+
161
+ bowtie_index_file = viral_index_path
162
+
163
+ os.system(
164
+ f"bowtie -a -v 0 -f -m 10 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \
165
+ --al {out_prefix + '0mm2viral__mapped.fa'} \
166
+ --max {out_prefix + '0mm2viral__toomanyalign.fa'} \
167
+ --un {out_prefix + '0mm2viral__unmapped.fa'} \
168
+ {out_prefix + '0mm2viral.txt'}"
169
+ )
170
+ # -a Report all valid alignments per read
171
+ # -v Report alignments with at most <int> mismatches
172
+ # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense
173
+ # -m Suppress all alignments for a particular read if more than <int> reportable alignments exist for it
174
+ # -suppress Suppress columns of output in the default output mode
175
+ # -x The basename of the Bowtie, or Bowtie 2, index to be searched
176
+
177
+
178
+
179
+
180
+
181
+ #%%
182
+ ######################################################################################################
183
+ # mapping pipeline
184
+ ######################################################################################################
185
+ @log_time(log)
186
+ def main(sequence_file):
187
+ """Executes 'map_2_HBDxBase'. Maps input sequences to HBDxBase, the human genome, and a collection of viral and bacterial genomes.
188
+
189
+ Uses:
190
+
191
+ - HBDxBase_index_path
192
+ - HBDxBase_pseudo_index_path
193
+ - genome_index_path
194
+ - bacterial_index_path
195
+ - viral_index_path
196
+ - sequence_file
197
+
198
+ """
199
+
200
+ prepare_input_files(sequence_file)
201
+
202
+ # sequential mm mapping to HBDxBase
203
+ print('-------- map to HBDxBase --------')
204
+
205
+ print('-------- mapping seqs (0 mm) --------')
206
+ map_seq_2_HBDxBase(
207
+ 0,
208
+ 'seqs.fa',
209
+ 'tmp_seqs'
210
+ )
211
+
212
+ print('-------- mapping seqs (1 mm) --------')
213
+ map_seq_2_HBDxBase(
214
+ 1,
215
+ 'tmp_seqs0mm2HBDxBase__unmapped.fa',
216
+ 'tmp_seqs'
217
+ )
218
+
219
+ print('-------- mapping seqs (2 mm) --------')
220
+ map_seq_2_HBDxBase(
221
+ 2,
222
+ 'tmp_seqs1mm2HBDxBase__unmapped.fa',
223
+ 'tmp_seqs'
224
+ )
225
+
226
+ print('-------- mapping seqs (3 mm) --------')
227
+ map_seq_2_HBDxBase(
228
+ 3,
229
+ 'tmp_seqs2mm2HBDxBase__unmapped.fa',
230
+ 'tmp_seqs'
231
+ )
232
+
233
+ # sequential mm mapping to Pseudo-HBDxBase
234
+ print('-------- map to Pseudo-HBDxBase --------')
235
+
236
+ print('-------- mapping seqs (0 mm) --------')
237
+ map_seq_2_HBDxBase_pseudo(
238
+ 0,
239
+ 'tmp_seqs3mm2HBDxBase__unmapped.fa',
240
+ 'tmp_seqs'
241
+ )
242
+
243
+ print('-------- mapping seqs (1 mm) --------')
244
+ map_seq_2_HBDxBase_pseudo(
245
+ 1,
246
+ 'tmp_seqs0mm2HBDxBase_pseudo__unmapped.fa',
247
+ 'tmp_seqs'
248
+ )
249
+
250
+ print('-------- mapping seqs (2 mm) --------')
251
+ map_seq_2_HBDxBase_pseudo(
252
+ 2,
253
+ 'tmp_seqs1mm2HBDxBase_pseudo__unmapped.fa',
254
+ 'tmp_seqs'
255
+ )
256
+
257
+ print('-------- mapping seqs (3 mm) --------')
258
+ map_seq_2_HBDxBase_pseudo(
259
+ 3,
260
+ 'tmp_seqs2mm2HBDxBase_pseudo__unmapped.fa',
261
+ 'tmp_seqs'
262
+ )
263
+
264
+
265
+ # concatenate files
266
+ print('-------- concatenate mapping files --------')
267
+ os.system("cat tmp_seqs0mm2HBDxBase.txt tmp_seqs1mm2HBDxBase.txt tmp_seqs2mm2HBDxBase.txt tmp_seqs3mm2HBDxBase.txt tmp_seqs0mm2HBDxBase_pseudo.txt tmp_seqs1mm2HBDxBase_pseudo.txt tmp_seqs2mm2HBDxBase_pseudo.txt tmp_seqs3mm2HBDxBase_pseudo.txt > seqsmapped2HBDxBase_combined.txt")
268
+
269
+ print('\n')
270
+
271
+ # mapping to adapters (allowing for 3 mms)
272
+ print('-------- map to adapters (3 mm) --------')
273
+ map_seq_2_adapters(
274
+ 'seqs.fa',
275
+ 'tmp_seqs'
276
+ )
277
+
278
+ # mapping to genome (more than 50 alignments are not reported)
279
+ print('-------- map to human genome --------')
280
+
281
+ print('-------- mapping seqs (0 mm) --------')
282
+ map_seq_2_genome(
283
+ 'seqs.fa',
284
+ 'tmp_seqs'
285
+ )
286
+
287
+
288
+ ## concatenate files
289
+ print('-------- concatenate mapping files --------')
290
+ os.system("cp tmp_seqs0mm2genome.txt seqsmapped2genome_combined.txt")
291
+
292
+ print('\n')
293
+
294
+ ## intersect genome mapping hits with TE.bed
295
+ print('-------- intersect genome mapping hits with TE.bed --------')
296
+ # convert to BED format
297
+ os.system("awk 'BEGIN {FS= \"\t\"; OFS=\"\t\"} {print $3, $4, $4+length($5)-1, $5, 111, $2}' seqsmapped2genome_combined.txt > tmp_seqsmapped2genome_combined.bed")
298
+ # intersect with TE.bed (force strandedness -> fetch only sRNA_sequence and TE_name -> aggregate TE annotation on sequences)
299
+ os.system(f"bedtools intersect -a tmp_seqsmapped2genome_combined.bed -b {TE_path} -wa -wb -s" + "| awk '{print $4,$10}' | awk '{a[$1]=a[$1]\";\"$2} END {for(i in a) print i\"\t\"substr(a[i],2)}' > tmp_seqsmapped2genome_intersect_TE.txt")
300
+
301
+ # mapping to bacterial and viral genomes (more than 10 alignments are not reported)
302
+ print('-------- map to bacterial and viral genome --------')
303
+
304
+ print('-------- mapping seqs (0 mm) --------')
305
+ map_seq_2_bacterial_viral(
306
+ 'seqs.fa',
307
+ 'tmp_seqs'
308
+ )
309
+
310
+ ## concatenate files
311
+ print('-------- concatenate mapping files --------')
312
+ os.system("cat tmp_seqs0mm2bacterial.txt tmp_seqs0mm2viral.txt > seqsmapped2bacterialviral_combined.txt")
313
+
314
+ print('\n')
315
+
316
+
317
+
318
+
kba_pipeline/src/precursor_bins.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import pandas as pd
3
+ from typing import List
4
+ from collections.abc import Callable
5
+
6
+ def load_HBDxBase():
7
+ version = '_v4'
8
+ HBDxBase_file = f'../../references/HBDxBase/HBDxBase_all{version}.csv'
9
+ HBDxBase_df = pd.read_csv(HBDxBase_file, index_col=0)
10
+ HBDxBase_df.loc[:,'precursor_bins'] = (HBDxBase_df.precursor_length/25).astype(int)
11
+ return HBDxBase_df
12
+
13
+ def compute_dynamic_bin_size(precursor_len:int, name:str=None, min_bin_size:int=20, max_bin_size:int=30) -> List[int]:
14
+ '''
15
+ This function splits precursor to bins of size max_bin_size
16
+ if the last bin is smaller than min_bin_size, it will split the precursor to bins of size max_bin_size-1
17
+ This process will continue until the last bin is larger than min_bin_size.
18
+ if the min bin size is reached and still the last bin is smaller than min_bin_size, the last two bins will be merged.
19
+ so the maximimum bin size possible would be min_bin_size+(min_bin_size-1) = 39
20
+ '''
21
+ def split_precursor_to_bins(precursor_len,max_bin_size):
22
+ '''
23
+ This function splits precursor to bins of size max_bin_size
24
+ '''
25
+ precursor_bin_lens = []
26
+ for i in range(0, precursor_len, max_bin_size):
27
+ if i+max_bin_size < precursor_len:
28
+ precursor_bin_lens.append(max_bin_size)
29
+ else:
30
+ precursor_bin_lens.append(precursor_len-i)
31
+ return precursor_bin_lens
32
+
33
+ if precursor_len < min_bin_size:
34
+ return [precursor_len]
35
+ else:
36
+ precursor_bin_lens = split_precursor_to_bins(precursor_len,max_bin_size)
37
+ reduced_len = max_bin_size-1
38
+ while precursor_bin_lens[-1] < min_bin_size:
39
+ precursor_bin_lens = split_precursor_to_bins(precursor_len,reduced_len)
40
+ reduced_len -= 1
41
+ if reduced_len < min_bin_size:
42
+ #add last two bins together
43
+ precursor_bin_lens[-2] += precursor_bin_lens[-1]
44
+ precursor_bin_lens = precursor_bin_lens[:-1]
45
+ break
46
+
47
+ return precursor_bin_lens
48
+
49
+ def get_bin_no_from_pos(precursor_len:int,position:int,name:str=None,min_bin_size:int=20,max_bin_size:int=30) -> int:
50
+ '''
51
+ This function returns the bin number of a position in a precursor
52
+ bins start from 1
53
+ '''
54
+ precursor_bin_lens = compute_dynamic_bin_size(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size)
55
+ bin_no = 0
56
+ for i,bin_len in enumerate(precursor_bin_lens):
57
+ if position < bin_len:
58
+ bin_no = i
59
+ break
60
+ else:
61
+ position -= bin_len
62
+ return bin_no+1
63
+
64
+ def get_bin_with_max_overlap(row) -> int:
65
+ '''
66
+ This function returns the bin number of a fragment that overlaps the most with the fragment
67
+ '''
68
+ precursor_len = row.precursor_length
69
+ start_frag_pos = row.ref_start
70
+ frag_len = row.seq_length
71
+ name = row.precursor_name_full
72
+ min_bin_size = 20
73
+ max_bin_size = 30
74
+ precursor_bin_lens = compute_dynamic_bin_size(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size)
75
+ bin_no = 0
76
+ for i,bin_len in enumerate(precursor_bin_lens):
77
+ if start_frag_pos < bin_len:
78
+ #get overlap with curr bin
79
+ overlap = min(bin_len-start_frag_pos,frag_len)
80
+
81
+ if overlap > frag_len/2:
82
+ bin_no = i
83
+ else:
84
+ bin_no = i+1
85
+ break
86
+
87
+ else:
88
+ start_frag_pos -= bin_len
89
+ #get bin start and bin end
90
+ bin_start,bin_end = sum(precursor_bin_lens[:bin_no]),sum(precursor_bin_lens[:bin_no+1])
91
+ row['bin_start'] = bin_start
92
+ row['bin_end'] = bin_end
93
+ row['subclass_name'] = name + '_bin-' + str(bin_no+1)
94
+ row['precursor_bins'] = len(precursor_bin_lens)
95
+ row['subclass_name_bin_pos'] = name + '_binpos-' + str(bin_start) + ':' + str(bin_end)
96
+ return row
97
+
98
+ def convert_bin_to_pos(precursor_len:int,bin_no:int,bin_function:Callable=compute_dynamic_bin_size,name:str=None,min_bin_size:int=20,max_bin_size:int=30):
99
+ '''
100
+ This function returns the start and end position of a bin
101
+ '''
102
+ precursor_bin_lens = bin_function(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size)
103
+ start_pos = 0
104
+ end_pos = 0
105
+ for i,bin_len in enumerate(precursor_bin_lens):
106
+ if i+1 == bin_no:
107
+ end_pos = start_pos+bin_len
108
+ break
109
+ else:
110
+ start_pos += bin_len
111
+ return start_pos,end_pos
112
+
113
+ #main
114
+ if __name__ == '__main__':
115
+ #read hbdxbase
116
+ HBDxBase_df = load_HBDxBase()
117
+ min_bin_size = 20
118
+ max_bin_size = 30
119
+ #select indices of precurosrs that include 'rRNA' but not 'pseudo'
120
+ rRNA_df = HBDxBase_df[HBDxBase_df.index.str.contains('rRNA') * ~HBDxBase_df.index.str.contains('pseudo')]
121
+
122
+ #get bin of index 1
123
+ bins = compute_dynamic_bin_size(len(rRNA_df.iloc[0].sequence),rRNA_df.iloc[0].name,min_bin_size,max_bin_size)
124
+ bin_no = get_bin_no_from_pos(len(rRNA_df.iloc[0].sequence),name=rRNA_df.iloc[0].name,position=1)
125
+ annotation_bin = get_bin_with_max_overlap(len(rRNA_df.iloc[0].sequence),start_frag_pos=1,frag_len=50,name=rRNA_df.iloc[0].name)
126
+
127
+ # %%
kba_pipeline/src/utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import errno
4
+ from pathlib import Path
5
+ from Bio.SeqIO.FastaIO import SimpleFastaParser
6
+ from datetime import datetime
7
+ from getpass import getuser
8
+
9
+ import logging
10
+ from rich.logging import RichHandler
11
+ from functools import wraps
12
+ from time import perf_counter
13
+ from typing import Callable
14
+
15
+ default_path = '../outputs/'
16
+
17
+ def humanize_time(time_in_seconds: float, /) -> str:
18
+ """Return a nicely human-readable string of a time_in_seconds.
19
+
20
+ Parameters
21
+ ----------
22
+ time_in_seconds : float
23
+ Time in seconds, (not full seconds).
24
+
25
+ Returns
26
+ -------
27
+ str
28
+ A description of the time in one of the forms:
29
+ - 300.1 ms
30
+ - 4.5 sec
31
+ - 5 min 43.1 sec
32
+ """
33
+ sgn = "" if time_in_seconds >= 0 else "- "
34
+ time_in_seconds = abs(time_in_seconds)
35
+ if time_in_seconds < 1:
36
+ return f"{sgn}{time_in_seconds*1e3:.1f} ms"
37
+ elif time_in_seconds < 60:
38
+ return f"{sgn}{time_in_seconds:.1f} sec"
39
+ else:
40
+ return f"{sgn}{int(time_in_seconds//60)} min {time_in_seconds%60:.1f} sec"
41
+
42
+
43
+ class log_time:
44
+ """A decorator / context manager to log the time a certain function / code block took.
45
+
46
+ Usage either with:
47
+
48
+ @log_time(log)
49
+ def function_getting_logged_every_time(…):
50
+
51
+
52
+ producing:
53
+
54
+ function_getting_logged_every_time took 5 sec.
55
+
56
+ or:
57
+
58
+ with log_time(log, "Name of this codeblock"):
59
+
60
+
61
+ producing:
62
+
63
+ Name of this codeblock took 5 sec.
64
+ """
65
+
66
+ def __init__(self, logger: logging.Logger, name: str = None):
67
+ """
68
+ Parameters
69
+ ----------
70
+ logger : logging.Logger
71
+ The logger to use for logging the time, if None use print.
72
+ name : str, optional
73
+ The name in the message, when used as a decorator this defaults to the function name, by default None
74
+ """
75
+ self.logger = logger
76
+ self.name = name
77
+
78
+ def __call__(self, func: Callable):
79
+ if self.name is None:
80
+ self.name = func.__qualname__
81
+
82
+ @wraps(func)
83
+ def inner(*args, **kwds):
84
+ with self:
85
+ return func(*args, **kwds)
86
+
87
+ return inner
88
+
89
+ def __enter__(self):
90
+ self.start_time = perf_counter()
91
+
92
+ def __exit__(self, *exc):
93
+ self.exit_time = perf_counter()
94
+
95
+ time_delta = humanize_time(self.exit_time - self.start_time)
96
+ if self.logger is None:
97
+ print(f"{self.name} took {time_delta}.")
98
+ else:
99
+ self.logger.info(f"{self.name} took {time_delta}.")
100
+
101
+
102
+ def write_2_log(log_file):
103
+ # Setup logging
104
+ log_file_handler = logging.FileHandler(log_file)
105
+ log_file_handler.setLevel(logging.INFO)
106
+ log_file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
107
+ log_rich_handler = RichHandler()
108
+ log_rich_handler.setLevel(logging.INFO) #cli_args.log_level
109
+ log_rich_handler.setFormatter(logging.Formatter("%(message)s"))
110
+ logging.basicConfig(level=logging.INFO, datefmt="[%X]", handlers=[log_file_handler, log_rich_handler])
111
+
112
+
113
+ def fasta2df(path):
114
+ with open(path) as fasta_file:
115
+ identifiers = []
116
+ seqs = []
117
+ for header, sequence in SimpleFastaParser(fasta_file):
118
+ identifiers.append(header)
119
+ seqs.append(sequence)
120
+
121
+ fasta_df = pd.DataFrame(seqs, identifiers, columns=['sequence'])
122
+ fasta_df['sequence'] = fasta_df.sequence.apply(lambda x: x.replace('U','T'))
123
+ return fasta_df
124
+
125
+
126
+
127
+ def fasta2df_subheader(path, id_pos):
128
+ with open(path) as fasta_file:
129
+ identifiers = []
130
+ seqs = []
131
+ for header, sequence in SimpleFastaParser(fasta_file):
132
+ identifiers.append(header.split(None)[id_pos])
133
+ seqs.append(sequence)
134
+
135
+ fasta_df = pd.DataFrame(seqs, identifiers, columns=['sequence'])
136
+ fasta_df['sequence'] = fasta_df.sequence.apply(lambda x: x.replace('U','T'))
137
+ return fasta_df
138
+
139
+
140
+
141
+ def build_bowtie_index(bowtie_index_file):
142
+ #index_example = Path(bowtie_index_file + '.1.ebwt')
143
+ #if not index_example.is_file():
144
+ print('-------- index is build --------')
145
+ os.system(f"bowtie-build {bowtie_index_file + '.fa'} {bowtie_index_file}")
146
+ #else: print('-------- previously built index is used --------')
147
+
148
+
149
+
150
+ def make_output_dir(fasta_file):
151
+ output_dir = default_path + datetime.now().strftime('%Y-%m-%d') + ('__') + fasta_file.replace('.fasta', '').replace('.fa', '') + '/'
152
+ try:
153
+ os.makedirs(output_dir)
154
+ except OSError as e:
155
+ if e.errno != errno.EEXIST:
156
+ raise # This was not a "directory exist" error..
157
+ return output_dir
158
+
159
+
160
+ def reverse_complement(seq):
161
+ complement = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}
162
+ return ''.join([complement[base] for base in seq[::-1]])
163
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anndata==0.8.0
2
+ dill==0.3.6
3
+ hydra-core==1.3.0
4
+ imbalanced-learn==0.9.1
5
+ matplotlib==3.5.3
6
+ numpy==1.22.3
7
+ omegaconf==2.2.2
8
+ pandas==1.5.2
9
+ plotly==5.10.0
10
+ PyYAML==6.0
11
+ rich==12.6.0
12
+ viennarna==2.5.0a5
13
+ scanpy==1.9.1
14
+ scikit_learn==1.2.0
15
+ skorch==0.12.1
16
+ torch==1.10.1
17
+ tensorboard==2.11.2
18
+ Levenshtein==0.21.0
scripts/test_inference_api.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transforna import predict_transforna, predict_transforna_all_models
2
+
3
+ seqs = [
4
+ 'AACGAAGCTCGACTTTTAAGG',
5
+ 'GTCCACCCCAAAGCGTAGG']
6
+
7
+ path_to_models = '/path/to/tcga/models/'
8
+ sc_preds_id_df = predict_transforna_all_models(seqs,path_to_models = path_to_models) #/models/tcga/
9
+ #%%
10
+ #get sc predictions for models trained on id (in distribution)
11
+ sc_preds_id_df = predict_transforna(seqs, model="seq",trained_on='id',path_to_models = path_to_models)
12
+ #get sc predictions for models trained on full (all sub classes)
13
+ sc_preds_df = predict_transforna(seqs, model="seq",path_to_models = path_to_models)
14
+ #predict using models trained on major class
15
+ mc_preds_df = predict_transforna(seqs, model="seq",mc_or_sc='major_class',path_to_models = path_to_models)
16
+ #get logits
17
+ logits_df = predict_transforna(seqs, model="seq",logits_flag=True,path_to_models = path_to_models)
18
+ #get embedds
19
+ embedd_df = predict_transforna(seqs, model="seq",embedds_flag=True,path_to_models = path_to_models)
20
+ #get the top 4 similar sequences
21
+ sim_df = predict_transforna(seqs, model="seq",similarity_flag=True,n_sim=4,path_to_models = path_to_models)
22
+ #get umaps
23
+ umaps_df = predict_transforna(seqs, model="seq",umaps_flag=True,path_to_models = path_to_models)
24
+
25
+
26
+ all_preds_df = predict_transforna_all_models(seqs,path_to_models=path_to_models)
27
+ all_preds_df
28
+
29
+ # %%
scripts/train.sh ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data_time for hydra output folder
2
+ get_data_time(){
3
+ date=$(ls outputs/ | head -n 1)
4
+ time=$(ls outputs/*/ | head -n 1)
5
+ date=$date
6
+ time=$time
7
+ }
8
+
9
+ train_model(){
10
+ python -m transforna --config-dir="/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/conf"\
11
+ model_name=$1 trained_on=$2 num_replicates=$4
12
+
13
+ get_data_time
14
+ #rename the folder to model_name
15
+ mv outputs/$date/$time outputs/$date/$3
16
+ ls outputs/$date/
17
+ rm -rf models/tcga/TransfoRNA_${2^^}/$5/$3
18
+
19
+
20
+
21
+ mv -f outputs/$date/$3 models/tcga/TransfoRNA_${2^^}/$5/
22
+ rm -rf outputs/
23
+
24
+ }
25
+ #activate transforna environment
26
+ eval "$(conda shell.bash hook)"
27
+ conda activate transforna
28
+
29
+ #create the models folder if it does not exist
30
+ if [[ ! -d "models/tcga/TransfoRNA_ID/major_class" ]]; then
31
+ mkdir -p models/tcga/TransfoRNA_ID/major_class
32
+ fi
33
+ if [[ ! -d "models/tcga/TransfoRNA_FULL/sub_class" ]]; then
34
+ mkdir -p models/tcga/TransfoRNA_FULL/sub_class
35
+ fi
36
+ if [[ ! -d "models/tcga/TransfoRNA_ID/sub_class" ]]; then
37
+ mkdir -p models/tcga/TransfoRNA_ID/sub_class
38
+ fi
39
+ if [[ ! -d "models/tcga/TransfoRNA_FULL/major_class" ]]; then
40
+ mkdir -p models/tcga/TransfoRNA_FULL/major_class
41
+ fi
42
+ #remove the outputs folder
43
+ rm -rf outputs
44
+
45
+
46
+ #define models
47
+ models=("seq" "seq-seq" "seq-rev" "seq-struct" "baseline")
48
+ models_capitalized=("Seq" "Seq-Seq" "Seq-Rev" "Seq-Struct" "Baseline")
49
+
50
+
51
+ num_replicates=5
52
+
53
+
54
+ ############train major_class_hico
55
+
56
+ ##replace clf_target:str = 'sub_class_hico' to clf_target:str = 'major_class_hico' in ../conf/train_model_configs/tcga.py
57
+ sed -i "s/clf_target:str = 'sub_class_hico'/clf_target:str = 'major_class_hico'/g" conf/train_model_configs/tcga.py
58
+ #print the file content
59
+ cat conf/train_model_configs/tcga.py
60
+ #loop and train
61
+ for i in ${!models[@]}; do
62
+ echo "Training model ${models_capitalized[$i]} for id on major_class"
63
+ train_model ${models[$i]} id ${models_capitalized[$i]} $num_replicates "major_class"
64
+ echo "Training model ${models[$i]} for full on major_class"
65
+ train_model ${models[$i]} full ${models_capitalized[$i]} 1 "major_class"
66
+ done
67
+
68
+
69
+ ############train sub_class_hico
70
+
71
+ #replace clf_target:str = 'major_class_hico' to clf_target:str = 'sub_class_hico' in ../conf/train_model_configs/tcga.py
72
+ sed -i "s/clf_target:str = 'major_class_hico'/clf_target:str = 'sub_class_hico'/g" conf/train_model_configs/tcga.py
73
+
74
+ for i in ${!models[@]}; do
75
+ echo "Training model ${models_capitalized[$i]} for id on sub_class"
76
+ train_model ${models[$i]} id ${models_capitalized[$i]} $num_replicates "sub_class"
77
+ echo "Training model ${models[$i]} for full on sub_class"
78
+ train_model ${models[$i]} full ${models_capitalized[$i]} 1 "sub_class"
79
+ done
80
+
setup.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name='TransfoRNA',
5
+ version='0.0.1',
6
+ description='TransfoRNA: Navigating the Uncertainties of Small RNA Annotation with an Adaptive Machine Learning Strategy',
7
+ url='https://github.com/gitHBDX/TransfoRNA',
8
+ author='YasserTaha,JuliaJehn',
9
+ author_email='ytaha@hb-dx.com,jjehn@hb-dx.com,tsikosek@hb-dx.com',
10
+ classifiers=[
11
+ 'Development Status :: 4 - Beta',
12
+ 'Intended Audience :: Biological Researchers',
13
+ 'License :: OSI Approved :: MIT License',
14
+ 'Programming Language :: Python :: 3.9',
15
+ ],
16
+ packages=find_packages(include=['transforna', 'transforna.*']),
17
+ install_requires=[
18
+ "anndata==0.8.0",
19
+ "dill==0.3.6",
20
+ "hydra-core==1.3.0",
21
+ "imbalanced-learn==0.9.1",
22
+ "matplotlib==3.5.3",
23
+ "numpy==1.22.3",
24
+ "omegaconf==2.2.2",
25
+ "pandas==1.5.2",
26
+ "plotly==5.10.0",
27
+ "PyYAML==6.0",
28
+ "rich==12.6.0",
29
+ "viennarna>=2.5.0a5",
30
+ "scanpy==1.9.1",
31
+ "scikit_learn==1.2.0",
32
+ "skorch==0.12.1",
33
+ "torch==1.10.1",
34
+ "tensorboard==2.16.2",
35
+ "Levenshtein==0.21.0"
36
+ ],
37
+ python_requires='>=3.9',
38
+ #move yaml files to package
39
+ package_data={'': ['*.yaml']},
40
+ )
transforna/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .src.callbacks import *
2
+ from .src.inference import *
3
+ from .src.model import *
4
+ from .src.novelty_prediction import *
5
+ from .src.processing import *
6
+ from .src.train import *
7
+ from .src.utils import *
transforna/__main__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import warnings
5
+
6
+ import hydra
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from hydra.utils import instantiate
9
+ from omegaconf import DictConfig
10
+
11
+ from transforna import compute_cv, infer_benchmark, infer_tcga, train
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def add_config_to_sys_path():
19
+ cfg = HydraConfig.get()
20
+ config_path = [path["path"] for path in cfg.runtime.config_sources if path["schema"] == "file"][0]
21
+ sys.path.append(config_path)
22
+
23
+ #transforna could called from anywhere:
24
+ #python -m transforna --config-dir = /path/to/configs
25
+ @hydra.main(config_path='../conf', config_name="main_config")
26
+ def main(cfg: DictConfig) -> None:
27
+ add_config_to_sys_path()
28
+ #get path of hydra outputs folder
29
+ output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
30
+
31
+ path = os.getcwd()
32
+ #init train and model config
33
+ cfg['train_config'] = instantiate(cfg['train_config']).__dict__
34
+ cfg['model_config'] = instantiate(cfg['model_config']).__dict__
35
+
36
+ #update model config with the name of the model
37
+ cfg['model_config']["model_input"] = cfg["model_name"]
38
+
39
+ #inference or train
40
+ if cfg["inference"]:
41
+ logger.info(f"Started inference on {cfg['task']}")
42
+ if cfg['task'] == 'tcga':
43
+ return infer_tcga(cfg,path=path)
44
+ else:
45
+ return infer_benchmark(cfg,path=path)
46
+ else:
47
+ if cfg["cross_val"]:
48
+ compute_cv(cfg,path,output_dir=output_dir)
49
+
50
+ else:
51
+ train(cfg,path=path,output_dir=output_dir)
52
+
53
+ if __name__ == "__main__":
54
+ main()
transforna/bin/figure_scripts/figure_4_table_3.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #%%
4
+ #read all files ending with dist_df in bin/lc_files/
5
+ import pandas as pd
6
+ import glob
7
+ from plotly import graph_objects as go
8
+ from transforna import load,predict_transforna
9
+ all_df = pd.DataFrame()
10
+ files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/*lev_dist_df.csv')
11
+ for file in files:
12
+ df = pd.read_csv(file)
13
+ all_df = pd.concat([all_df,df])
14
+ all_df = all_df.drop(columns=['Unnamed: 0'])
15
+ all_df.loc[all_df.split.isnull(),'split'] = 'NA'
16
+
17
+ #%%
18
+ #draw a box plot for the Ensemble model for each of the splits using seaborn
19
+ ensemble_df = all_df[all_df.Model == 'Ensemble'].reset_index(drop=True)
20
+ #remove other_affixes
21
+ ensemble_df = ensemble_df[ensemble_df.split != 'other_affixes'].reset_index(drop=True)
22
+ #rename
23
+ ensemble_df['split'] = ensemble_df['split'].replace({'5\'A-affixes':'Putative 5’-adapter prefixes','Fused':'Recombined'})
24
+ ensemble_df['split'] = ensemble_df['split'].replace({'Relaxed-miRNA':'Isomirs'})
25
+ #%%
26
+ #plot the boxplot using seaborn
27
+ import seaborn as sns
28
+ import matplotlib.pyplot as plt
29
+ sns.set_theme(style="whitegrid")
30
+ sns.set(rc={'figure.figsize':(15,10)})
31
+ sns.set(font_scale=1.5)
32
+ order = ['LC-familiar','LC-novel','Random','Putative 5’-adapter prefixes','Recombined','NA','LOCO','Isomirs']
33
+ ax = sns.boxplot(x="split", y="NLD", data=ensemble_df, palette="Set3",order=order,showfliers = True)
34
+
35
+ #add Novelty Threshold line
36
+ ax.axhline(y=ensemble_df['Novelty Threshold'].mean(), color='g', linestyle='--',xmin=0,xmax=1)
37
+ #annotate mean of Novelty Threshold
38
+ ax.annotate('NLD threshold', xy=(1.5, ensemble_df['Novelty Threshold'].mean()), xytext=(1.5, ensemble_df['Novelty Threshold'].mean()-0.07), arrowprops=dict(facecolor='black', shrink=0.05))
39
+ #rename
40
+ ax.set_xticklabels(['LC-Familiar','LC-Novel','Random','5’-adapter artefacts','Recombined','NA','LOCO','IsomiRs'])
41
+ #add title
42
+ ax.set_facecolor('None')
43
+ plt.title('Levenshtein Distance Distribution per Split on LC')
44
+ ax.set(xlabel='Split', ylabel='Normalized Levenshtein Distance')
45
+ #save legend
46
+ plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.,facecolor=None,framealpha=0.0)
47
+ plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_no_out_boxplot.svg',dpi=400)
48
+ #tilt x axis labels
49
+ plt.xticks(rotation=-22.5)
50
+ #save svg
51
+ plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_seaboarn_boxplot.svg',dpi=1000)
52
+ ##save png
53
+ plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_seaboarn_boxplot.png',dpi=1000)
54
+ #%%
55
+ bars = [r for r in ax.get_children()]
56
+ colors = []
57
+ for c in bars[:-1]:
58
+ try: colors.append(c.get_facecolor())
59
+ except: pass
60
+ isomir_color = colors[len(order)-1]
61
+ isomir_color = [255*x for x in isomir_color]
62
+ #covert to rgb('r','g','b','a')
63
+ isomir_color = 'rgb(%s,%s,%s,%s)'%(isomir_color[0],isomir_color[1],isomir_color[2],isomir_color[3])
64
+
65
+ #%%
66
+ relaxed_mirna_df = all_df[all_df.split == 'Relaxed-miRNA']
67
+ models = relaxed_mirna_df.Model.unique()
68
+ percentage_dict = {}
69
+ for model in models:
70
+ model_df = relaxed_mirna_df[relaxed_mirna_df.Model == model]
71
+ #compute the % of sequences with NLD < Novelty Threshold for each model
72
+ percentage_dict[model] = len(model_df[model_df['NLD'] > model_df['Novelty Threshold']])/len(model_df)
73
+ percentage_dict[model]*=100
74
+
75
+ fig = go.Figure()
76
+ for model in ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev','Ensemble']:
77
+ fig.add_trace(go.Bar(x=[model],y=[percentage_dict[model]],name=model,marker_color=isomir_color))
78
+ #add percentage on top of each bar
79
+ fig.add_annotation(x=model,y=percentage_dict[model]+2,text='%s%%'%(round(percentage_dict[model],2)),showarrow=False)
80
+ #increase size of annotation
81
+ fig.update_annotations(dict(font_size=13))
82
+ #add title in the center
83
+ fig.update_layout(title='Percentage of Isomirs considered novel per model')
84
+ fig.update_layout(xaxis_tickangle=+22.5)
85
+ fig.update_layout(showlegend=False)
86
+ #make transparent background
87
+ fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
88
+ #y axis label
89
+ fig.update_yaxes(title_text='Percentage of Novel Sequences')
90
+ #save svg
91
+ fig.show()
92
+ #save svg
93
+ #fig.write_image('relaxed_mirna_novel_perc_lc_barplot.svg')
94
+ #%%
95
+ #here we explore the false familiar of the ood lc set
96
+ ood_df = pd.read_csv('/nfs/home/yat_ldap/VS_Projects/TransfoRNA/bin/lc_files/LC-novel_lev_dist_df.csv')
97
+ mapping_dict_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/subclass_to_annotation.json'
98
+ mapping_dict = load(mapping_dict_path)
99
+
100
+ LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/LC__ngs__DI_HB_GEL-23.1.2.h5ad'
101
+ ad = load(LC_path)
102
+ #%%
103
+ model = 'Ensemble'
104
+ ood_seqs = ood_df[(ood_df.Model == model).values * (ood_df['Is Familiar?'] == True).values].Sequence.tolist()
105
+ ood_predicted_labels = ood_df[(ood_df.Model == model).values * (ood_df['Is Familiar?'] == True).values].Labels.tolist()
106
+ ood_actual_labels = ad.var.loc[ood_seqs]['subclass_name'].values.tolist()
107
+ from transforna import correct_labels
108
+ ood_predicted_labels = correct_labels(ood_predicted_labels,ood_actual_labels,mapping_dict)
109
+
110
+ #get indices where ood_predicted_labels == ood_actual_labels
111
+ correct_indices = [i for i, x in enumerate(ood_predicted_labels) if x != ood_actual_labels[i]]
112
+ #remove the indices from ood_seqs, ood_predicted_labels, ood_actual_labels
113
+ ood_seqs = [ood_seqs[i] for i in correct_indices]
114
+ ood_predicted_labels = [ood_predicted_labels[i] for i in correct_indices]
115
+ ood_actual_labels = [ood_actual_labels[i] for i in correct_indices]
116
+ #get the major class of the actual labels
117
+ ood_actual_major_class = [mapping_dict[label] if label in mapping_dict else 'None' for label in ood_actual_labels]
118
+ ood_predicted_major_class = [mapping_dict[label] if label in mapping_dict else 'None' for label in ood_predicted_labels ]
119
+ #get frequencies of each major class
120
+ from collections import Counter
121
+ ood_actual_major_class_freq = Counter(ood_actual_major_class)
122
+ ood_predicted_major_class_freq = Counter(ood_predicted_major_class)
123
+
124
+
125
+
126
+ # %%
127
+ import plotly.express as px
128
+ major_classes = list(ood_actual_major_class_freq.keys())
129
+
130
+ ood_seqs_len = [len(seq) for seq in ood_seqs]
131
+ ood_seqs_len_freq = Counter(ood_seqs_len)
132
+ fig = px.bar(x=list(ood_seqs_len_freq.keys()),y=list(ood_seqs_len_freq.values()))
133
+ fig.show()
134
+
135
+ #%%
136
+ import plotly.graph_objects as go
137
+ fig = go.Figure()
138
+ for major_class in major_classes:
139
+ len_dist = [len(ood_seqs[i]) for i, x in enumerate(ood_actual_major_class) if x == major_class]
140
+ len_dist_freq = Counter(len_dist)
141
+ fig.add_trace(go.Bar(x=list(len_dist_freq.keys()),y=list(len_dist_freq.values()),name=major_class))
142
+ #stack
143
+ fig.update_layout(barmode='stack')
144
+ #make transparent background
145
+ fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
146
+ #set y axis label to Count and x axis label to Length
147
+ fig.update_layout(yaxis_title='Count',xaxis_title='Length')
148
+ #set title
149
+ fig.update_layout(title_text="Length distribution of false familiar sequences per major class")
150
+ #save as svg
151
+ fig.write_image('false_familiar_length_distribution_per_major_class_stacked.svg')
152
+ fig.show()
153
+
154
+ # %%
155
+ #for each model, for each split, print Is Familiar? == True and print the number of sequences
156
+ for model in all_df.Model.unique():
157
+ print('\n\n')
158
+ model_df = all_df[all_df.Model == model]
159
+ num_hicos = 0
160
+ total_samples = 0
161
+ for split in ['LC-familiar','LC-novel','LOCO','NA','Relaxed-miRNA']:
162
+
163
+ split_df = model_df[model_df.split == split]
164
+ #print('Model: %s, Split: %s, Familiar: %s, Number of Sequences: %s'%(model,split,len(split_df[split_df['Is Familiar?'] == True]),len(split_df)))
165
+ #print model, split %
166
+ print('%s %s %s'%(model,split,len(split_df[split_df['Is Familiar?'] == True])/len(split_df)*100))
167
+ if split != 'LC-novel':
168
+ num_hicos+=len(split_df[split_df['Is Familiar?'] == True])
169
+ total_samples+=len(split_df)
170
+ #print % of hicos
171
+ print('%s %s %s'%(model,'HICO',num_hicos/total_samples*100))
172
+ print(total_samples)
173
+ # %%
transforna/bin/figure_scripts/figure_5_S10_table_4.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #in this file, the progression of the number of hicos per major class is computed per model
2
+ #this is done before ID, after FULL.
3
+ #%%
4
+ from transforna import load
5
+ from transforna import predict_transforna,predict_transforna_all_models
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+ import numpy as np
9
+
10
+ def compute_overlap_models_ensemble(full_df:pd.DataFrame,mapping_dict:dict):
11
+ full_copy_df = full_df.copy()
12
+ full_copy_df['MC_Labels'] = full_copy_df['Net-Label'].map(mapping_dict)
13
+ #filter is familiar
14
+ full_copy_df = full_copy_df[full_copy_df['Is Familiar?']].set_index('Sequence')
15
+ #count the predicted miRNAs per each Model
16
+ full_copy_df.groupby('Model').MC_Labels.value_counts()
17
+
18
+ #for eaach of the models and for each of the mc classes, get the overlap between the models predicting a certain mc and the ensemble having the same prediction
19
+ models = ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev']
20
+ mcs = full_copy_df.MC_Labels.value_counts().index.tolist()
21
+ mc_stats = {}
22
+ novel_resid = {}
23
+ mcs_predicted_by_only_one_model = {}
24
+ #add all mcs as keys to mc_stats and add all models as keys in every mc
25
+ for mc in mcs:
26
+ mc_stats[mc] = {}
27
+ novel_resid[mc] = {}
28
+ mcs_predicted_by_only_one_model[mc] = {}
29
+ for model in models:
30
+ mc_stats[mc][model] = 0
31
+ novel_resid[mc][model] = 0
32
+ mcs_predicted_by_only_one_model[mc][model] = 0
33
+
34
+ for mc in mcs:
35
+ ensemble_xrna = full_copy_df[full_copy_df.Model == 'Ensemble'].iloc[full_copy_df[full_copy_df.Model == 'Ensemble'].MC_Labels.str.contains(mc).values].index.tolist()
36
+ for model in models:
37
+ model_xrna = full_copy_df[full_copy_df.Model == model].iloc[full_copy_df[full_copy_df.Model == model].MC_Labels.str.contains(mc).values].index.tolist()
38
+ common_xrna = set(ensemble_xrna).intersection(set(model_xrna))
39
+ try:
40
+ mc_stats[mc][model] = len(common_xrna)/len(ensemble_xrna)
41
+ except ZeroDivisionError:
42
+ mc_stats[mc][model] = 0
43
+ #check how many sequences exist in ensemble but not in model
44
+ try:
45
+ novel_resid[mc][model] = len(set(ensemble_xrna).difference(set(model_xrna)))/len(ensemble_xrna)
46
+ except ZeroDivisionError:
47
+ novel_resid[mc][model] = 0
48
+ #check how many sequences exist in model and in ensemble but not in other models
49
+ other_models_xrna = []
50
+ for other_model in models:
51
+ if other_model != model:
52
+ other_models_xrna.extend(full_copy_df[full_copy_df.Model == other_model].iloc[full_copy_df[full_copy_df.Model == other_model].MC_Labels.str.contains(mc).values].index.tolist())
53
+ #check how many of model_xrna are not in other_models_xrna and are in ensemble_xrna
54
+ try:
55
+ mcs_predicted_by_only_one_model[mc][model] = len(set(model_xrna).difference(set(other_models_xrna)).intersection(set(ensemble_xrna)))/len(ensemble_xrna)
56
+ except ZeroDivisionError:
57
+ mcs_predicted_by_only_one_model[mc][model] = 0
58
+
59
+ return models,mc_stats,novel_resid,mcs_predicted_by_only_one_model
60
+
61
+
62
+ def plot_bar_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model):
63
+ #plot the result as bar plot per mc
64
+ import plotly.graph_objects as go
65
+ import numpy as np
66
+ import plotly.express as px
67
+ #square plot with mc classes on the x axis and the number of hicos on the y axis before ID, after ID, after FULL
68
+ #add cascaded bar plot for novel resid. one per mc per model
69
+ positions = np.arange(len(models))
70
+ fig = go.Figure()
71
+ for model in models:
72
+ fig.add_trace(go.Bar(
73
+ x=list(mc_stats.keys()),
74
+ y=[mc_stats[mc][model] for mc in mc_stats.keys()],
75
+ name=model,
76
+ marker_color=px.colors.qualitative.Plotly[models.index(model)]
77
+ ))
78
+
79
+ fig.add_trace(go.Bar(
80
+ x=list(mc_stats.keys()),
81
+ y=[mcs_predicted_by_only_one_model[mc][model] for mc in mc_stats.keys()],
82
+ #base = [mc_stats[mc][model] for mc in mc_stats.keys()],
83
+ name = 'novel',
84
+ marker_color='lightgrey'
85
+ ))
86
+ fig.update_layout(title='Overlap between Ensemble and other models per MC class')
87
+
88
+ return fig
89
+
90
+ def plot_heatmap_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model,what_to_plot='overlap'):
91
+ '''
92
+ This function computes a heatmap of the overlap between the ensemble and the other models per mc class
93
+ input:
94
+ models: list of models
95
+ mc_stats: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent overlap between each model and the ensemble
96
+ novel_resid: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent the % of sequences that are predicted by the ensemble as familiar but with specific model as novel
97
+ mcs_predicted_by_only_one_model: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent the % of sequences that are predicted as familiar by only one model
98
+ what_to_plot: string. 'overlap' for overlap between ensemble and other models, 'novel' for novel resid, 'only_one_model' for mcs predicted as novel by only one model
99
+
100
+ '''
101
+
102
+ if what_to_plot == 'overlap':
103
+ plot_dict = mc_stats
104
+ elif what_to_plot == 'novel':
105
+ plot_dict = novel_resid
106
+ elif what_to_plot == 'only_one_model':
107
+ plot_dict = mcs_predicted_by_only_one_model
108
+
109
+ import plotly.figure_factory as ff
110
+ fig = ff.create_annotated_heatmap(
111
+ z=[[plot_dict[mc][model] for mc in plot_dict.keys()] for model in models],
112
+ x=list(plot_dict.keys()),
113
+ y=models,
114
+ annotation_text=[[str(round(plot_dict[mc][model],2)) for mc in plot_dict.keys()] for model in models],
115
+ font_colors=['black'],
116
+ colorscale='Blues'
117
+ )
118
+ #set x axis order
119
+ order_x_axis = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','lncRNA','piRNA','YRNA','vtRNA']
120
+ fig.update_xaxes(type='category',categoryorder='array',categoryarray=order_x_axis)
121
+
122
+
123
+ fig.update_xaxes(side='bottom')
124
+ if what_to_plot == 'overlap':
125
+ fig.update_layout(title='Overlap between Ensemble and other models per MC class')
126
+ elif what_to_plot == 'novel':
127
+ fig.update_layout(title='Novel resid between Ensemble and other models per MC class')
128
+ elif what_to_plot == 'only_one_model':
129
+ fig.update_layout(title='MCs predicted by only one model')
130
+ return fig
131
+ #%%
132
+ #read TCGA
133
+ dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
134
+ models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
135
+ tcga_df = load(dataset_path_train)
136
+ tcga_df.set_index('sequence',inplace=True)
137
+ loco_hico_na_stats_before = {}
138
+ loco_hico_na_stats_before['HICO'] = sum(tcga_df['hico'])/tcga_df.shape[0]
139
+ before_hico_seqs = tcga_df['subclass_name'][tcga_df['hico'] == True].index.values
140
+ loco_hico_na_stats_before['LOCO'] = (sum(tcga_df.subclass_name != 'no_annotation') - sum(tcga_df['hico']))/tcga_df.shape[0]
141
+ before_loco_seqs = tcga_df[tcga_df.hico!=True][tcga_df.subclass_name != 'no_annotation'].index.values
142
+ loco_hico_na_stats_before['NA'] = sum(tcga_df.subclass_name == 'no_annotation')/tcga_df.shape[0]
143
+ before_na_seqs = tcga_df[tcga_df.subclass_name == 'no_annotation'].index.values
144
+ #load mapping dict
145
+ mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'
146
+ mapping_dict = load(mapping_dict_path)
147
+ hico_seqs = tcga_df['subclass_name'][tcga_df['hico'] == True].index.values
148
+ hicos_mc_before_id_stats = tcga_df.loc[hico_seqs].subclass_name.map(mapping_dict).value_counts()
149
+ #remove mcs with ; in them
150
+ #hicos_mc_before_id_stats = hicos_mc_before_id_stats[~hicos_mc_before_id_stats.index.str.contains(';')]
151
+ seqs_non_hico_id = tcga_df['subclass_name'][tcga_df['hico'] == False].index.values
152
+ id_df = predict_transforna(sequences=seqs_non_hico_id,model='Seq-Rev',trained_on='id',path_to_models=models_path)
153
+ id_df = id_df[id_df['Is Familiar?']].set_index('Sequence')
154
+ #print the percentage of sequences with no_annotation and with
155
+ print('Percentage of sequences with no annotation: %s'%(id_df[id_df['Net-Label'] == 'no_annotation'].shape[0]/id_df.shape[0]))
156
+ print('Percentage of sequences with annotation: %s'%(id_df[id_df['Net-Label'] != 'no_annotation'].shape[0]/id_df.shape[0]))
157
+
158
+ #%%
159
+ hicos_mc_after_id_stats = id_df['Net-Label'].map(mapping_dict).value_counts()
160
+ #remove mcs with ; in them
161
+ #hicos_mc_after_id_stats = hicos_mc_after_id_stats[~hicos_mc_after_id_stats.index.str.contains(';')]
162
+ #add missing major classes with zeros
163
+ for mc in hicos_mc_before_id_stats.index:
164
+ if mc not in hicos_mc_after_id_stats.index:
165
+ hicos_mc_after_id_stats[mc] = 0
166
+ hicos_mc_after_id_stats = hicos_mc_after_id_stats+hicos_mc_before_id_stats
167
+
168
+ #%%
169
+ seqs_non_hico_full = list(set(seqs_non_hico_id).difference(set(id_df.index.values)))
170
+ full_df = predict_transforna_all_models(sequences=seqs_non_hico_full,trained_on='full',path_to_models=models_path)
171
+ #UNCOMMENT TO COMPUTE BEFORE AND AFTER PER MC: table_4
172
+ #ensemble_df = full_df[full_df['Model']=='Ensemble']
173
+ #ensemble_df['Major Class'] = ensemble_df['Net-Label'].map(mapping_dict)
174
+ #new_hico_mcs= ensemble_df['Major Class'].value_counts()
175
+ #ann_hico_mcs = tcga_df[tcga_df['hico'] == True]['small_RNA_class_annotation'].value_counts()
176
+
177
+ #%%%
178
+ inspect_model = True
179
+ if inspect_model:
180
+ #from transforna import compute_overlap_models_ensemble,plot_heatmap_overlap_models_ensemble
181
+ models, mc_stats, novel_resid, mcs_predicted_by_only_one_model = compute_overlap_models_ensemble(full_df,mapping_dict)
182
+ fig = plot_heatmap_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model,what_to_plot='overlap')
183
+ fig.show()
184
+
185
+ #%%
186
+ df = full_df[full_df.Model == 'Ensemble']
187
+ df = df[df['Is Familiar?']].set_index('Sequence')
188
+ print('Percentage of sequences with no annotation: %s'%(df[df['Is Familiar?'] == False].shape[0]/df.shape[0]))
189
+ print('Percentage of sequences with annotation: %s'%(df[df['Is Familiar?'] == True].shape[0]/df.shape[0]))
190
+ hicos_mc_after_full_stats = df['Net-Label'].map(mapping_dict).value_counts()
191
+ #remove mcs with ; in them
192
+ #hicos_mc_after_full_stats = hicos_mc_after_full_stats[~hicos_mc_after_full_stats.index.str.contains(';')]
193
+ #add missing major classes with zeros
194
+ for mc in hicos_mc_after_id_stats.index:
195
+ if mc not in hicos_mc_after_full_stats.index:
196
+ hicos_mc_after_full_stats[mc] = 0
197
+ hicos_mc_after_full_stats = hicos_mc_after_full_stats + hicos_mc_after_id_stats
198
+
199
+ # %%
200
+ #reorder the index of the series
201
+ hicos_mc_before_id_stats = hicos_mc_before_id_stats.reindex(hicos_mc_after_full_stats.index)
202
+ hicos_mc_after_id_stats = hicos_mc_after_id_stats.reindex(hicos_mc_after_full_stats.index)
203
+ #plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot
204
+ #%%
205
+ #%%
206
+ training_mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','YRNA','lncRNA']
207
+ hicos_mc_before_id_stats_train = hicos_mc_before_id_stats[training_mcs]
208
+ hicos_mc_after_id_stats_train = hicos_mc_after_id_stats[training_mcs]
209
+ hicos_mc_after_full_stats_train = hicos_mc_after_full_stats[training_mcs]
210
+ #plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot
211
+ import plotly.graph_objects as go
212
+ import numpy as np
213
+ import plotly.io as pio
214
+ import plotly.express as px
215
+
216
+ #make a square plot with mc classes on the x axis and the number of hicos on the y axis before ID, after ID, after FULL
217
+ fig = go.Figure()
218
+ fig.add_trace(go.Bar(
219
+ x=hicos_mc_before_id_stats_train.index,
220
+ y=hicos_mc_before_id_stats_train.values,
221
+ name='Before ID',
222
+ marker_color='rgb(31, 119, 180)',
223
+ opacity = 0.5
224
+ ))
225
+ fig.add_trace(go.Bar(
226
+ x=hicos_mc_after_id_stats_train.index,
227
+ y=hicos_mc_after_id_stats_train.values,
228
+ name='After ID',
229
+ marker_color='rgb(31, 119, 180)',
230
+ opacity=0.75
231
+ ))
232
+ fig.add_trace(go.Bar(
233
+ x=hicos_mc_after_full_stats_train.index,
234
+ y=hicos_mc_after_full_stats_train.values,
235
+ name='After FULL',
236
+ marker_color='rgb(31, 119, 180)',
237
+ opacity=1
238
+ ))
239
+ #make log scale
240
+ fig.update_layout(
241
+ title='Progression of the Number of HICOs per Major Class',
242
+ xaxis_tickfont_size=14,
243
+ yaxis=dict(
244
+ title='Number of HICOs',
245
+ titlefont_size=16,
246
+ tickfont_size=14,
247
+ ),
248
+ xaxis=dict(
249
+ title='Major Class',
250
+ titlefont_size=16,
251
+ tickfont_size=14,
252
+ ),
253
+ legend=dict(
254
+ x=0.8,
255
+ y=1.0,
256
+ bgcolor='rgba(255, 255, 255, 0)',
257
+ bordercolor='rgba(255, 255, 255, 0)'
258
+ ),
259
+ barmode='group',
260
+ bargap=0.15,
261
+ bargroupgap=0.1
262
+ )
263
+ #make transparent background
264
+ fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
265
+ #log scalew
266
+ fig.update_yaxes(type="log")
267
+
268
+ fig.update_layout(legend=dict(
269
+ yanchor="top",
270
+ y=0.99,
271
+ xanchor="left",
272
+ x=0.01
273
+ ))
274
+ #tilt the x axis labels
275
+ fig.update_layout(xaxis_tickangle=22.5)
276
+ #set the range of the y axis
277
+ fig.update_yaxes(range=[0, 4.5])
278
+ fig.update_layout(legend=dict(
279
+ orientation="h",
280
+ yanchor="bottom",
281
+ y=1.02,
282
+ xanchor="right",
283
+ x=1
284
+ ))
285
+ fig.write_image("progression_hicos_per_mc_train.svg")
286
+ fig.show()
287
+ #%%
288
+ eval_mcs = ['miRNA','miscRNA','piRNA','vtRNA']
289
+ hicos_mc_before_id_stats_eval = hicos_mc_before_id_stats[eval_mcs]
290
+ hicos_mc_after_full_stats_eval = hicos_mc_after_full_stats[eval_mcs]
291
+
292
+ hicos_mc_after_full_stats_eval.index = hicos_mc_after_full_stats_eval.index + '*'
293
+ hicos_mc_before_id_stats_eval.index = hicos_mc_before_id_stats_eval.index + '*'
294
+ #%%
295
+ #plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot
296
+ import plotly.graph_objects as go
297
+ import numpy as np
298
+ import plotly.io as pio
299
+ import plotly.express as px
300
+
301
+ fig2 = go.Figure()
302
+ fig2.add_trace(go.Bar(
303
+ x=hicos_mc_before_id_stats_eval.index,
304
+ y=hicos_mc_before_id_stats_eval.values,
305
+ name='Before ID',
306
+ marker_color='rgb(31, 119, 180)',
307
+ opacity = 0.5
308
+ ))
309
+ fig2.add_trace(go.Bar(
310
+ x=hicos_mc_after_full_stats_eval.index,
311
+ y=hicos_mc_after_full_stats_eval.values,
312
+ name='After FULL',
313
+ marker_color='rgb(31, 119, 180)',
314
+ opacity=1
315
+ ))
316
+ #make log scale
317
+ fig2.update_layout(
318
+ title='Progression of the Number of HICOs per Major Class',
319
+ xaxis_tickfont_size=14,
320
+ yaxis=dict(
321
+ title='Number of HICOs',
322
+ titlefont_size=16,
323
+ tickfont_size=14,
324
+ ),
325
+ xaxis=dict(
326
+ title='Major Class',
327
+ titlefont_size=16,
328
+ tickfont_size=14,
329
+ ),
330
+ legend=dict(
331
+ x=0.8,
332
+ y=1.0,
333
+ bgcolor='rgba(255, 255, 255, 0)',
334
+ bordercolor='rgba(255, 255, 255, 0)'
335
+ ),
336
+ barmode='group',
337
+ bargap=0.15,
338
+ bargroupgap=0.1
339
+ )
340
+ #make transparent background
341
+ fig2.update_layout(plot_bgcolor='rgba(0,0,0,0)')
342
+ #log scalew
343
+ fig2.update_yaxes(type="log")
344
+
345
+ fig2.update_layout(legend=dict(
346
+ yanchor="top",
347
+ y=0.99,
348
+ xanchor="left",
349
+ x=0.01
350
+ ))
351
+ #tilt the x axis labels
352
+ fig2.update_layout(xaxis_tickangle=22.5)
353
+ #set the range of the y axis
354
+ fig2.update_yaxes(range=[0, 4.5])
355
+ #adjust bargap
356
+ fig2.update_layout(bargap=0.3)
357
+ fig2.update_layout(legend=dict(
358
+ orientation="h",
359
+ yanchor="bottom",
360
+ y=1.02,
361
+ xanchor="right",
362
+ x=1
363
+ ))
364
+ #fig2.write_image("progression_hicos_per_mc_eval.svg")
365
+ fig2.show()
366
+ # %%
367
+ #append df and df_after_id
368
+ df_all_hico = df.append(id_df)
369
+ loco_hico_na_stats_after = {}
370
+ loco_hico_na_stats_after['HICO from NA'] = sum(df_all_hico.index.isin(before_na_seqs))/tcga_df.shape[0]
371
+ loco_pred_df = df_all_hico[df_all_hico.index.isin(before_loco_seqs)]
372
+ loco_anns_pd = tcga_df.loc[loco_pred_df.index].subclass_name.str.split(';',expand=True)
373
+ loco_anns_pd = loco_anns_pd.apply(lambda x: x.str.lower())
374
+ #duplicate labels in loco_pred_df * times as the num of columns in loco_anns_pd
375
+ loco_pred_labels_df = pd.DataFrame(np.repeat(loco_pred_df['Net-Label'].values,loco_anns_pd.shape[1]).reshape(loco_pred_df.shape[0],loco_anns_pd.shape[1])).set_index(loco_pred_df.index)
376
+ loco_pred_labels_df = loco_pred_labels_df.apply(lambda x: x.str.lower())
377
+
378
+
379
+
380
+ #%%
381
+ trna_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('_trna')).any(axis=1)
382
+ trna_loco_pred_df = loco_pred_labels_df[trna_mask_df]
383
+ #get trna_loco_anns_pd
384
+ trna_loco_anns_pd = loco_anns_pd[trna_mask_df]
385
+ #for trna_loco_pred_df, remove what prepends the __ and what appends the last -
386
+ trna_loco_pred_df = trna_loco_pred_df.apply(lambda x: x.str.split('__').str[1])
387
+ trna_loco_pred_df = trna_loco_pred_df.apply(lambda x: x.str.split('-').str[:-1].str.join('-'))
388
+ #compute overlap between trna_loco_pred_df and trna_loco_anns_pd
389
+ #for every value in trna_loco_pred_df, check if is part of the corresponding position in trna_loco_anns_pd
390
+ num_hico_trna_from_loco = 0
391
+ for idx,row in trna_loco_pred_df.iterrows():
392
+ trna_label = row[0]
393
+ num_hico_trna_from_loco += trna_loco_anns_pd.loc[idx].apply(lambda x: x!=None and trna_label in x).any()
394
+
395
+
396
+ #%%
397
+ #check if 'mir' or 'let' is in any of the values per row. the columns are numbered from 0 to len(loco_anns_pd.columns)
398
+ mir_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('mir')).any(axis=1)
399
+ let_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('let')).any(axis=1)
400
+ mir_or_let_mask_df = mir_mask_df | let_mask_df
401
+ mir_or_let_loco_pred_df = loco_pred_labels_df[mir_or_let_mask_df]
402
+ mir_or_let_loco_anns_pd = loco_anns_pd[mir_or_let_mask_df]
403
+ #for each value in mir_or_let_loco_pred_df, if the value contains two '-', remove the last one and what comes after it
404
+ mir_or_let_loco_anns_pd = mir_or_let_loco_anns_pd.applymap(lambda x: '-'.join(x.split('-')[:-1]) if x!=None and x.count('-') == 2 else x)
405
+ mir_or_let_loco_pred_df = mir_or_let_loco_pred_df.applymap(lambda x: '-'.join(x.split('-')[:-1]) if x!=None and x.count('-') == 2 else x)
406
+ #compute overlap between mir_or_let_loco_pred_df and mir_or_let_loco_anns_pd
407
+ num_hico_mir_from_loco = sum((mir_or_let_loco_anns_pd == mir_or_let_loco_pred_df).any(axis=1))
408
+ #%%
409
+
410
+
411
+ #get rest_loco_anns_pd
412
+ rest_loco_pred_df = loco_pred_labels_df[~mir_or_let_mask_df & ~trna_mask_df]
413
+ rest_loco_anns_pd = loco_anns_pd[~mir_or_let_mask_df & ~trna_mask_df]
414
+
415
+ num_hico_bins_from_loco = 0
416
+ for idx,row in rest_loco_pred_df.iterrows():
417
+ rest_rna_label = row[0].split('-')[0]
418
+ try:
419
+ bin_no = int(row[0].split('-')[1])
420
+ except:
421
+ continue
422
+
423
+ num_hico_bins_from_loco += rest_loco_anns_pd.loc[idx].apply(lambda x: x!=None and rest_rna_label == x.split('-')[0] and abs(int(x.split('-')[1])- bin_no)<=1).any()
424
+
425
+ loco_hico_na_stats_after['HICO from LOCO'] = (num_hico_trna_from_loco + num_hico_mir_from_loco + num_hico_bins_from_loco)/tcga_df.shape[0]
426
+ loco_hico_na_stats_after['LOCO from NA'] = loco_hico_na_stats_before['NA'] - loco_hico_na_stats_after['HICO from NA']
427
+ loco_hico_na_stats_after['LOCO from LOCO'] = loco_hico_na_stats_before['LOCO'] - loco_hico_na_stats_after['HICO from LOCO']
428
+ loco_hico_na_stats_after['HICO'] = loco_hico_na_stats_before['HICO']
429
+
430
+ # %%
431
+
432
+ import plotly.graph_objects as go
433
+ import plotly.io as pio
434
+ import plotly.express as px
435
+
436
+ color_mapping = {}
437
+ for key in loco_hico_na_stats_before.keys():
438
+ if key.startswith('HICO'):
439
+ color_mapping[key] = "rgb(51,160,44)"
440
+ elif key.startswith('LOCO'):
441
+ color_mapping[key] = "rgb(178,223,138)"
442
+ else:
443
+ color_mapping[key] = "rgb(251,154,153)"
444
+ colors = list(color_mapping.values())
445
+ fig = go.Figure(data=[go.Pie(labels=list(loco_hico_na_stats_before.keys()), values=list(loco_hico_na_stats_before.values()),hole=.0,marker=dict(colors=colors),sort=False)])
446
+ fig.update_layout(title='Percentage of HICOs, LOCOs and NAs before ID')
447
+ fig.show()
448
+ #save figure as svg
449
+ #fig.write_image("pie_chart_before_id.svg")
450
+
451
+ # %%
452
+
453
+ color_mapping = {}
454
+ for key in loco_hico_na_stats_after.keys():
455
+ if key.startswith('HICO'):
456
+ color_mapping[key] = "rgb(51,160,44)"
457
+ elif key.startswith('LOCO'):
458
+ color_mapping[key] = "rgb(178,223,138)"
459
+
460
+ loco_hico_na_stats_after = {k: loco_hico_na_stats_after[k] for k in sorted(loco_hico_na_stats_after, key=lambda k: k.startswith('HICO'), reverse=True)}
461
+
462
+ fig = go.Figure(data=[go.Pie(labels=list(loco_hico_na_stats_after.keys()), values=list(loco_hico_na_stats_after.values()),hole=.0,marker=dict(colors=list(color_mapping.values())),sort=False)])
463
+ fig.update_layout(title='Percentage of HICOs, LOCOs and NAs after ID')
464
+ fig.show()
465
+ #save figure as svg
466
+ #fig.write_image("pie_chart_after_id.svg")
transforna/bin/figure_scripts/figure_6.ipynb ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/nfs/home/yat_ldap/conda/envs/hbdx/envs/transforna/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from transforna import load,predict_transforna_all_models,predict_transforna,fold_sequences\n",
19
+ "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'\n",
20
+ "lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'\n",
21
+ "tcga_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'\n",
22
+ "\n",
23
+ "tcga_df = load(tcga_path)\n",
24
+ "lc_df = load(lc_path)\n",
25
+ "\n",
26
+ "lc_df = lc_df[lc_df.sequence.str.len() <= 30]\n",
27
+ "\n",
28
+ "all_seqs = lc_df.sequence.tolist()+tcga_df.sequence.tolist()\n",
29
+ "\n",
30
+ "mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'\n",
31
+ "mapping_dict = load(mapping_dict_path)\n",
32
+ " "
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "predictions = predict_transforna_all_models(all_seqs,trained_on='full',path_to_models=models_path)\n",
42
+ "predictions.to_csv('predictions_lc_tcga.csv',index=False)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 2,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "#read predictions\n",
52
+ "predictions = load('predictions_lc_tcga.csv')"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "umaps = {}\n",
62
+ "models = predictions['Model'].unique()\n",
63
+ "for model in models:\n",
64
+ " if model == 'Ensemble':\n",
65
+ " continue\n",
66
+ " #get predictions\n",
67
+ " model_predictions = predictions[predictions['Model']==model]\n",
68
+ " #get is familiar rows\n",
69
+ " familiar_df = model_predictions[model_predictions['Is Familiar?']==True]\n",
70
+ " #get umap\n",
71
+ " umap_df = predict_transforna(model_predictions['Sequence'].tolist(),model=model,trained_on='full',path_to_models=models_path,umap_flag=True)\n",
72
+ " umaps[model] = umap_df"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "\n",
82
+ "import plotly.express as px\n",
83
+ "import numpy as np\n",
84
+ "mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n",
85
+ "#filter out the classes that contain ;\n",
86
+ "mcs = [mc for mc in mcs if ';' not in mc]\n",
87
+ "colors = px.colors.qualitative.Plotly\n",
88
+ "color_mapping = dict(zip(mcs,colors))\n",
89
+ "for model,umap_df in umaps.items():\n",
90
+ " umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n",
91
+ " umap_df_copy = umap_df.copy()\n",
92
+ " #remove rows with Major Class containing ;\n",
93
+ " umap_df = umap_df[~umap_df['Major Class'].str.contains(';')]\n",
94
+ " fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n",
95
+ " =['Sequence'],title=model,\\\n",
96
+ " width = 800, height=800,color_discrete_map=color_mapping)\n",
97
+ " fig.update_traces(marker=dict(size=1))\n",
98
+ " #white background\n",
99
+ " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
100
+ " #only show UMAP1 from 4.3 to 11\n",
101
+ " fig.update_xaxes(range=[4.3,11])\n",
102
+ " #and UMAP2 from -2.3 to 6.8\n",
103
+ " fig.update_yaxes(range=[-2.3,6.8])\n",
104
+ " #fig.show()\n",
105
+ " fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.png')\n",
106
+ " fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.svg')\n"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "import plotly.express as px\n",
116
+ "import numpy as np\n",
117
+ "mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n",
118
+ "#filter out the classes that contain ;\n",
119
+ "mcs = [mc for mc in mcs if ';' not in mc]\n",
120
+ "colors = px.colors.qualitative.Plotly + px.colors.qualitative.Light24\n",
121
+ "color_mapping = dict(zip(mcs,colors))\n",
122
+ "for model,umap_df in umaps.items():\n",
123
+ " umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n",
124
+ " umap_df_copy = umap_df.copy()\n",
125
+ " #remove rows with Major Class containing ;\n",
126
+ " umap_df = umap_df[~umap_df['Major Class'].str.contains(';')]\n",
127
+ " fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n",
128
+ " =['Sequence'],title=model,\\\n",
129
+ " width = 800, height=800,color_discrete_map=color_mapping)\n",
130
+ " fig.update_traces(marker=dict(size=1))\n",
131
+ " #white background\n",
132
+ " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
133
+ " #fig.show()\n",
134
+ " fig.write_image(f'lc_figures/lc_tcga_umap_{model}.png')\n",
135
+ " fig.write_image(f'lc_figures/lc_tcga_umap_{model}.svg')\n"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "#plot umap using px.scatter for each model\n",
145
+ "import plotly.express as px\n",
146
+ "import numpy as np\n",
147
+ "mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n",
148
+ "#filter out the classes that contain ;\n",
149
+ "mcs = [mc for mc in mcs if ';' not in mc]\n",
150
+ "colors = px.colors.qualitative.Plotly\n",
151
+ "color_mapping = dict(zip(mcs,colors))\n",
152
+ "umap_df = umaps['Seq']\n",
153
+ "umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n",
154
+ "umap_df_copy = umap_df.copy()\n",
155
+ "#display points contained within the circle at center (7.9,2.5) and radius 4.3\n",
156
+ "umap_df_copy['distance'] = np.sqrt((umap_df_copy['UMAP1']-7.9)**2+(umap_df_copy['UMAP2']-2.5)**2)\n",
157
+ "umap_df_copy = umap_df_copy[umap_df_copy['distance']<=4.3]\n",
158
+ "#remove rows with Major Class containing ;\n",
159
+ "umap_df_copy = umap_df_copy[~umap_df_copy['Major Class'].str.contains(';')]\n",
160
+ "fig = px.scatter(umap_df_copy,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n",
161
+ " =['Sequence'],title=model,\\\n",
162
+ " width = 800, height=800,color_discrete_map=color_mapping)\n",
163
+ "fig.update_traces(marker=dict(size=1))\n",
164
+ "#white background\n",
165
+ "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
166
+ "fig.show()\n",
167
+ "#fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.png')\n",
168
+ "#fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.svg')\n"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "#plot\n",
178
+ "sec_struct = fold_sequences(model_predictions['Sequence'].tolist())['structure_37']\n",
179
+ "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n",
180
+ "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 40,
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "umap_df = umaps['Seq-Struct']\n",
190
+ "fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color=sec_struct_ratio,hover_data=['Sequence'],title=model,\\\n",
191
+ " width = 800, height=800,color_continuous_scale='Viridis')\n",
192
+ "fig.update_traces(marker=dict(size=1))\n",
193
+ "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
194
+ "#save\n",
195
+ "fig.write_image(f'lc_figures/lc_tcga_umap_{model}_dot_bracket.png')\n",
196
+ "fig.write_image(f'lc_figures/lc_tcga_umap_{model}_dot_bracket.svg')"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": []
205
+ }
206
+ ],
207
+ "metadata": {
208
+ "kernelspec": {
209
+ "display_name": "transforna",
210
+ "language": "python",
211
+ "name": "python3"
212
+ },
213
+ "language_info": {
214
+ "codemirror_mode": {
215
+ "name": "ipython",
216
+ "version": 3
217
+ },
218
+ "file_extension": ".py",
219
+ "mimetype": "text/x-python",
220
+ "name": "python",
221
+ "nbconvert_exporter": "python",
222
+ "pygments_lexer": "ipython3",
223
+ "version": "3.9.18"
224
+ }
225
+ },
226
+ "nbformat": 4,
227
+ "nbformat_minor": 2
228
+ }
transforna/bin/figure_scripts/figure_S4.ipynb ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "\n",
10
+ "import pandas as pd\n",
11
+ "scores = {'major_class':{},'sub_class':{}}\n",
12
+ "models = ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev']\n",
13
+ "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n",
14
+ "for model1 in models:\n",
15
+ " summary_pd = pd.read_csv(models_path+'/major_class/'+model1+'/summary_pd.tsv',sep='\\t')\n",
16
+ " scores['major_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100)+')'\n",
17
+ " summary_pd = pd.read_csv(models_path+'/sub_class/'+model1+'/summary_pd.tsv',sep='\\t')\n",
18
+ " scores['sub_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100) +')'"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "data": {
28
+ "text/plain": [
29
+ "{'Baseline': '52.83789870060305+/- (1.0961119898709506)',\n",
30
+ " 'Seq': '97.70018230805728+/- (0.3819207447704567)',\n",
31
+ " 'Seq-Seq': '95.65091330992355+/- (0.4963151975035616)',\n",
32
+ " 'Seq-Struct': '97.71071590680333+/- (0.6173598637101496)',\n",
33
+ " 'Seq-Rev': '97.51224133899979+/- (0.3418133671042992)'}"
34
+ ]
35
+ },
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "scores['sub_class']"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 3,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "\n",
52
+ "import json\n",
53
+ "import pandas as pd\n",
54
+ "with open('/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json') as f:\n",
55
+ " mapping_dict = json.load(f)\n",
56
+ "\n",
57
+ "b_acc_sc_to_mc = {}\n",
58
+ "for model1 in models:\n",
59
+ " b_acc = []\n",
60
+ " for idx in range(5):\n",
61
+ " confusion_matrix = pd.read_csv(models_path+'/sub_class/'+model1+f'/embedds/confusion_matrix_{idx}.csv',sep=',',index_col=0)\n",
62
+ " confusion_matrix.index = confusion_matrix.index.map(mapping_dict)\n",
63
+ " confusion_matrix.columns = confusion_matrix.columns.map(mapping_dict)\n",
64
+ " confusion_matrix = confusion_matrix.groupby(confusion_matrix.index).sum().groupby(confusion_matrix.columns,axis=1).sum()\n",
65
+ " b_acc.append(confusion_matrix.values.diagonal().sum()/confusion_matrix.values.sum())\n",
66
+ " b_acc_sc_to_mc[model1] = str(pd.Series(b_acc).mean()*100)+'+/-'+' ('+str(pd.Series(b_acc).std()*100)+')'\n"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 4,
72
+ "metadata": {},
73
+ "outputs": [
74
+ {
75
+ "data": {
76
+ "text/plain": [
77
+ "{'Baseline': '89.6182558114013+/- (0.6372156071358975)',\n",
78
+ " 'Seq': '99.66714304286457+/- (0.1404591049684126)',\n",
79
+ " 'Seq-Seq': '99.40702944026852+/- (0.18268320317601783)',\n",
80
+ " 'Seq-Struct': '99.77114728744993+/- (0.06976258667467564)',\n",
81
+ " 'Seq-Rev': '99.70878801385821+/- (0.11954774341354062)'}"
82
+ ]
83
+ },
84
+ "execution_count": 4,
85
+ "metadata": {},
86
+ "output_type": "execute_result"
87
+ }
88
+ ],
89
+ "source": [
90
+ "b_acc_sc_to_mc"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 5,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "\n",
100
+ "import plotly.express as px\n",
101
+ "no_annotation_predictions = {}\n",
102
+ "for model1 in models:\n",
103
+ " #multiindex\n",
104
+ " no_annotation_predictions[model1] = pd.read_csv(models_path+'/sub_class/'+model1+'/embedds/no_annotation_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
105
+ " no_annotation_predictions[model1].set_index([('RNA Sequences','0')] ,inplace=True)\n",
106
+ " no_annotation_predictions[model1].index.name = 'RNA Sequences'\n",
107
+ " no_annotation_predictions[model1] = no_annotation_predictions[model1]['Logits'].idxmax(axis=1)\n"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "from transforna.src.utils.tcga_post_analysis_utils import correct_labels\n",
117
+ "import pandas as pd\n",
118
+ "correlation = pd.DataFrame(index=models,columns=models)\n",
119
+ "for model1 in models:\n",
120
+ " for model2 in models:\n",
121
+ " model1_predictions = correct_labels(no_annotation_predictions[model1],no_annotation_predictions[model2],mapping_dict)\n",
122
+ " is_equal = model1_predictions == no_annotation_predictions[model2].values\n",
123
+ " correlation.loc[model1,model2] = is_equal.sum()/len(is_equal)\n",
124
+ "font_size = 20\n",
125
+ "fig = px.imshow(correlation, color_continuous_scale='Blues')\n",
126
+ "#annotate\n",
127
+ "for i in range(len(models)):\n",
128
+ " for j in range(len(models)):\n",
129
+ " if i != j:\n",
130
+ " font = dict(color='black', size=font_size)\n",
131
+ " else:\n",
132
+ " font = dict(color='white', size=font_size) \n",
133
+ " \n",
134
+ " fig.add_annotation(\n",
135
+ " x=j, y=i,\n",
136
+ " text=str(round(correlation.iloc[i,j],2)),\n",
137
+ " showarrow=False,\n",
138
+ " font=font\n",
139
+ " )\n",
140
+ "\n",
141
+ "#set figure size: width and height\n",
142
+ "fig.update_layout(width=800, height=800)\n",
143
+ "\n",
144
+ "fig.update_layout(title='Correlation between models for each sub_class model')\n",
145
+ "#set x and y axis to Models\n",
146
+ "fig.update_xaxes(title_text='Models', tickfont=dict(size=font_size))\n",
147
+ "fig.update_yaxes(title_text='Models', tickfont=dict(size=font_size))\n",
148
+ "fig.show()\n",
149
+ "#save\n",
150
+ "fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.png')\n",
151
+ "fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.svg')"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 7,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "#create umap for every model from embedds folder\n",
161
+ "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n",
162
+ "\n",
163
+ "#read\n",
164
+ "sc_embedds = {}\n",
165
+ "mc_embedds = {}\n",
166
+ "sc_to_mc_labels = {}\n",
167
+ "sc_labels = {}\n",
168
+ "mc_labels = {}\n",
169
+ "for model in models:\n",
170
+ " df = pd.read_csv(models_path+'/sub_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
171
+ " sc_embedds[model] = df['RNA Embedds'].values\n",
172
+ " sc_labels[model] = df['Labels']['0']\n",
173
+ " sc_to_mc_labels[model] = sc_labels[model].map(mapping_dict).values\n",
174
+ "\n",
175
+ " #major class\n",
176
+ " df = pd.read_csv(models_path+'/major_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
177
+ " mc_embedds[model] = df['RNA Embedds'].values\n",
178
+ " mc_labels[model] = df['Labels']['0']"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 8,
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "import umap\n",
188
+ "#compute umap coordinates\n",
189
+ "sc_umap_coords = {}\n",
190
+ "mc_umap_coords = {}\n",
191
+ "for model in models:\n",
192
+ " sc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(sc_embedds[model])\n",
193
+ " mc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(mc_embedds[model])"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "#plot umap\n",
203
+ "import plotly.express as px\n",
204
+ "import numpy as np\n",
205
+ "\n",
206
+ "mcs = np.unique(sc_to_mc_labels[models[0]])\n",
207
+ "colors = px.colors.qualitative.Plotly\n",
208
+ "color_mapping = dict(zip(mcs,colors))\n",
209
+ "for model in models:\n",
210
+ " fig = px.scatter(x=sc_umap_coords[model][:,0],y=sc_umap_coords[model][:,1],color=sc_to_mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n",
211
+ "\n",
212
+ " hover_data={'Major Class':sc_labels[model],'Sub Class':sc_to_mc_labels[model]},color_discrete_map=color_mapping)\n",
213
+ "\n",
214
+ " fig.update_traces(marker=dict(size=1))\n",
215
+ " #white background\n",
216
+ " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
217
+ "\n",
218
+ " fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.svg')\n",
219
+ " fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.png')\n",
220
+ " fig.show()\n",
221
+ "\n",
222
+ " #plot umap for major class\n",
223
+ " fig = px.scatter(x=mc_umap_coords[model][:,0],y=mc_umap_coords[model][:,1],color=mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n",
224
+ "\n",
225
+ " hover_data={'Major Class':mc_labels[model]},color_discrete_map=color_mapping)\n",
226
+ " fig.update_traces(marker=dict(size=1))\n",
227
+ " #white background\n",
228
+ " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
229
+ "\n",
230
+ " fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.svg')\n",
231
+ " fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.png')\n",
232
+ " fig.show()"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "from transforna import fold_sequences\n",
242
+ "df = pd.read_csv(models_path+'/major_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
243
+ "sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n",
244
+ "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n",
245
+ "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n",
246
+ "fig = px.scatter(x=mc_umap_coords['Seq-Struct'][:,0],y=mc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n",
247
+ " hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n",
248
+ "\n",
249
+ "fig.update_traces(marker=dict(size=3))\n",
250
+ "#white background\n",
251
+ "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
252
+ "fig.show()\n",
253
+ "fig.write_image(models_path+'/major_class/Seq-Struct/figures/mc_umap_sec_struct.svg')\n"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "\n",
263
+ "from transforna import fold_sequences\n",
264
+ "df = pd.read_csv(models_path+'/sub_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n",
265
+ "sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n",
266
+ "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n",
267
+ "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n",
268
+ "fig = px.scatter(x=sc_umap_coords['Seq-Struct'][:,0],y=sc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n",
269
+ " hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n",
270
+ "\n",
271
+ "fig.update_traces(marker=dict(size=3))\n",
272
+ "#white background\n",
273
+ "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n",
274
+ "fig.show()\n",
275
+ "fig.write_image(models_path+'/sub_class/Seq-Struct/figures/sc_umap_sec_struct.svg')"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 11,
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "from transforna import Results_Handler,get_closest_ngbr_per_split\n",
285
+ "\n",
286
+ "splits = ['train','valid','test','ood','artificial','no_annotation']\n",
287
+ "splits_to_plot = ['test','ood','random','recombined','artificial_affix']\n",
288
+ "renaming_dict= {'test':'ID (test)','ood':'Rare sub-classes','random':'Random','artificial_affix':'Putative 5\\'-adapter prefixes','recombined':'Recombined'}\n",
289
+ "\n",
290
+ "lev_dist_df = pd.DataFrame()\n",
291
+ "for model in models:\n",
292
+ " results = Results_Handler(models_path+f'/sub_class/{model}/embedds',splits=splits,read_dataset=True)\n",
293
+ " results.append_loco_variants()\n",
294
+ " results.get_knn_model()\n",
295
+ " \n",
296
+ " #compute levenstein distance per split\n",
297
+ " for split in splits_to_plot:\n",
298
+ " split_seqs,split_labels,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,split)\n",
299
+ " #create df from split and levenstein distance\n",
300
+ " lev_dist_split_df = pd.DataFrame({'split':split,'lev_dist':lev_dist,'seqs':split_seqs,'labels':split_labels,'top_n_seqs':top_n_seqs,'top_n_labels':top_n_labels})\n",
301
+ " #rename \n",
302
+ " lev_dist_split_df['split'] = lev_dist_split_df['split'].map(renaming_dict)\n",
303
+ " lev_dist_split_df['model'] = model\n",
304
+ " #append \n",
305
+ " lev_dist_df = pd.concat([lev_dist_df,lev_dist_split_df],axis=0)\n",
306
+ "\n"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "#plot the distribution of lev_dist for each split for each model\n",
316
+ "model_thresholds = {'Baseline':0.267,'Seq':0.246,'Seq-Seq':0.272,'Seq-Struct': 0.242,'Seq-Rev':0.237}\n",
317
+ "model_aucs = {'Baseline':0.76,'Seq':0.97,'Seq-Seq':0.96,'Seq-Struct': 0.97,'Seq-Rev':0.97}\n",
318
+ "import seaborn as sns\n",
319
+ "import matplotlib.pyplot as plt\n",
320
+ "sns.set_theme(style=\"whitegrid\")\n",
321
+ "sns.set(rc={'figure.figsize':(15,10)})\n",
322
+ "sns.set(font_scale=1.5)\n",
323
+ "ax = sns.boxplot(x=\"model\", y=\"lev_dist\", hue=\"split\", data=lev_dist_df, palette=\"Set3\",order=models,showfliers = True)\n",
324
+ "#add title\n",
325
+ "ax.set_facecolor('None')\n",
326
+ "plt.title('Levenshtein Distance Distribution per Model on ID')\n",
327
+ "ax.set(xlabel='Model', ylabel='Normalized Levenshtein Distance')\n",
328
+ "#legend background should transparent\n",
329
+ "ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.,facecolor=None,framealpha=0.0)\n",
330
+ "# add horizontal lines for thresholds for each model while making sure the line is within the boxplot\n",
331
+ "min_val = 0 \n",
332
+ "for model in models:\n",
333
+ " thresh = model_thresholds[model]\n",
334
+ " plt.axhline(y=thresh, color='g', linestyle='--',xmin=min_val,xmax=min_val+0.2)\n",
335
+ " min_val+=0.2\n",
336
+ "\n"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "metadata": {},
343
+ "outputs": [],
344
+ "source": []
345
+ }
346
+ ],
347
+ "metadata": {
348
+ "kernelspec": {
349
+ "display_name": "transforna",
350
+ "language": "python",
351
+ "name": "python3"
352
+ },
353
+ "language_info": {
354
+ "codemirror_mode": {
355
+ "name": "ipython",
356
+ "version": 3
357
+ },
358
+ "file_extension": ".py",
359
+ "mimetype": "text/x-python",
360
+ "name": "python",
361
+ "nbconvert_exporter": "python",
362
+ "pygments_lexer": "ipython3",
363
+ "version": "3.9.18"
364
+ }
365
+ },
366
+ "nbformat": 4,
367
+ "nbformat_minor": 2
368
+ }
transforna/bin/figure_scripts/figure_S5.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ import transforna
6
+ from transforna import IDModelAugmenter, load
7
+
8
+ #%%
9
+ model_name = 'Seq'
10
+ config_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID/sub_class/{model_name}/meta/hp_settings.yaml'
11
+ config = load(config_path)
12
+ model_augmenter = IDModelAugmenter(df=None,config=config)
13
+ df = model_augmenter.predict_transforna_na()
14
+ tcga_df = load('/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv')
15
+
16
+ tcga_df.set_index('sequence',inplace=True)
17
+ tcga_df['Labels'] = tcga_df['subclass_name'][tcga_df['hico'] == True]
18
+ tcga_df['Labels'] = tcga_df['Labels'].astype('category')
19
+ #%%
20
+ tcga_df.loc[df.Sequence.values,'Labels'] = df['Net-Label'].values
21
+
22
+ loco_labels_df = tcga_df['subclass_name'].str.split(';',expand=True).loc[df['Sequence']]
23
+ #filter the rows having no_annotation in the first row of loco_labels_df
24
+ loco_labels_df = loco_labels_df.iloc[~(loco_labels_df[0] == 'no_annotation').values]
25
+ #%%
26
+ #get the Is Familiar? column from df based on index of loco_labels_df
27
+ novelty_prediction_loco_df = df[df['Sequence'].isin(loco_labels_df.index)].set_index('Sequence')['Is Familiar?']
28
+ #%%
29
+ id_predictions_df = tcga_df.loc[loco_labels_df.index]['Labels']
30
+ #copy the columns of id_predictions_df nuber of times equal to the number of columns in loco_labels_df
31
+ id_predictions_df = pd.concat([id_predictions_df]*loco_labels_df.shape[1],axis=1)
32
+ id_predictions_df.columns = np.arange(loco_labels_df.shape[1])
33
+ equ_mask = loco_labels_df == id_predictions_df
34
+ #check how many rows in eq_mask has atleast one True
35
+ num_true = equ_mask.any(axis=1).sum()
36
+ print('percentage of all loco RNAs: ',num_true/equ_mask.shape[0])
37
+
38
+
39
+ #split loco_labels_df into two dataframes. familiar and novel
40
+ fam_loco_labels_df = loco_labels_df[novelty_prediction_loco_df]
41
+ novel_loco_labels__df = loco_labels_df[~novelty_prediction_loco_df]
42
+ #seperate id_predictions_df into two dataframes. novel and familiar
43
+ id_predictions_fam_df = id_predictions_df[novelty_prediction_loco_df]
44
+ id_predictions_novel_df = id_predictions_df[~novelty_prediction_loco_df]
45
+ #%%
46
+ num_true_fam = (fam_loco_labels_df == id_predictions_fam_df).any(axis=1).sum()
47
+ num_true_novel = (novel_loco_labels__df == id_predictions_novel_df).any(axis=1).sum()
48
+
49
+ print('percentage of similar predictions in familiar: ',num_true_fam/fam_loco_labels_df.shape[0])
50
+ print('percentage of similar predictions not in novel: ',num_true_novel/novel_loco_labels__df.shape[0])
51
+ print('')
52
+ # %%
53
+ #remove the rows in fam_loco_labels_df and id_predictions_fam_df that have atleast one True in equ_mask
54
+ fam_loco_labels_no_overlap_df = fam_loco_labels_df[~equ_mask.any(axis=1)]
55
+ id_predictions_fam_no_overlap_df = id_predictions_fam_df[~equ_mask.any(axis=1)]
56
+ #collapse the dataframe of fam_loco_labels_df with a ';' seperator
57
+ collapsed_loco_labels_df = fam_loco_labels_no_overlap_df.apply(lambda x: ';'.join(x.dropna().astype(str)),axis=1)
58
+ #combined collapsed_loco_labels_df with id_predictions_fam_df[0]
59
+ predicted_fam_but_ann_novel_df = pd.concat([collapsed_loco_labels_df,id_predictions_fam_no_overlap_df[0]],axis=1)
60
+ #rename columns
61
+ predicted_fam_but_ann_novel_df.columns = ['KBA_labels','predicted_label']
62
+ # %%
63
+ #get major class for each column in KBA_labels and predicted_label
64
+ mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json'
65
+ sc_to_mc_mapper_dict = load(mapping_dict_path)
66
+
67
+ predicted_fam_but_ann_novel_df['KBA_labels_mc'] = predicted_fam_but_ann_novel_df['KBA_labels'].str.split(';').apply(lambda x: ';'.join([sc_to_mc_mapper_dict[i] if i in sc_to_mc_mapper_dict.keys() else i for i in x]))
68
+ predicted_fam_but_ann_novel_df['predicted_label_mc'] = predicted_fam_but_ann_novel_df['predicted_label'].apply(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict.keys() else x)
69
+ # %%
70
+ #for the each of the sequence in predicted_fam_but_ann_novel_df, compute the sim seq along with the lv distance
71
+ from transforna import predict_transforna
72
+
73
+ sim_df = predict_transforna(model=model_name,sequences=predicted_fam_but_ann_novel_df.index.tolist(),similarity_flag=True,n_sim=1,trained_on='id',path_to_models='/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/')
74
+ sim_df = sim_df.set_index('Sequence')
75
+
76
+ #append the sim_df to predicted_fam_but_ann_novel_df except for the Labels column
77
+ predicted_fam_but_ann_novel_df = pd.concat([predicted_fam_but_ann_novel_df,sim_df.drop('Labels',axis=1)],axis=1)
78
+ # %%
79
+ #plot the mc proportions of predicted_label_mc
80
+ predicted_fam_but_ann_novel_df['predicted_label_mc'].value_counts().plot(kind='bar')
81
+ #get order of labels on x axis
82
+ x_labels = predicted_fam_but_ann_novel_df['predicted_label_mc'].value_counts().index.tolist()
83
+ # %%
84
+ #plot the LV distance per predicted_label_mc and order the x axis based on the order of x_labels
85
+ fig = predicted_fam_but_ann_novel_df.boxplot(column='NLD',by='predicted_label_mc',figsize=(20,10),rot=90,showfliers=False)
86
+ #reorder x axis in fig by x_labels
87
+ fig.set_xticklabels(x_labels)
88
+ #increase font of axis labels and ticks
89
+ fig.set_xlabel('Predicted Label',fontsize=20)
90
+ fig.set_ylabel('Levenstein Distance',fontsize=20)
91
+ fig.tick_params(axis='both', which='major', labelsize=20)
92
+ #display pandas full rows
93
+ pd.set_option('display.max_rows', None)
94
+ # %%
transforna/bin/figure_scripts/figure_S8.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from transforna import load
3
+ from transforna import Results_Handler
4
+ import pandas as pd
5
+ import plotly.graph_objects as go
6
+ import plotly.io as pio
7
+ path = '/media/ftp_share/hbdx/analysis/tcga/TransfoRNA_I_ID_V4/sub_class/Seq/embedds/'
8
+ results:Results_Handler = Results_Handler(path=path,splits=['train','valid','test','ood'],read_ad=True)
9
+
10
+ mapping_dict = load('/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/subclass_to_annotation.json')
11
+ mapping_dict['artificial_affix'] = 'artificial_affix'
12
+ train_df = results.splits_df_dict['train_df']
13
+ valid_df = results.splits_df_dict['valid_df']
14
+ test_df = results.splits_df_dict['test_df']
15
+ ood_df = results.splits_df_dict['ood_df']
16
+ #remove RNA Sequences from the dataframe if not in results.ad.var.index
17
+ train_df = train_df[train_df['RNA Sequences'].isin(results.ad.var[results.ad.var['hico'] == True].index)['0'].values]
18
+ valid_df = valid_df[valid_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values]
19
+ test_df = test_df[test_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values]
20
+ ood_df = ood_df[ood_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values]
21
+ #concatenate train,valid and test
22
+ train_val_test_df = train_df.append(valid_df).append(test_df)
23
+ #map Labels to annotation
24
+ hico_id_labels = train_val_test_df['Labels','0'].map(mapping_dict).values
25
+ hico_ood_labels = ood_df['Labels','0'].map(mapping_dict).values
26
+ #read ad
27
+ ad = results.ad
28
+
29
+ hico_loco_df = pd.DataFrame(columns=['mc','hico_id','hico_ood','loco'])
30
+ for mc in ad.var['small_RNA_class_annotation'][ad.var['hico'] == True].unique():
31
+ hico_loco_df = hico_loco_df.append({'mc':mc,
32
+ 'hico_id':sum([mc in i for i in hico_id_labels]),
33
+ 'hico_ood':sum([mc in i for i in hico_ood_labels]),
34
+ 'loco':sum([mc in i for i in ad.var['small_RNA_class_annotation'][ad.var['hico'] != True].values])},ignore_index=True)
35
+ # %%
36
+
37
+
38
+ order = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA']
39
+
40
+ fig = go.Figure()
41
+ fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['hico_id'],name='HICO ID',marker_color='#00CC96'))
42
+ fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['hico_ood'],name='HICO OOD',marker_color='darkcyan'))
43
+ fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['loco'],name='LOCO',marker_color='#7f7f7f'))
44
+ fig.update_layout(barmode='group')
45
+ fig.update_layout(width=800,height=800)
46
+ #order the x axis
47
+ fig.update_layout(xaxis={'categoryorder':'array','categoryarray':order})
48
+ fig.update_layout(xaxis_title='Major Class',yaxis_title='Number of Sequences')
49
+ fig.update_layout(title='Number of Sequences per Major Class in ID, OOD and LOCO')
50
+ fig.update_layout(yaxis_type="log")
51
+ #save as png
52
+ pio.write_image(fig,'hico_id_ood_loco_proportion.png')
53
+ #save as svg
54
+ pio.write_image(fig,'hico_id_ood_loco_proportion.svg')
55
+ fig.show()
56
+ # %%
transforna/bin/figure_scripts/figure_S9_S11.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #in this file, the progression of the number of hicos per major class is computed per model
2
+ #this is done before ID, after FULL.
3
+ #%%
4
+ from transforna import load
5
+ from transforna import predict_transforna,predict_transforna_all_models
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+ import numpy as np
9
+ import plotly.io as pio
10
+ import plotly.express as px
11
+ mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json'
12
+ models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
13
+
14
+ mapping_dict = load(mapping_dict_path)
15
+
16
+ #%%
17
+ dataset:str = 'LC'
18
+ hico_loco_na_flag:str = 'hico'
19
+ assert hico_loco_na_flag in ['hico','loco_na'], 'hico_loco_na_flag must be either hico or loco_na'
20
+ if dataset == 'TCGA':
21
+ dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'
22
+ else:
23
+ dataset_path_train: str = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
24
+
25
+ prediction_single_pd = predict_transforna(['AAAAAAACCCCCTTTTTTT'],model='Seq',logits_flag = True,trained_on='id',path_to_models=models_path)
26
+ sub_classes_used_for_training = prediction_single_pd.columns.tolist()
27
+
28
+ var = load(dataset_path_train).set_index('sequence')
29
+ #remove from var all indexes that are longer than 30
30
+ var = var[var.index.str.len() <= 30]
31
+ hico_seqs_all = var.index[var['hico']].tolist()
32
+ hico_labels_all = var['subclass_name'][var['hico']].values
33
+
34
+ hico_seqs_id = var.index[var.hico & var.subclass_name.isin(sub_classes_used_for_training)].tolist()
35
+ hico_labels_id = var['subclass_name'][var.hico & var.subclass_name.isin(sub_classes_used_for_training)].values
36
+
37
+ non_hico_seqs = var['subclass_name'][var['hico'] == False].index.values
38
+ non_hico_labels = var['subclass_name'][var['hico'] == False].values
39
+
40
+
41
+ #filter hico labels and hico seqs to hico ID
42
+ if hico_loco_na_flag == 'loco_na':
43
+ curr_seqs = non_hico_seqs
44
+ curr_labels = non_hico_labels
45
+ elif hico_loco_na_flag == 'hico':
46
+ curr_seqs = hico_seqs_id
47
+ curr_labels = hico_labels_id
48
+
49
+ full_df = predict_transforna_all_models(sequences=curr_seqs,path_to_models=models_path)
50
+
51
+
52
+ #%%
53
+ mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA']
54
+ #for each mc, get the sequences of hicos in that mc and compute the number of hicos per model
55
+ num_hicos_per_mc = {}
56
+ if hico_loco_na_flag == 'hico':#this is where ground truth exists (hico id)
57
+ curr_labels_id_mc = [mapping_dict[label] for label in curr_labels]
58
+
59
+ elif hico_loco_na_flag == 'loco_na': # this is where ground truth does not exist (LOCO/NA)
60
+ ensemble_preds = full_df[full_df.Model == 'Ensemble'].set_index('Sequence').loc[curr_seqs].reset_index()
61
+ curr_labels_id_mc = [mapping_dict[label] for label in ensemble_preds['Net-Label']]
62
+
63
+ for mc in mcs:
64
+ #select sequences from hico_seqs that are in the major class mc
65
+ mc_seqs = [seq for seq,label in zip(curr_seqs,curr_labels_id_mc) if label == mc]
66
+ if len(mc_seqs) == 0:
67
+ num_hicos_per_mc[mc] = {model:0 for model in full_df.Model.unique()}
68
+ continue
69
+ #only keep in full_df the sequences that are in mc_seqs
70
+ mc_full_df = full_df[full_df.Sequence.isin(mc_seqs)]
71
+ curr_num_hico_per_model = mc_full_df[mc_full_df['Is Familiar?']].groupby(['Model'])['Is Familiar?'].value_counts().droplevel(1)
72
+ #remove Baseline from index
73
+ curr_num_hico_per_model = curr_num_hico_per_model.drop('Baseline')
74
+ curr_num_hico_per_model -= curr_num_hico_per_model.min()
75
+ num_hicos_per_mc[mc] = curr_num_hico_per_model.to_dict()
76
+ #%%
77
+ to_plot_df = pd.DataFrame(num_hicos_per_mc)
78
+ to_plot_mcs = ['rRNA','tRNA','snoRNA']
79
+ fig = go.Figure()
80
+ #x axis should be the mcs
81
+ for model in num_hicos_per_mc['rRNA'].keys():
82
+ fig.add_trace(go.Bar(x=mcs, y=[num_hicos_per_mc[mc][model] for mc in mcs], name=model))
83
+
84
+ fig.update_layout(barmode='group')
85
+ fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
86
+ fig.write_image(f"num_hicos_per_model_{dataset}_{hico_loco_na_flag}.svg")
87
+ fig.update_yaxes(type="log")
88
+ fig.show()
89
+
90
+ #%%
91
+
92
+ import pandas as pd
93
+ import glob
94
+ from plotly import graph_objects as go
95
+ from transforna import load,predict_transforna
96
+ all_df = pd.DataFrame()
97
+ files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/LC-novel_lev_dist_df.csv')
98
+ for file in files:
99
+ df = pd.read_csv(file)
100
+ all_df = pd.concat([all_df,df])
101
+ all_df = all_df.drop(columns=['Unnamed: 0'])
102
+ all_df.loc[all_df.split.isnull(),'split'] = 'NA'
103
+ ensemble_df = all_df[all_df.Model == 'Ensemble']
104
+ # %%
105
+
106
+ lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
107
+ lc_df = load(lc_path)
108
+ lc_df.set_index('sequence',inplace=True)
109
+ # %%
110
+ #filter lc_df to only include sequences that are in ensemble_df
111
+ lc_df = lc_df.loc[ensemble_df.Sequence]
112
+ actual_major_classes = lc_df['small_RNA_class_annotation']
113
+ predicted_major_classes = ensemble_df[['Net-Label','Sequence']].set_index('Sequence').loc[lc_df.index]['Net-Label'].map(mapping_dict)
114
+ # %%
115
+ #plot correlation matrix between actual and predicted major classes
116
+ from sklearn.metrics import confusion_matrix
117
+ import seaborn as sns
118
+ import matplotlib.pyplot as plt
119
+ import numpy as np
120
+ major_classes = list(set(list(predicted_major_classes.unique())+list(actual_major_classes.unique())))
121
+ conf_matrix = confusion_matrix(actual_major_classes,predicted_major_classes,labels=major_classes)
122
+ conf_matrix = conf_matrix/np.sum(conf_matrix,axis=1)
123
+
124
+ sns.heatmap(conf_matrix,annot=True,cmap='Blues')
125
+ for i in range(conf_matrix.shape[0]):
126
+ for j in range(conf_matrix.shape[1]):
127
+ conf_matrix[i,j] = round(conf_matrix[i,j],1)
128
+
129
+
130
+ plt.xlabel('Predicted Major Class')
131
+ plt.ylabel('Actual Major Class')
132
+ plt.xticks(np.arange(len(major_classes)),major_classes,rotation=90)
133
+ plt.yticks(np.arange(len(major_classes)),major_classes,rotation=0)
134
+ plt.show()
135
+
136
+ # %%
transforna/bin/figure_scripts/infer_lc_using_tcga.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ from random import randint
4
+
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ from anndata import AnnData
8
+
9
+ #add parent directory to path
10
+ sys.path.append('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/')
11
+ from src import (Results_Handler, correct_labels, load, predict_transforna,
12
+ predict_transforna_all_models,get_fused_seqs)
13
+
14
+
15
+ def get_mc_sc(infer_df,sequences,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag = False):
16
+
17
+ infered_seqs = infer_df.loc[sequences]
18
+ sc_classes_df = infered_seqs['subclass_name'].str.split(';',expand=True)
19
+ #filter rows with all nans in sc_classes_df
20
+ sc_classes_df = sc_classes_df[~sc_classes_df.isnull().all(axis=1)]
21
+ #cmask for classes used for training
22
+ if ood_flag:
23
+ sub_classes_used_for_training_plus_neighbors = []
24
+ #for every subclass in sub_classes_used_for_training that contains bin, get previous and succeeding bins
25
+ for sub_class in sub_classes_used_for_training:
26
+ sub_classes_used_for_training_plus_neighbors.append(sub_class)
27
+ if 'bin' in sub_class:
28
+ bin_num = int(sub_class.split('_bin-')[1])
29
+ if bin_num > 0:
30
+ sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num-1}')
31
+ sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num+1}')
32
+ if 'tR' in sub_class:
33
+ #seperate the first part(either 3p/5p), also ge tthe second part after __
34
+ first_part = sub_class.split('-')[0]
35
+ second_part = sub_class.split('__')[1]
36
+ #get all classes in sc_to_mc_mapper_dict,values that contain both parts and append them to sub_classes_used_for_training_plus_neighbors
37
+ sub_classes_used_for_training_plus_neighbors += [sc for sc in sc_to_mc_mapper_dict.keys() if first_part in sc and second_part in sc]
38
+ sub_classes_used_for_training_plus_neighbors = list(set(sub_classes_used_for_training_plus_neighbors))
39
+ mask = sc_classes_df.applymap(lambda x: True if (x not in sub_classes_used_for_training_plus_neighbors and 'hypermapper' not in x)\
40
+ or pd.isnull(x) else False)
41
+
42
+ else:
43
+ mask = sc_classes_df.applymap(lambda x: True if x in sub_classes_used_for_training or pd.isnull(x) else False)
44
+
45
+ #check if any sub class in sub_classes_used_for_training is in sc_classes_df
46
+ if mask.apply(lambda x: all(x.tolist()), axis=1).sum() == 0:
47
+ #TODO: change to log
48
+ import logging
49
+ log_ = logging.getLogger(__name__)
50
+ log_.error('None of the sub classes used for training are in the sequences')
51
+ raise Exception('None of the sub classes used for training are in the sequences')
52
+
53
+ #filter rows with atleast one False in mask
54
+ sc_classes_df = sc_classes_df[mask.apply(lambda x: all(x.tolist()), axis=1)]
55
+ #get mc classes
56
+ mc_classes_df = sc_classes_df.applymap(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict else 'not_found')
57
+ #assign major class for not found values if containing 'miRNA', 'tRNA', 'rRNA', 'snRNA', 'snoRNA'
58
+ #mc_classes_df = mc_classes_df.applymap(lambda x: None if x is None else 'miRNA' if 'miR' in x else 'tRNA' if 'tRNA' in x else 'rRNA' if 'rRNA' in x else 'snRNA' if 'snRNA' in x else 'snoRNA' if 'snoRNA' in x else 'snoRNA' if 'SNO' in x else 'protein_coding' if 'RPL37A' in x else 'lncRNA' if 'SNHG1' in x else 'not_found')
59
+ #filter all 'not_found' rows
60
+ mc_classes_df = mc_classes_df[mc_classes_df.apply(lambda x: 'not_found' not in x.tolist() ,axis=1)]
61
+ #filter values with ; in mc_classes_df
62
+ mc_classes_df = mc_classes_df[~mc_classes_df[0].str.contains(';')]
63
+ #filter index
64
+ sc_classes_df = sc_classes_df.loc[mc_classes_df.index]
65
+ mc_classes_df = mc_classes_df.loc[sc_classes_df.index]
66
+ return mc_classes_df,sc_classes_df
67
+
68
+ def plot_confusion_false_novel(df,sc_df,mc_df,save_figs:bool=False):
69
+ #filter index of sc_classes_df to contain indices of outliers df
70
+ curr_sc_classes_df = sc_df.loc[[i for i in df.index if i in sc_df.index]]
71
+ curr_mc_classes_df = mc_df.loc[[i for i in df.index if i in mc_df.index]]
72
+ #convert Labels to mc_Labels
73
+ df = df.assign(predicted_mc_labels=df.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1))
74
+ #add mc classes
75
+ df = df.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist())
76
+ #add sc classes
77
+ df = df.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist())
78
+ #compute accuracy
79
+ df = df.assign(mc_accuracy=df.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1))
80
+ df = df.assign(sc_accuracy=df.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1))
81
+
82
+ #use plotly to plot confusion matrix based on mc classes
83
+ mc_confusion_matrix = df.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack()
84
+ mc_confusion_matrix = mc_confusion_matrix.fillna(0)
85
+ mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1)
86
+ mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,2))
87
+ #for columns not in rows, sum the column values and add them to a new column called 'other'
88
+ other_col = [0]*mc_confusion_matrix.shape[0]
89
+ for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]:
90
+ other_col += mc_confusion_matrix[i]
91
+ mc_confusion_matrix['other'] = other_col
92
+ #add an other row with all zeros
93
+ mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1]
94
+ #drop all columns not in rows
95
+ mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1)
96
+ #plot confusion matri
97
+ fig = go.Figure(data=go.Heatmap(
98
+ z=mc_confusion_matrix.values,
99
+ x=mc_confusion_matrix.columns,
100
+ y=mc_confusion_matrix.index,
101
+ hoverongaps = False))
102
+ #add z values to heatmap
103
+ for i in range(len(mc_confusion_matrix.index)):
104
+ for j in range(len(mc_confusion_matrix.columns)):
105
+ fig.add_annotation(text=str(mc_confusion_matrix.values[i][j]), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i],
106
+ showarrow=False, font_size=25, font_color='red')
107
+ #add title
108
+ fig.update_layout(title_text='Confusion matrix based on mc classes for false novel sequences')
109
+ #label x axis and y axis
110
+ fig.update_xaxes(title_text='Predicted mc class')
111
+ fig.update_yaxes(title_text='Actual mc class')
112
+ #save
113
+ if save_figs:
114
+ fig.write_image('transforna/bin/lc_figures/confusion_matrix_mc_classes_false_novel.png')
115
+
116
+
117
+ def compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers = False,fig_prefix:str = '',save_figs:bool=False):
118
+ font_size = 25
119
+ if fig_prefix == 'LC-familiar':
120
+ font_size = 10
121
+ #rename Labels to predicted_sc_labels
122
+ prediction_pd = prediction_pd.rename(columns={'Net-Label':'predicted_sc_labels'})
123
+
124
+ for model in prediction_pd['Model'].unique():
125
+ #get model predictions
126
+ num_rows = sc_classes_df.shape[0]
127
+ model_prediction_pd = prediction_pd[prediction_pd['Model'] == model]
128
+ model_prediction_pd = model_prediction_pd.set_index('Sequence')
129
+ #filter index of model_prediction_pd to contain indices of sc_classes_df
130
+ model_prediction_pd = model_prediction_pd.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]]
131
+
132
+ try: #try because ensemble models do not have a folder
133
+ #check how many of the hico seqs exist in the train_df
134
+ embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/{model}/embedds'
135
+ results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
136
+ except:
137
+ embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq-Rev/embedds'
138
+ results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
139
+
140
+ train_seqs = set(results.splits_df_dict['train_df']['RNA Sequences']['0'].values.tolist())
141
+ common_seqs = train_seqs.intersection(set(model_prediction_pd.index.tolist()))
142
+ print(f'Number of common seqs between train_df and {model} is {len(common_seqs)}')
143
+ #print(f'removing overlaping sequences between train set and inference')
144
+ #remove common_seqs from model_prediction_pd
145
+ #model_prediction_pd = model_prediction_pd.drop(common_seqs)
146
+
147
+
148
+ #compute number of sequences where NLD is higher than Novelty Threshold
149
+ num_outliers = sum(model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'])
150
+ false_novel_df = model_prediction_pd[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold']]
151
+
152
+ plot_confusion_false_novel(false_novel_df,sc_classes_df,mc_classes_df,save_figs)
153
+ #draw a pie chart depicting number of outliers per actual_mc_labels
154
+ fig_outl = mc_classes_df.loc[false_novel_df.index][0].value_counts().plot.pie(autopct='%1.1f%%',figsize=(6, 6))
155
+ fig_outl.set_title(f'False Novel per MC for {model}: {num_outliers}')
156
+ if save_figs:
157
+ fig_outl.get_figure().savefig(f'transforna/bin/lc_figures/false_novel_mc_{model}.png')
158
+ fig_outl.get_figure().clf()
159
+ #get number of unique sub classes per major class in false_novel_df
160
+ false_novel_sc_freq_df = sc_classes_df.loc[false_novel_df.index][0].value_counts().to_frame()
161
+ #save index as csv
162
+ #false_novel_sc_freq_df.to_csv(f'false_novel_sc_freq_df_{model}.csv')
163
+ #add mc to false_novel_sc_freq_df
164
+ false_novel_sc_freq_df['MC'] = false_novel_sc_freq_df.index.map(lambda x: sc_to_mc_mapper_dict[x])
165
+ #plot pie chart showing unique sub classes per major class in false_novel_df
166
+ fig_outl_sc = false_novel_sc_freq_df.groupby('MC')[0].sum().plot.pie(autopct='%1.1f%%',figsize=(6, 6))
167
+ fig_outl_sc.set_title(f'False novel: No. Unique sub classes per MC {model}: {num_outliers}')
168
+ if save_figs:
169
+ fig_outl_sc.get_figure().savefig(f'transforna/bin/lc_figures/{fig_prefix}_false_novel_sc_{model}.png')
170
+ fig_outl_sc.get_figure().clf()
171
+ #filter outliers
172
+ if seperate_outliers:
173
+ model_prediction_pd = model_prediction_pd[model_prediction_pd['NLD'] <= model_prediction_pd['Novelty Threshold']]
174
+ else:
175
+ #set the predictions of outliers to 'Outlier'
176
+ model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_sc_labels'] = 'Outlier'
177
+ model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_mc_labels'] = 'Outlier'
178
+ sc_to_mc_mapper_dict['Outlier'] = 'Outlier'
179
+
180
+ #filter index of sc_classes_df to contain indices of model_prediction_pd
181
+ curr_sc_classes_df = sc_classes_df.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]]
182
+ curr_mc_classes_df = mc_classes_df.loc[[i for i in model_prediction_pd.index if i in mc_classes_df.index]]
183
+ #convert Labels to mc_Labels
184
+ model_prediction_pd = model_prediction_pd.assign(predicted_mc_labels=model_prediction_pd.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1))
185
+ #add mc classes
186
+ model_prediction_pd = model_prediction_pd.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist())
187
+ #add sc classes
188
+ model_prediction_pd = model_prediction_pd.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist())
189
+ #correct labels
190
+ model_prediction_pd['predicted_sc_labels'] = correct_labels(model_prediction_pd['predicted_sc_labels'],model_prediction_pd['actual_sc_labels'],sc_to_mc_mapper_dict)
191
+ #compute accuracy
192
+ model_prediction_pd = model_prediction_pd.assign(mc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1))
193
+ model_prediction_pd = model_prediction_pd.assign(sc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1))
194
+
195
+ if not seperate_outliers:
196
+ cols_to_save = ['actual_mc_labels','predicted_mc_labels','predicted_sc_labels','actual_sc_labels']
197
+ total_false_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels != model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save]
198
+ #add a column indicating if NLD is greater than Novelty Threshold
199
+ total_false_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_false_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_false_mc_predictions_df.index]['Novelty Threshold']
200
+ #save
201
+ total_false_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_false_mcs_w_out_{model}.csv')
202
+ total_true_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels == model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save]
203
+ #add a column indicating if NLD is greater than Novelty Threshold
204
+ total_true_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_true_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_true_mc_predictions_df.index]['Novelty Threshold']
205
+ #save
206
+ total_true_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_true_mcs_w_out_{model}.csv')
207
+
208
+ print('Model: ', model)
209
+ print('num_outliers: ', num_outliers)
210
+ #print accuracy including outliers
211
+ print('mc_accuracy: ', model_prediction_pd['mc_accuracy'].mean())
212
+ print('sc_accuracy: ', model_prediction_pd['sc_accuracy'].mean())
213
+
214
+ #print balanced accuracy
215
+ print('mc_balanced_accuracy: ', model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean())
216
+ print('sc_balanced_accuracy: ', model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean())
217
+
218
+ #use plotly to plot confusion matrix based on mc classes
219
+ mc_confusion_matrix = model_prediction_pd.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack()
220
+ mc_confusion_matrix = mc_confusion_matrix.fillna(0)
221
+ mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1)
222
+ mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,4))
223
+ #for columns not in rows, sum the column values and add them to a new column called 'other'
224
+ other_col = [0]*mc_confusion_matrix.shape[0]
225
+ for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]:
226
+ other_col += mc_confusion_matrix[i]
227
+ mc_confusion_matrix['other'] = other_col
228
+ #add an other row with all zeros
229
+ mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1]
230
+ #drop all columns not in rows
231
+ mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1)
232
+ #plot confusion matrix
233
+
234
+ fig = go.Figure(data=go.Heatmap(
235
+ z=mc_confusion_matrix.values,
236
+ x=mc_confusion_matrix.columns,
237
+ y=mc_confusion_matrix.index,
238
+ colorscale='Blues',
239
+ hoverongaps = False))
240
+ #add z values to heatmap
241
+ for i in range(len(mc_confusion_matrix.index)):
242
+ for j in range(len(mc_confusion_matrix.columns)):
243
+ fig.add_annotation(text=str(round(mc_confusion_matrix.values[i][j],2)), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i],
244
+ showarrow=False, font_size=font_size, font_color='black')
245
+
246
+ fig.update_layout(
247
+ title='Confusion matrix for mc classes - ' + model + ' - ' + 'mc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean(),2)) \
248
+ + ' - ' + 'sc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean(),2)) + '<br>' + \
249
+ 'percent false novel: ' + str(round(num_outliers/num_rows,2)),
250
+ xaxis_nticks=36)
251
+ #label x axis and y axis
252
+ fig.update_xaxes(title_text='Predicted mc class')
253
+ fig.update_yaxes(title_text='Actual mc class')
254
+ if save_figs:
255
+ #save plot
256
+ if seperate_outliers:
257
+ fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.png')
258
+ #save svg
259
+ fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.svg')
260
+ else:
261
+ fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.png')
262
+ #save svg
263
+ fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.svg')
264
+ print('\n')
265
+
266
+
267
+ if __name__ == '__main__':
268
+ #####################################################################################################################
269
+ mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'
270
+ LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
271
+ path_to_models = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
272
+
273
+ trained_on = 'full' #id or full
274
+ save_figs = True
275
+
276
+ infer_aa = infer_relaxed_mirna = infer_hico = infer_ood = infer_other_affixes = infer_random = infer_fused = infer_na = infer_loco = False
277
+
278
+ split = 'infer_hico'#sys.argv[1]
279
+ print(f'Running inference for {split}')
280
+ if split == 'infer_aa':
281
+ infer_aa = True
282
+ elif split == 'infer_relaxed_mirna':
283
+ infer_relaxed_mirna = True
284
+ elif split == 'infer_hico':
285
+ infer_hico = True
286
+ elif split == 'infer_ood':
287
+ infer_ood = True
288
+ elif split == 'infer_other_affixes':
289
+ infer_other_affixes = True
290
+ elif split == 'infer_random':
291
+ infer_random = True
292
+ elif split == 'infer_fused':
293
+ infer_fused = True
294
+ elif split == 'infer_na':
295
+ infer_na = True
296
+ elif split == 'infer_loco':
297
+ infer_loco = True
298
+
299
+ #####################################################################################################################
300
+ #only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico should be true
301
+ if sum([infer_aa,infer_relaxed_mirna,infer_hico,infer_ood,infer_other_affixes,infer_random,infer_fused,infer_na,infer_loco]) != 1:
302
+ raise Exception('Only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico or infer_other_affixes or infer_random or infer_fused or infer_na should be true')
303
+
304
+ #set fig_prefix
305
+ if infer_aa:
306
+ fig_prefix = '5\'A-affixes'
307
+ elif infer_other_affixes:
308
+ fig_prefix = 'other_affixes'
309
+ elif infer_relaxed_mirna:
310
+ fig_prefix = 'Relaxed-miRNA'
311
+ elif infer_hico:
312
+ fig_prefix = 'LC-familiar'
313
+ elif infer_ood:
314
+ fig_prefix = 'LC-novel'
315
+ elif infer_random:
316
+ fig_prefix = 'Random'
317
+ elif infer_fused:
318
+ fig_prefix = 'Fused'
319
+ elif infer_na:
320
+ fig_prefix = 'NA'
321
+ elif infer_loco:
322
+ fig_prefix = 'LOCO'
323
+
324
+ infer_df = load(LC_path)
325
+ if isinstance(infer_df,AnnData):
326
+ infer_df = infer_df.var
327
+ infer_df.set_index('sequence',inplace=True)
328
+ sc_to_mc_mapper_dict = load(mapping_dict_path)
329
+ #select hico sequences
330
+ hico_seqs = infer_df.index[infer_df['hico']].tolist()
331
+ art_affix_seqs = infer_df[~infer_df['five_prime_adapter_filter']].index.tolist()
332
+
333
+ if infer_hico:
334
+ hico_seqs = hico_seqs
335
+
336
+ if infer_aa:
337
+ hico_seqs = art_affix_seqs
338
+
339
+ if infer_other_affixes:
340
+ hico_seqs = infer_df[~infer_df['hbdx_spikein_affix_filter']].index.tolist()
341
+
342
+ if infer_na:
343
+ hico_seqs = infer_df[infer_df.subclass_name == 'no_annotation'].index.tolist()
344
+
345
+ if infer_loco:
346
+ hico_seqs = infer_df[~infer_df['hico']][infer_df.subclass_name != 'no_annotation'].index.tolist()
347
+
348
+ #for mirnas
349
+ if infer_relaxed_mirna:
350
+ #subclass name must contain miR, let, Let and not contain ; and that are not hico
351
+ mirnas_seqs = infer_df[infer_df.subclass_name.str.contains('miR') | infer_df.subclass_name.str.contains('let')][~infer_df.subclass_name.str.contains(';')].index.tolist()
352
+ #remove the ones that are true in ad.hico column
353
+ hico_seqs = list(set(mirnas_seqs).difference(set(hico_seqs)))
354
+
355
+ #novel mirnas
356
+ #mirnas_not_in_train_mask = (ad['hico']==True).values * ~(ad['subclass_name'].isin(mirna_train_sc)).values * (ad['small_RNA_class_annotation'].isin(['miRNA']))
357
+ #hicos = ad[mirnas_not_in_train_mask].index.tolist()
358
+
359
+
360
+ if infer_random:
361
+ #create random sequences
362
+ random_seqs = []
363
+ while len(random_seqs) < 200:
364
+ random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(18,30)))
365
+ if random_seq not in random_seqs:
366
+ random_seqs.append(random_seq)
367
+ hico_seqs = random_seqs
368
+
369
+ if infer_fused:
370
+ hico_seqs = get_fused_seqs(hico_seqs,num_sequences=200)
371
+
372
+
373
+ #hico_seqs = ad[ad.subclass_name.str.contains('mir')][~ad.subclass_name.str.contains(';')]['subclass_name'].index.tolist()
374
+ hico_seqs = [seq for seq in hico_seqs if len(seq) <= 30]
375
+ #set cuda 1
376
+ import os
377
+ os.environ["CUDA_VISIBLE_DEVICES"] = '1'
378
+
379
+ #run prediction
380
+ prediction_pd = predict_transforna_all_models(hico_seqs,trained_on=trained_on,path_to_models=path_to_models)
381
+ prediction_pd['split'] = fig_prefix
382
+ #the if condition here is to make sure to filter seqs with sub classes not used in training
383
+ if not infer_ood and not infer_relaxed_mirna and not infer_hico:
384
+ prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv')
385
+ if infer_aa or infer_other_affixes or infer_random or infer_fused:
386
+ for model in prediction_pd.Model.unique():
387
+ num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?'])
388
+ print(f'Number of non novel sequences for {model} is {num_non_novel}')
389
+ print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the lower the better')
390
+
391
+ else:
392
+ if infer_na or infer_loco:
393
+ #print number of Is Familiar per model
394
+ for model in prediction_pd.Model.unique():
395
+ num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?'])
396
+ print(f'Number of non novel sequences for {model} is {num_non_novel}')
397
+ print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the higher the better')
398
+ print('\n')
399
+ else:
400
+ #only to get classes used for training
401
+ prediction_single_pd = predict_transforna(hico_seqs[0],model='Seq',logits_flag = True,trained_on=trained_on,path_to_models=path_to_models)
402
+ sub_classes_used_for_training = prediction_single_pd.columns.tolist()
403
+
404
+
405
+ mc_classes_df,sc_classes_df = get_mc_sc(infer_df,hico_seqs,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag=infer_ood)
406
+ if infer_ood:
407
+ for model in prediction_pd.Model.unique():
408
+ #filter sequences in prediction_pd to only include sequences in sc_classes_df
409
+ curr_prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)]
410
+ #filter curr_prediction toonly include model
411
+ curr_prediction_pd = curr_prediction_pd[curr_prediction_pd.Model == model]
412
+ num_seqs = curr_prediction_pd.shape[0]
413
+ #filter Is Familiar
414
+ curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Is Familiar?']]
415
+ #filter sc_classes_df to only include sequences in curr_prediction_pd
416
+ curr_sc_classes_df = sc_classes_df[sc_classes_df.index.isin(curr_prediction_pd['Sequence'].values)]
417
+ #correct labels and remove the correct labels from the curr_prediction_pd
418
+ curr_prediction_pd['Net-Label'] = correct_labels(curr_prediction_pd['Net-Label'].values,curr_sc_classes_df[0].values,sc_to_mc_mapper_dict)
419
+ #filter rows in curr_prediction where Labels is equal to sc_classes_df[0]
420
+ curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Net-Label'].values != curr_sc_classes_df[0].values]
421
+ num_non_novel = len(curr_prediction_pd)
422
+ print(f'Number of non novel sequences for {model} is {num_non_novel}')
423
+ print(f'Percent non novel for {model} is {num_non_novel/num_seqs}, the lower the better')
424
+ print('\n')
425
+ else:
426
+ #filter prediction_pd to include only sequences in prediction_pd
427
+
428
+ #compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=False,fig_prefix = fig_prefix,save_figs=save_figs)
429
+ compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=True,fig_prefix = fig_prefix,save_figs=save_figs)
430
+
431
+ if infer_ood or infer_relaxed_mirna or infer_hico:
432
+ prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)]
433
+ #save lev_dist_df
434
+ prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv')
435
+
436
+
437
+
438
+
transforna/src/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .callbacks import *
2
+ from .callbacks.tbWriter import writer
3
+ from .inference import *
4
+ from .model import *
5
+ from .novelty_prediction import *
6
+ from .processing import *
7
+ from .score import *
8
+ from .train import *
9
+ from .utils import *
transforna/src/callbacks/LRCallback.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Iterable
3
+ from math import cos, floor, log, pi
4
+
5
+ import skorch
6
+ from torch.optim.lr_scheduler import _LRScheduler
7
+
8
+ _LRScheduler
9
+
10
+
11
+ class CyclicCosineDecayLR(skorch.callbacks.Callback):
12
+ def __init__(
13
+ self,
14
+ optimizer,
15
+ init_interval,
16
+ min_lr,
17
+ len_param_groups,
18
+ base_lrs,
19
+ restart_multiplier=None,
20
+ restart_interval=None,
21
+ restart_lr=None,
22
+ last_epoch=-1,
23
+ ):
24
+ """
25
+ Initialize new CyclicCosineDecayLR object
26
+ :param optimizer: (Optimizer) - Wrapped optimizer.
27
+ :param init_interval: (int) - Initial decay cycle interval.
28
+ :param min_lr: (float or iterable of floats) - Minimal learning rate.
29
+ :param restart_multiplier: (float) - Multiplication coefficient for increasing cycle intervals,
30
+ if this parameter is set, restart_interval must be None.
31
+ :param restart_interval: (int) - Restart interval for fixed cycle intervals,
32
+ if this parameter is set, restart_multiplier must be None.
33
+ :param restart_lr: (float or iterable of floats) - Optional, the learning rate at cycle restarts,
34
+ if not provided, initial learning rate will be used.
35
+ :param last_epoch: (int) - Last epoch.
36
+ """
37
+ self.len_param_groups = len_param_groups
38
+ if restart_interval is not None and restart_multiplier is not None:
39
+ raise ValueError(
40
+ "You can either set restart_interval or restart_multiplier but not both"
41
+ )
42
+
43
+ if isinstance(min_lr, Iterable) and len(min_lr) != self.len_param_groups:
44
+ raise ValueError(
45
+ "Expected len(min_lr) to be equal to len(optimizer.param_groups), "
46
+ "got {} and {} instead".format(len(min_lr), self.len_param_groups)
47
+ )
48
+
49
+ if isinstance(restart_lr, Iterable) and len(restart_lr) != len(
50
+ self.len_param_groups
51
+ ):
52
+ raise ValueError(
53
+ "Expected len(restart_lr) to be equal to len(optimizer.param_groups), "
54
+ "got {} and {} instead".format(len(restart_lr), self.len_param_groups)
55
+ )
56
+
57
+ if init_interval <= 0:
58
+ raise ValueError(
59
+ "init_interval must be a positive number, got {} instead".format(
60
+ init_interval
61
+ )
62
+ )
63
+
64
+ group_num = self.len_param_groups
65
+ self._init_interval = init_interval
66
+ self._min_lr = [min_lr] * group_num if isinstance(min_lr, float) else min_lr
67
+ self._restart_lr = (
68
+ [restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr
69
+ )
70
+ self._restart_interval = restart_interval
71
+ self._restart_multiplier = restart_multiplier
72
+ self.last_epoch = last_epoch
73
+ self.base_lrs = base_lrs
74
+ super().__init__()
75
+
76
+ def on_batch_end(self, net, training, **kwargs):
77
+ if self.last_epoch < self._init_interval:
78
+ return self._calc(self.last_epoch, self._init_interval, self.base_lrs)
79
+
80
+ elif self._restart_interval is not None:
81
+ cycle_epoch = (
82
+ self.last_epoch - self._init_interval
83
+ ) % self._restart_interval
84
+ lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
85
+ return self._calc(cycle_epoch, self._restart_interval, lrs)
86
+
87
+ elif self._restart_multiplier is not None:
88
+ n = self._get_n(self.last_epoch)
89
+ sn_prev = self._partial_sum(n)
90
+ cycle_epoch = self.last_epoch - sn_prev
91
+ interval = self._init_interval * self._restart_multiplier ** n
92
+ lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
93
+ return self._calc(cycle_epoch, interval, lrs)
94
+ else:
95
+ return self._min_lr
96
+
97
+ def _calc(self, t, T, lrs):
98
+ return [
99
+ min_lr + (lr - min_lr) * (1 + cos(pi * t / T)) / 2
100
+ for lr, min_lr in zip(lrs, self._min_lr)
101
+ ]
102
+
103
+ def _get_n(self, epoch):
104
+ a = self._init_interval
105
+ r = self._restart_multiplier
106
+ _t = 1 - (1 - r) * epoch / a
107
+ return floor(log(_t, r))
108
+
109
+ def _partial_sum(self, n):
110
+ a = self._init_interval
111
+ r = self._restart_multiplier
112
+ return a * (1 - r ** n) / (1 - r)
113
+
114
+
115
+ class LearningRateDecayCallback(skorch.callbacks.Callback):
116
+ def __init__(
117
+ self,
118
+ config,
119
+ ):
120
+ super().__init__()
121
+ self.lr_warmup_end = config.lr_warmup_end
122
+ self.lr_warmup_start = config.lr_warmup_start
123
+ self.learning_rate = config.learning_rate
124
+ self.warmup_batch = config.warmup_epoch * config.batch_per_epoch
125
+ self.final_batch = config.final_epoch * config.batch_per_epoch
126
+
127
+ self.batch_idx = 0
128
+
129
+ def on_batch_end(self, net, training, **kwargs):
130
+ """
131
+
132
+ :param trainer:
133
+ :type trainer:
134
+ :param pl_module:
135
+ :type pl_module:
136
+ :param batch:
137
+ :type batch:
138
+ :param batch_idx:
139
+ :type batch_idx:
140
+ :param dataloader_idx:
141
+ :type dataloader_idx:
142
+ """
143
+ # to avoid updating after validation batch
144
+ if training:
145
+
146
+ if self.batch_idx < self.warmup_batch:
147
+ # linear warmup, in paper: start from 0.1 to 1 over lr_warmup_end batches
148
+ lr_mult = float(self.batch_idx) / float(max(1, self.warmup_batch))
149
+ lr = self.lr_warmup_start + lr_mult * (
150
+ self.lr_warmup_end - self.lr_warmup_start
151
+ )
152
+ else:
153
+ # Cosine learning rate decay
154
+ progress = float(self.batch_idx - self.warmup_batch) / float(
155
+ max(1, self.final_batch - self.warmup_batch)
156
+ )
157
+ lr = max(
158
+ self.learning_rate
159
+ + 0.5
160
+ * (1.0 + math.cos(math.pi * progress))
161
+ * (self.lr_warmup_end - self.learning_rate),
162
+ self.learning_rate,
163
+ )
164
+ net.lr = lr
165
+ # for param_group in net.optimizer.param_groups:
166
+ # param_group["lr"] = lr
167
+
168
+ self.batch_idx += 1
169
+
170
+
171
+ class LRAnnealing(skorch.callbacks.Callback):
172
+ def on_epoch_end(self, net, **kwargs):
173
+ if not net.history[-1]["valid_loss_best"]:
174
+ net.lr /= 4.0
transforna/src/callbacks/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .criterion import *
2
+ from .LRCallback import *
3
+ from .metrics import *
4
+ from .tbWriter import *
transforna/src/callbacks/criterion.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class LossFunction(nn.Module):
12
+ def __init__(self,main_config):
13
+ super(LossFunction, self).__init__()
14
+ self.train_config = main_config["train_config"]
15
+ self.model_config = main_config["model_config"]
16
+ self.batch_per_epoch = self.train_config.batch_per_epoch
17
+ self.warm_up_annealing = (
18
+ self.train_config.warmup_epoch * self.batch_per_epoch
19
+ )
20
+ self.num_embed_hidden = self.model_config.num_embed_hidden
21
+ self.batch_idx = 0
22
+ self.loss_anealing_term = 0
23
+
24
+
25
+ class_weights = self.model_config.class_weights
26
+ #TODO: use device as in main_config
27
+ class_weights = torch.FloatTensor([float(x) for x in class_weights])
28
+
29
+ if self.model_config.num_classes > 2:
30
+ self.clf_loss_fn = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=self.train_config.label_smoothing_clf,reduction='none')
31
+ else:
32
+ self.clf_loss_fn = self.focal_loss
33
+
34
+
35
+ # @staticmethod
36
+ def cosine_similarity_matrix(
37
+ self, gene_embedd: torch.Tensor, second_input_embedd: torch.Tensor, annealing=True
38
+ ) -> torch.Tensor:
39
+ # if annealing is true, then this function is being called from Net.predict and
40
+ # doesnt pass the instantiated object LossFunction, therefore no access to self.
41
+ # in Predict we also just need the max of predictions.
42
+ # for some reason, skorch only passes the LossFunction initialized object, only
43
+ # from get_loss fn.
44
+
45
+ assert gene_embedd.size(0) == second_input_embedd.size(0)
46
+
47
+ cosine_similarity = torch.matmul(gene_embedd, second_input_embedd.T)
48
+
49
+ if annealing:
50
+ if self.batch_idx < self.warm_up_annealing:
51
+ self.loss_anealing_term = 1 + (
52
+ self.batch_idx / self.warm_up_annealing
53
+ ) * torch.sqrt(torch.tensor(self.num_embed_hidden))
54
+
55
+ cosine_similarity *= self.loss_anealing_term
56
+
57
+ return cosine_similarity
58
+ def get_similar_labels(self,y:torch.Tensor):
59
+ '''
60
+ This function recieves y, the labels tensor
61
+ It creates a list of lists containing at every index a list(min_len = 2) of the indices of the labels that are similar
62
+ '''
63
+ # create a test array
64
+ labels_y = y[:,0].cpu().detach().numpy()
65
+
66
+ # creates an array of indices, sorted by unique element
67
+ idx_sort = np.argsort(labels_y)
68
+
69
+ # sorts records array so all unique elements are together
70
+ sorted_records_array = labels_y[idx_sort]
71
+
72
+ # returns the unique values, the index of the first occurrence of a value, and the count for each element
73
+ vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)
74
+
75
+ # splits the indices into separate arrays
76
+ res = np.split(idx_sort, idx_start[1:])
77
+ #filter them with respect to their size, keeping only items occurring more than once
78
+ vals = vals[count > 1]
79
+ res = filter(lambda x: x.size > 1, res)
80
+
81
+ indices_similar_labels = []
82
+ similar_labels = []
83
+ for r in res:
84
+ indices_similar_labels.append(list(r))
85
+ similar_labels.append(list(labels_y[r]))
86
+
87
+ return indices_similar_labels,similar_labels
88
+
89
+ def get_triplet_samples(self,indices_similar_labels,similar_labels):
90
+ '''
91
+ This function creates three lists, positives, anchors and negatives
92
+ Each index in the three lists correpond to a single triplet
93
+ '''
94
+ positives,anchors,negatives = [],[],[]
95
+ for idx_similar_labels in indices_similar_labels:
96
+ random_indices = random.sample(range(len(idx_similar_labels)), 2)
97
+ positives.append(idx_similar_labels[random_indices[0]])
98
+ anchors.append(idx_similar_labels[random_indices[1]])
99
+
100
+ negatives = copy.deepcopy(positives)
101
+ random.shuffle(negatives)
102
+ while (np.array(positives) == np.array(negatives)).any():
103
+ random.shuffle(negatives)
104
+
105
+ return positives,anchors,negatives
106
+ def get_triplet_loss(self,y,gene_embedd,second_input_embedd):
107
+ '''
108
+ This function computes triplet loss by creating triplet samples of positives, negatives and anchors
109
+ The objective is to decrease the distance of the embeddings between the anchors and the positives
110
+ while increasing the distance between the anchor and the negatives.
111
+ This is done seperately for both the embeddings, gene embedds 0 and ss embedds 1
112
+ '''
113
+ #get similar labels
114
+ indices_similar_labels,similar_labels = self.get_similar_labels(y)
115
+ #insuring that there's at least two sets of labels in a given list (indices_similar_labels)
116
+ if len(indices_similar_labels)>1:
117
+ #get triplet samples
118
+ positives,anchors,negatives = self.get_triplet_samples(indices_similar_labels,similar_labels)
119
+ #get triplet loss for gene embedds
120
+ gene_embedd_triplet_loss = self.triplet_loss(gene_embedd[positives,:],
121
+ gene_embedd[anchors,:],
122
+ gene_embedd[negatives,:])
123
+ #get triplet loss for ss embedds
124
+ second_input_embedd_triplet_loss = self.triplet_loss(second_input_embedd[positives,:],
125
+ second_input_embedd[anchors,:],
126
+ second_input_embedd[negatives,:])
127
+ return gene_embedd_triplet_loss + second_input_embedd_triplet_loss
128
+ else:
129
+ return 0
130
+
131
+ def focal_loss(self,predicted_labels,y):
132
+ y = y.unsqueeze(dim=1)
133
+ y_new = torch.zeros(y.shape[0], 2).type(torch.cuda.FloatTensor)
134
+ y_new[range(y.shape[0]), y[:,0]]=1
135
+ BCE_loss = F.binary_cross_entropy_with_logits(predicted_labels.float(), y_new.float(), reduction='none')
136
+ pt = torch.exp(-BCE_loss) # prevents nans when probability 0
137
+ F_loss = (1-pt)**2 * BCE_loss
138
+ loss = 10*F_loss.mean()
139
+ return loss
140
+
141
+ def contrastive_loss(self,cosine_similarity,batch_size):
142
+ j = -torch.sum(torch.diagonal(cosine_similarity))
143
+
144
+ cosine_similarity.diagonal().copy_(torch.zeros(cosine_similarity.size(0)))
145
+
146
+ j = (1 - self.train_config.label_smoothing_sim) * j + (
147
+ self.train_config.label_smoothing_sim / (cosine_similarity.size(0) * (cosine_similarity.size(0) - 1))
148
+ ) * torch.sum(cosine_similarity)
149
+
150
+ j += torch.sum(torch.logsumexp(cosine_similarity, dim=0))
151
+
152
+ if j < 0:
153
+ j = j-j
154
+ return j/batch_size
155
+
156
+ def forward(self, embedds: List[torch.Tensor], y=None) -> torch.Tensor:
157
+ self.batch_idx += 1
158
+ gene_embedd, second_input_embedd, predicted_labels,curr_epoch = embedds
159
+
160
+
161
+ loss = self.clf_loss_fn(predicted_labels,y.squeeze())
162
+
163
+
164
+
165
+ return loss
transforna/src/callbacks/metrics.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import skorch
5
+ import torch
6
+ from sklearn.metrics import confusion_matrix, make_scorer
7
+ from skorch.callbacks import BatchScoring
8
+ from skorch.callbacks.scoring import ScoringBase, _cache_net_forward_iter
9
+ from skorch.callbacks.training import Checkpoint
10
+
11
+ from .LRCallback import LearningRateDecayCallback
12
+
13
+ writer = None
14
+
15
+ def accuracy_score(y_true, y_pred: torch.tensor,task:str=None,mirna_flag:bool = False):
16
+ #sample
17
+
18
+ # premirna
19
+ if task == "premirna":
20
+ y_pred = y_pred[:,:-1]
21
+ miRNA_idx = np.where(y_true.squeeze()==mirna_flag)
22
+ correct = torch.max(y_pred,1).indices.cpu().numpy()[miRNA_idx] == mirna_flag
23
+ return sum(correct)
24
+
25
+ # sncrna
26
+ if task == "sncrna":
27
+ y_pred = y_pred[:,:-1]
28
+ # correct is of [samples], where each entry is true if it was found in top k
29
+ correct = torch.max(y_pred,1).indices.cpu().numpy() == y_true.squeeze()
30
+
31
+ return sum(correct) / y_pred.shape[0]
32
+
33
+
34
+ def accuracy_score_tcga(y_true, y_pred):
35
+
36
+ if torch.is_tensor(y_pred):
37
+ y_pred = y_pred.clone().detach().cpu().numpy()
38
+ if torch.is_tensor(y_true):
39
+ y_true = y_true.clone().detach().cpu().numpy()
40
+
41
+ #y pred contains logits | samples weights
42
+ sample_weight = y_pred[:,-1]
43
+ y_pred = np.argmax(y_pred[:,:-1],axis=1)
44
+
45
+ C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
46
+ with np.errstate(divide='ignore', invalid='ignore'):
47
+ per_class = np.diag(C) / C.sum(axis=1)
48
+ if np.any(np.isnan(per_class)):
49
+ per_class = per_class[~np.isnan(per_class)]
50
+ score = np.mean(per_class)
51
+ return score
52
+
53
+ def score_callbacks(cfg):
54
+
55
+ acc_scorer = make_scorer(accuracy_score,task=cfg["task"])
56
+ if cfg['task'] == 'tcga':
57
+ acc_scorer = make_scorer(accuracy_score_tcga)
58
+
59
+
60
+ if cfg["task"] == "premirna":
61
+ acc_scorer_mirna = make_scorer(accuracy_score,task=cfg["task"],mirna_flag = True)
62
+
63
+ val_score_callback_mirna = BatchScoringPremirna( mirna_flag=True,
64
+ scoring = acc_scorer_mirna, lower_is_better=False, name="val_acc_mirna")
65
+
66
+ train_score_callback_mirna = BatchScoringPremirna(mirna_flag=True,
67
+ scoring = acc_scorer_mirna, on_train=True, lower_is_better=False, name="train_acc_mirna")
68
+
69
+ val_score_callback = BatchScoringPremirna(mirna_flag=False,
70
+ scoring = acc_scorer, lower_is_better=False, name="val_acc")
71
+
72
+ train_score_callback = BatchScoringPremirna(mirna_flag=False,
73
+ scoring = acc_scorer, on_train=True, lower_is_better=False, name="train_acc")
74
+
75
+
76
+ scoring_callbacks = [
77
+ train_score_callback,
78
+ train_score_callback_mirna
79
+ ]
80
+ if cfg["train_split"]:
81
+ scoring_callbacks.extend([val_score_callback_mirna,val_score_callback])
82
+
83
+ if cfg["task"] in ["sncrna", "tcga"]:
84
+
85
+ val_score_callback = BatchScoring(acc_scorer, lower_is_better=False, name="val_acc")
86
+ train_score_callback = BatchScoring(
87
+ acc_scorer, on_train=True, lower_is_better=False, name="train_acc"
88
+ )
89
+ scoring_callbacks = [train_score_callback]
90
+
91
+ #tcga dataset has a predifined valid split, so train_split is false, but still valid metric is required
92
+ #TODO: remove predifined valid from tcga from prepare_data_tcga
93
+ if cfg["train_split"] or cfg['task'] == 'tcga':
94
+ scoring_callbacks.append(val_score_callback)
95
+
96
+ return scoring_callbacks
97
+
98
+ def get_callbacks(path,cfg):
99
+
100
+ callback_list = [("lrcallback", LearningRateDecayCallback)]
101
+ if cfg['tensorboard'] == True:
102
+ from .tbWriter import writer
103
+ callback_list.append(MetricsVizualization)
104
+
105
+ if (cfg["train_split"] or cfg['task'] == 'tcga') and cfg['inference'] == False:
106
+ monitor = "val_acc_best"
107
+ if cfg['trained_on'] == 'full':
108
+ monitor = 'train_acc_best'
109
+ ckpt_path = path+"/ckpt/"
110
+ try:
111
+ os.mkdir(ckpt_path)
112
+ except:
113
+ pass
114
+ model_name = f'model_params_{cfg["task"]}.pt'
115
+ callback_list.append(Checkpoint(monitor=monitor, dirname=ckpt_path,f_params=model_name))
116
+
117
+ scoring_callbacks = score_callbacks(cfg)
118
+ #TODO: For some reason scoring callbaks have to be inserted before checpoint and metrics viz callbacks
119
+ #otherwise NeuralNet notify function throws an exception
120
+ callback_list[1:1] = scoring_callbacks
121
+
122
+ return callback_list
123
+
124
+
125
+ class MetricsVizualization(skorch.callbacks.Callback):
126
+ def __init__(self, batch_idx=0) -> None:
127
+ super().__init__()
128
+ self.batch_idx = batch_idx
129
+
130
+ # TODO: Change to display metrics at epoch ends
131
+ def on_batch_end(self, net, training, **kwargs):
132
+ # validation batch
133
+ if not training:
134
+ # log val accuracy. accessing net.history:[ epoch ,batches, last batch,column in batch]
135
+ writer.add_scalar(
136
+ "Accuracy/val_acc",
137
+ net.history[-1, "batches", -1, "val_acc"],
138
+ self.batch_idx,
139
+ )
140
+ # log val loss
141
+ writer.add_scalar(
142
+ "Loss/val_loss",
143
+ net.history[-1, "batches", -1, "valid_loss"],
144
+ self.batch_idx,
145
+ )
146
+ # update batch idx after validation on batch is computed
147
+ # train batch
148
+ else:
149
+ # log lr
150
+ writer.add_scalar("Metrics/lr", net.lr, self.batch_idx)
151
+ # log train accuracy
152
+ writer.add_scalar(
153
+ "Accuracy/train_acc",
154
+ net.history[-1, "batches", -1, "train_acc"],
155
+ self.batch_idx,
156
+ )
157
+ # log train loss
158
+ writer.add_scalar(
159
+ "Loss/train_loss",
160
+ net.history[-1, "batches", -1, "train_loss"],
161
+ self.batch_idx,
162
+ )
163
+ self.batch_idx += 1
164
+
165
+ class BatchScoringPremirna(ScoringBase):
166
+ def __init__(self,mirna_flag:bool = False,*args,**kwargs):
167
+ super().__init__(*args,**kwargs)
168
+ #self.total_num_samples = total_num_samples
169
+ self.total_num_samples = 0
170
+ self.mirna_flag = mirna_flag
171
+ self.first_batch_flag = True
172
+ def on_batch_end(self, net, X, y, training, **kwargs):
173
+ if training != self.on_train:
174
+ return
175
+
176
+ y_preds = [kwargs['y_pred']]
177
+ #only for the first batch: get no. of samples belonging to same class samples
178
+ if self.first_batch_flag:
179
+ self.total_num_samples += sum(kwargs["batch"][1] == self.mirna_flag).detach().cpu().numpy()[0]
180
+
181
+ with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net:
182
+ # In case of y=None we will not have gathered any samples.
183
+ # We expect the scoring function to deal with y=None.
184
+ y = None if y is None else self.target_extractor(y)
185
+ try:
186
+ score = self._scoring(cached_net, X, y)
187
+ cached_net.history.record_batch(self.name_, score)
188
+ except KeyError:
189
+ pass
190
+ def get_avg_score(self, history):
191
+ if self.on_train:
192
+ bs_key = 'train_batch_size'
193
+ else:
194
+ bs_key = 'valid_batch_size'
195
+
196
+ weights, scores = list(zip(
197
+ *history[-1, 'batches', :, [bs_key, self.name_]]))
198
+ #score_avg = np.average(scores, weights=weights)
199
+ score_avg = sum(scores)/self.total_num_samples
200
+ return score_avg
201
+
202
+ # pylint: disable=unused-argument
203
+ def on_epoch_end(self, net, **kwargs):
204
+ self.first_batch_flag = False
205
+ history = net.history
206
+ try: # don't raise if there is no valid data
207
+ history[-1, 'batches', :, self.name_]
208
+ except KeyError:
209
+ return
210
+
211
+ score_avg = self.get_avg_score(history)
212
+ is_best = self._is_best_score(score_avg)
213
+ if is_best:
214
+ self.best_score_ = score_avg
215
+
216
+ history.record(self.name_, score_avg)
217
+ if is_best is not None:
218
+ history.record(self.name_ + '_best', bool(is_best))
transforna/src/callbacks/tbWriter.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+
4
+ from torch.utils.tensorboard import SummaryWriter
5
+
6
+ writer = SummaryWriter(str(Path(__file__).parent.parent.parent.absolute())+"/runs/")
transforna/src/inference/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .inference_api import *
2
+ from .inference_benchmark import *
3
+ from .inference_tcga import *
transforna/src/inference/inference_api.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ import warnings
4
+ from argparse import ArgumentParser
5
+ from contextlib import redirect_stdout
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from hydra.utils import instantiate
13
+ from omegaconf import OmegaConf
14
+ from sklearn.preprocessing import StandardScaler
15
+ from umap import UMAP
16
+
17
+ from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split
18
+ from ..processing.seq_tokenizer import SeqTokenizer
19
+ from ..utils.file import load
20
+ from ..utils.tcga_post_analysis_utils import Results_Handler
21
+ from ..utils.utils import (get_model, infer_from_pd,
22
+ prepare_inference_results_tcga,
23
+ update_config_with_inference_params)
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+ def aggregate_ensemble_model(lev_dist_df:pd.DataFrame):
30
+ '''
31
+ This function aggregates the predictions of the ensemble model by choosing the model with the lowest and the highest NLD per query sequence.
32
+ If the lowest NLD is lower than Novelty Threshold, then the model with the lowest NLD is chosen as the ensemble prediction.
33
+ Otherwise, the model with the highest NLD is chosen as the ensemble prediction.
34
+ '''
35
+ #for every sequence, if at least one model scores an NLD < Novelty Threshold, then get the one with the least NLD as the ensemble prediction
36
+ #otherwise, get the highest NLD.
37
+ #get the minimum NLD per query sequence
38
+ #remove the baseline model
39
+ baseline_df = lev_dist_df[lev_dist_df['Model'] == 'Baseline'].reset_index(drop=True)
40
+ lev_dist_df = lev_dist_df[lev_dist_df['Model'] != 'Baseline'].reset_index(drop=True)
41
+ min_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmin().values]
42
+ #get the maximum NLD per query sequence
43
+ max_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmax().values]
44
+ #choose between each row in min_lev_dist_df and max_lev_dist_df based on the value of Novelty Threshold
45
+ novel_mask_df = min_lev_dist_df['NLD'] > min_lev_dist_df['Novelty Threshold']
46
+ #get the rows where NLD is lower than Novelty Threshold
47
+ min_lev_dist_df = min_lev_dist_df[~novel_mask_df.values]
48
+ #get the rows where NLD is higher than Novelty Threshold
49
+ max_lev_dist_df = max_lev_dist_df[novel_mask_df.values]
50
+ #merge min_lev_dist_df and max_lev_dist_df
51
+ ensemble_lev_dist_df = pd.concat([min_lev_dist_df,max_lev_dist_df])
52
+ #add ensemble model
53
+ ensemble_lev_dist_df['Model'] = 'Ensemble'
54
+ #add ensemble_lev_dist_df to lev_dist_df
55
+ lev_dist_df = pd.concat([lev_dist_df,ensemble_lev_dist_df,baseline_df])
56
+ return lev_dist_df.reset_index(drop=True)
57
+
58
+
59
+ def read_inference_model_config(model:str,mc_or_sc,trained_on:str,path_to_models:str):
60
+ transforna_folder = "TransfoRNA_ID"
61
+ if trained_on == "full":
62
+ transforna_folder = "TransfoRNA_FULL"
63
+
64
+ model_path = f"{path_to_models}/{transforna_folder}/{mc_or_sc}/{model}/meta/hp_settings.yaml"
65
+ cfg = OmegaConf.load(model_path)
66
+ return cfg
67
+
68
+ def predict_transforna(sequences: List[str], model: str = "Seq-Rev", mc_or_sc:str='sub_class',\
69
+ logits_flag:bool = False,attention_flag:bool = False,\
70
+ similarity_flag:bool=False,n_sim:int=3,embedds_flag:bool = False, \
71
+ umap_flag:bool = False,trained_on:str='full',path_to_models:str='') -> pd.DataFrame:
72
+ '''
73
+ This function predicts the major class or sub class of a list of sequences using the TransfoRNA model.
74
+ Additionaly, it can return logits, attention scores, similarity scores, gene embeddings or umap embeddings.
75
+
76
+ Input:
77
+ sequences: list of sequences to predict
78
+ model: model to use for prediction
79
+ mc_or_sc: models trained on major class or sub class
80
+ logits_flag: whether to return logits
81
+ attention_flag: whether to return attention scores (obtained from the self-attention layer)
82
+ similarity_flag: whether to return explanatory/similar sequences in the training set
83
+ n_sim: number of similar sequences to return
84
+ embedds_flag: whether to return embeddings of the sequences
85
+ umap_flag: whether to return umap embeddings
86
+ trained_on: whether to use the model trained on the full dataset or the ID dataset
87
+ Output:
88
+ pd.DataFrame with the predictions
89
+ '''
90
+ #assers that only one flag is True
91
+ assert sum([logits_flag,attention_flag,similarity_flag,embedds_flag,umap_flag]) <= 1, 'One option at most can be True'
92
+ # capitalize the first letter of the model and the first letter after the -
93
+ model = "-".join([word.capitalize() for word in model.split("-")])
94
+ cfg = read_inference_model_config(model,mc_or_sc,trained_on,path_to_models)
95
+ cfg = update_config_with_inference_params(cfg,mc_or_sc,trained_on,path_to_models)
96
+ root_dir = Path(__file__).parents[1].absolute()
97
+
98
+ with redirect_stdout(None):
99
+ cfg, net = get_model(cfg, root_dir)
100
+ #original_infer_pd might include seqs that are longer than input model. if so, infer_pd contains the trimmed sequences
101
+ infer_pd = pd.Series(sequences, name="Sequences").to_frame()
102
+ predicted_labels, logits, gene_embedds_df,attn_scores_pd,all_data, max_len, net,_ = infer_from_pd(cfg, net, infer_pd, SeqTokenizer,attention_flag)
103
+
104
+ if model == 'Seq':
105
+ gene_embedds_df = gene_embedds_df.iloc[:,:int(gene_embedds_df.shape[1]/2)]
106
+ if logits_flag:
107
+ cfg['log_logits'] = True
108
+ prepare_inference_results_tcga(cfg, predicted_labels, logits, all_data, max_len)
109
+ infer_pd = all_data["infere_rna_seq"]
110
+
111
+ if logits_flag:
112
+ logits_df = infer_pd.rename_axis("Sequence").reset_index()
113
+ logits_cols = [col for col in infer_pd.columns if "Logits" in col]
114
+ logits_df = infer_pd[logits_cols]
115
+ logits_df.columns = pd.MultiIndex.from_tuples(logits_df.columns, names=["Logits", "Sub Class"])
116
+ logits_df.columns = logits_df.columns.droplevel(0)
117
+ return logits_df
118
+
119
+ elif attention_flag:
120
+ return attn_scores_pd
121
+
122
+ elif embedds_flag:
123
+ return gene_embedds_df
124
+
125
+ else: #return table with predictions, entropy, threshold, is familiar
126
+ #add aa predictions to infer_pd
127
+ embedds_path = '/'.join(cfg['inference_settings']["model_path"].split('/')[:-2])+'/embedds'
128
+ results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
129
+ results.get_knn_model()
130
+ lv_threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"]
131
+ logger.info(f'computing levenstein distance for the inference set')
132
+ #prepare infer split
133
+ gene_embedds_df.columns = results.embedds_cols[:len(gene_embedds_df.columns)]
134
+ #add index of gene_embedds_df to be a column with name results.seq_col
135
+ gene_embedds_df[results.seq_col] = gene_embedds_df.index
136
+ #set gene_embedds_df as the new infer split
137
+ results.splits_df_dict['infer_df'] = gene_embedds_df
138
+
139
+
140
+ _,_,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,'infer',num_neighbors=n_sim)
141
+
142
+ if similarity_flag:
143
+ #create df
144
+ sim_df = pd.DataFrame()
145
+ #populate query sequences and duplicate them n times
146
+ sequences = gene_embedds_df.index.tolist()
147
+ #duplicate each sequence n_sim times
148
+ sequences_duplicated = [seq for seq in sequences for _ in range(n_sim)]
149
+ sim_df['Sequence'] = sequences_duplicated
150
+ #assign top_5_seqs list to df column
151
+ sim_df[f'Explanatory Sequence'] = top_n_seqs
152
+ sim_df['NLD'] = lev_dist
153
+ sim_df['Explanatory Label'] = top_n_labels
154
+ sim_df['Novelty Threshold'] = lv_threshold
155
+ #for every query sequence, order the NLD in a increasing order
156
+ sim_df = sim_df.sort_values(by=['Sequence','NLD'],ascending=[False,True])
157
+ return sim_df
158
+
159
+ logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}')
160
+ #for every n_sim elements in the list, get the smallest levenstein distance
161
+ lv_dist_closest = [min(lev_dist[i:i+n_sim]) for i in range(0,len(lev_dist),n_sim)]
162
+ top_n_labels_closest = [top_n_labels[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)]
163
+ top_n_seqs_closest = [top_n_seqs[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)]
164
+ infer_pd['Is Familiar?'] = [True if lv<lv_threshold else False for lv in lv_dist_closest]
165
+
166
+ if umap_flag:
167
+ #compute umap
168
+ logger.info(f'computing umap for the inference set')
169
+ gene_embedds_df = gene_embedds_df.drop(results.seq_col,axis=1)
170
+ umap = UMAP(n_components=2,random_state=42)
171
+ scaled_embedds = StandardScaler().fit_transform(gene_embedds_df.values)
172
+ gene_embedds_df = pd.DataFrame(umap.fit_transform(scaled_embedds),columns=['UMAP1','UMAP2'])
173
+ gene_embedds_df['Net-Label'] = infer_pd['Net-Label'].values
174
+ gene_embedds_df['Is Familiar?'] = infer_pd['Is Familiar?'].values
175
+ gene_embedds_df['Explanatory Label'] = top_n_labels_closest
176
+ gene_embedds_df['Explanatory Sequence'] = top_n_seqs_closest
177
+ gene_embedds_df['Sequence'] = infer_pd.index
178
+ return gene_embedds_df
179
+
180
+ #override threshold
181
+ infer_pd['Novelty Threshold'] = lv_threshold
182
+ infer_pd['NLD'] = lv_dist_closest
183
+ infer_pd['Explanatory Label'] = top_n_labels_closest
184
+ infer_pd['Explanatory Sequence'] = top_n_seqs_closest
185
+ infer_pd = infer_pd.round({"NLD": 2, "Novelty Threshold": 2})
186
+ logger.info(f'num of new hico based on levenstein distance is {np.sum(infer_pd["Is Familiar?"])}')
187
+ return infer_pd.rename_axis("Sequence").reset_index()
188
+
189
+ def predict_transforna_all_models(sequences: List[str], mc_or_sc:str = 'sub_class',logits_flag: bool = False, attention_flag: bool = False,\
190
+ similarity_flag: bool = False, n_sim:int = 3,
191
+ embedds_flag:bool=False, umap_flag:bool = False, trained_on:str="full",path_to_models:str='') -> pd.DataFrame:
192
+ """
193
+ Predicts the labels of the sequences using all the models available in the transforna package.
194
+ If non of the flags are true, it constructs and aggrgates the output of the ensemble model.
195
+
196
+ Input:
197
+ sequences: list of sequences to predict
198
+ mc_or_sc: models trained on major class or sub class
199
+ logits_flag: whether to return logits
200
+ attention_flag: whether to return attention scores (obtained from the self-attention layer)
201
+ similarity_flag: whether to return explanatory/similar sequences in the training set
202
+ n_sim: number of similar sequences to return
203
+ embedds_flag: whether to return embeddings of the sequences
204
+ umap_flag: whether to return umap embeddings
205
+ trained_on: whether to use the model trained on the full dataset or the ID dataset
206
+ Output:
207
+ df: dataframe with the predictions
208
+ """
209
+ now = datetime.now()
210
+ before_time = now.strftime("%H:%M:%S")
211
+ models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"]
212
+ if similarity_flag or embedds_flag: #remove baseline, takes long time
213
+ models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"]
214
+ if attention_flag: #remove single based transformer models
215
+ models = ["Seq", "Seq-Struct", "Seq-Rev"]
216
+ df = None
217
+ for model in models:
218
+ logger.info(model)
219
+ df_ = predict_transforna(sequences, model, mc_or_sc,logits_flag,attention_flag,similarity_flag,n_sim,embedds_flag,umap_flag,trained_on=trained_on,path_to_models = path_to_models)
220
+ df_["Model"] = model
221
+ df = pd.concat([df, df_], axis=0)
222
+ #aggregate ensemble model if not of the flags are true
223
+ if not logits_flag and not attention_flag and not similarity_flag and not embedds_flag and not umap_flag:
224
+ df = aggregate_ensemble_model(df)
225
+
226
+ now = datetime.now()
227
+ after_time = now.strftime("%H:%M:%S")
228
+ delta_time = datetime.strptime(after_time, "%H:%M:%S") - datetime.strptime(before_time, "%H:%M:%S")
229
+ logger.info(f"Time taken: {delta_time}")
230
+
231
+ return df
232
+
233
+
234
+ if __name__ == "__main__":
235
+ parser = ArgumentParser()
236
+ parser.add_argument("sequences", nargs="+")
237
+ parser.add_argument("--logits_flag", nargs="?", const = True,default=False)
238
+ parser.add_argument("--attention_flag", nargs="?", const = True,default=False)
239
+ parser.add_argument("--similarity_flag", nargs="?", const = True,default=False)
240
+ parser.add_argument("--n_sim", nargs="?", const = 3,default=3)
241
+ parser.add_argument("--embedds_flag", nargs="?", const = True,default=False)
242
+ parser.add_argument("--trained_on", nargs="?", const = True,default="full")
243
+ predict_transforna_all_models(**vars(parser.parse_args()))
transforna/src/inference/inference_benchmark.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ from typing import Dict
4
+
5
+ from ..callbacks.metrics import accuracy_score
6
+ from ..processing.seq_tokenizer import SeqTokenizer
7
+ from ..score.score import infer_from_model, infer_testset
8
+ from ..utils.file import load, save
9
+ from ..utils.utils import *
10
+
11
+
12
+ def infer_benchmark(cfg:Dict= None,path:str = None):
13
+ if cfg['tensorboard']:
14
+ from ..callbacks.tbWriter import writer
15
+
16
+ model = cfg["model_name"]+'_'+cfg['task']
17
+
18
+ #set seed
19
+ set_seed_and_device(cfg["seed"],cfg["device_number"])
20
+ #get data
21
+ ad = load(cfg["train_config"].dataset_path_train)
22
+
23
+ #instantiate dataset class
24
+ dataset_class = SeqTokenizer(ad.var,cfg)
25
+ test_data = load(cfg["train_config"].dataset_path_test)
26
+ #prepare data for training and inference
27
+ all_data = prepare_data_benchmark(dataset_class,test_data,cfg)
28
+
29
+
30
+
31
+ #sync skorch config with params in train and model config
32
+ sync_skorch_with_config(cfg["model"]["skorch_model"],cfg)
33
+
34
+ # instantiate skorch model
35
+ net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path)
36
+ net.initialize()
37
+ net.load_params(f_params=f'{cfg["inference_settings"]["model_path"]}')
38
+
39
+ #perform inference on task specific testset
40
+ if cfg["inference_settings"]["infere_original_testset"]:
41
+ infer_testset(net,cfg,all_data,accuracy_score)
42
+
43
+ #inference on custom data
44
+ predicted_labels,logits,_,_ = infer_from_model(net,all_data["infere_data"])
45
+ prepare_inference_results_benchmarck(net,cfg,predicted_labels,logits,all_data)
46
+ save(path=Path(__file__).parent.parent.absolute() / f'inference_results_{model}',data=all_data["infere_rna_seq"])
47
+ if cfg['tensorboard']:
48
+ writer.close()
transforna/src/inference/inference_tcga.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict
3
+
4
+ from anndata import AnnData
5
+
6
+ from ..processing.seq_tokenizer import SeqTokenizer
7
+ from ..utils.file import load
8
+ from ..utils.utils import *
9
+
10
+
11
+ def infer_tcga(cfg:Dict= None,path:str = None):
12
+ if cfg['tensorboard']:
13
+ from ..callbacks.tbWriter import writer
14
+ cfg,net = get_model(cfg,path)
15
+ inference_path = cfg['inference_settings']['sequences_path']
16
+ original_infer_df = load(inference_path, index_col=0)
17
+ if isinstance(original_infer_df,AnnData):
18
+ original_infer_df = original_infer_df.var
19
+ predicted_labels,logits,_,_,all_data,max_len,net,infer_df = infer_from_pd(cfg,net,original_infer_df,SeqTokenizer)
20
+
21
+ #create inference_output if it does not exist
22
+ if not os.path.exists(f"inference_output"):
23
+ os.makedirs(f"inference_output")
24
+ if cfg['log_embedds']:
25
+ embedds_pd = log_embedds(cfg,net,all_data['infere_rna_seq'])
26
+ embedds_pd.to_csv(f"inference_output/{cfg['model_name']}_embedds.csv")
27
+
28
+ prepare_inference_results_tcga(cfg,predicted_labels,logits,all_data,max_len)
29
+
30
+ #if sequences were trimmed, add mapping of trimmed sequences to original sequences
31
+ if original_infer_df.shape[0] != infer_df.shape[0]:
32
+ all_data["infere_rna_seq"] = add_original_seqs_to_predictions(infer_df,all_data['infere_rna_seq'])
33
+ #save
34
+ all_data["infere_rna_seq"].to_csv(f"inference_output/{cfg['model_name']}_inference_results.csv")
35
+
36
+ if cfg['tensorboard']:
37
+ writer.close()
38
+ return predicted_labels
transforna/src/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model_components import *
2
+ from .skorchWrapper import *
transforna/src/model/model_components.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ import math
4
+ import random
5
+ from typing import Dict, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from omegaconf import DictConfig
12
+ from torch.nn.modules.normalization import LayerNorm
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def circulant_mask(n: int, window: int) -> torch.Tensor:
17
+ """Calculate the relative attention mask, calculated once when model instatiated, as a subset of this matrix
18
+ will be used for a input length less than max.
19
+ i,j represent relative token positions in this matrix and in the attention scores matrix,
20
+ this mask enables attention scores to be set to 0 if further than the specified window length
21
+
22
+ :param n: a fixed parameter set to be larger than largest max sequence length across batches
23
+ :param window: [window length],
24
+ :return relative attention mask
25
+ """
26
+ circulant_t = torch.zeros(n, n)
27
+ # [0, 1, 2, ..., window, -1, -2, ..., window]
28
+ offsets = [0] + [i for i in range(window + 1)] + [-i for i in range(window + 1)]
29
+ if window >= n:
30
+ return torch.ones(n, n)
31
+ for offset in offsets:
32
+ # size of the 1-tensor depends on the length of the diagonal
33
+ circulant_t.diagonal(offset=offset).copy_(torch.ones(n - abs(offset)))
34
+ return circulant_t
35
+
36
+
37
+ class SelfAttention(nn.Module):
38
+
39
+ """normal query, key, value based self attention but with relative attention functionality
40
+ and a learnable bias encoding relative token position which is added to the attention scores before the softmax"""
41
+
42
+ def __init__(self, config: DictConfig, relative_attention: int):
43
+ """init self attention weight of each key, query, value and output projection layer.
44
+
45
+ :param config: model config
46
+ :type config: ConveRTModelConfig
47
+ """
48
+ super().__init__()
49
+
50
+ self.config = config
51
+ self.query = nn.Linear(config.num_embed_hidden, config.num_attention_project)
52
+ self.key = nn.Linear(config.num_embed_hidden, config.num_attention_project)
53
+ self.value = nn.Linear(config.num_embed_hidden, config.num_attention_project)
54
+
55
+ self.softmax = nn.Softmax(dim=-1)
56
+ self.output_projection = nn.Linear(
57
+ config.num_attention_project, config.num_embed_hidden
58
+ )
59
+ self.bias = torch.nn.Parameter(torch.randn(config.n), requires_grad=True)
60
+ stdv = 1.0 / math.sqrt(self.bias.data.size(0))
61
+ self.bias.data.uniform_(-stdv, stdv)
62
+ self.relative_attention = relative_attention
63
+ self.n = self.config.n
64
+ self.half_n = self.n // 2
65
+ self.register_buffer(
66
+ "relative_mask",
67
+ circulant_mask(config.tokens_len, self.relative_attention),
68
+ )
69
+
70
+ def forward(
71
+ self, attn_input: torch.Tensor, attention_mask: torch.Tensor
72
+ ) -> torch.Tensor:
73
+ """calculate self-attention of query, key and weighted to value at the end.
74
+ self-attention input is projected by linear layer at the first time.
75
+ applying attention mask for ignore pad index attention weight. Relative attention mask
76
+ applied and a learnable bias added to the attention scores.
77
+ return value after apply output projection layer to value * attention
78
+
79
+ :param attn_input: [description]
80
+ :type attn_input: [type]
81
+ :param attention_mask: [description], defaults to None
82
+ :type attention_mask: [type], optional
83
+ :return: [description]
84
+ :rtype: [type]
85
+ """
86
+ self.T = attn_input.size()[1]
87
+ # input is B x max seq len x n_emb
88
+ _query = self.query.forward(attn_input)
89
+ _key = self.key.forward(attn_input)
90
+ _value = self.value.forward(attn_input)
91
+
92
+ # scaled dot product
93
+ attention_scores = torch.matmul(_query, _key.transpose(1, 2))
94
+ attention_scores = attention_scores / math.sqrt(
95
+ self.config.num_attention_project
96
+ )
97
+
98
+ # Relative attention
99
+
100
+ # extended_attention_mask = attention_mask.to(attention_scores.device) # fp16 compatibility
101
+ extended_attention_mask = (1.0 - attention_mask.unsqueeze(-1)) * -10000.0
102
+ attention_scores = attention_scores + extended_attention_mask
103
+
104
+ # fix circulant_matrix to matrix of size 60 x60 (max token truncation_length,
105
+ # register as buffer, so not keep creating masks of different sizes.
106
+
107
+ attention_scores = attention_scores.masked_fill(
108
+ self.relative_mask.unsqueeze(0)[:, : self.T, : self.T] == 0, float("-inf")
109
+ )
110
+
111
+ # Learnable bias vector is used of max size,for each i, different subsets of it are added to the scores, where the permutations
112
+ # depend on the relative position (i-j). this way cleverly allows no loops. bias vector is 2*max truncation length+1
113
+ # so has a learnable parameter for each eg. (i-j) /in {-60,...60} .
114
+
115
+ ii, jj = torch.meshgrid(torch.arange(self.T), torch.arange(self.T))
116
+ B_matrix = self.bias[self.n // 2 - ii + jj]
117
+
118
+ attention_scores = attention_scores + B_matrix.unsqueeze(0)
119
+
120
+ attention_scores = self.softmax(attention_scores)
121
+ output = torch.matmul(attention_scores, _value)
122
+
123
+ output = self.output_projection(output)
124
+
125
+ return [output,attention_scores] # B x T x num embed hidden
126
+
127
+
128
+
129
+ class FeedForward1(nn.Module):
130
+ def __init__(
131
+ self, input_hidden: int, intermediate_hidden: int, dropout_rate: float = 0.0
132
+ ):
133
+ # 512 2048
134
+
135
+ super().__init__()
136
+
137
+ self.linear_1 = nn.Linear(input_hidden, intermediate_hidden)
138
+ self.dropout = nn.Dropout(dropout_rate)
139
+ self.linear_2 = nn.Linear(intermediate_hidden, input_hidden)
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+
143
+ x = F.gelu(self.linear_1(x))
144
+ return self.linear_2(self.dropout(x))
145
+
146
+
147
+ class SharedInnerBlock(nn.Module):
148
+ def __init__(self, config: DictConfig, relative_attn: int):
149
+ super().__init__()
150
+
151
+ self.config = config
152
+ self.self_attention = SelfAttention(config, relative_attn)
153
+ self.norm1 = LayerNorm(config.num_embed_hidden) # 512
154
+ self.dropout = nn.Dropout(config.dropout)
155
+ self.ff1 = FeedForward1(
156
+ config.num_embed_hidden, config.feed_forward1_hidden, config.dropout
157
+ )
158
+ self.norm2 = LayerNorm(config.num_embed_hidden)
159
+
160
+ def forward(self, x: torch.Tensor, attention_mask: int) -> torch.Tensor:
161
+
162
+ new_values_x,attn_scores = self.self_attention(x, attention_mask=attention_mask)
163
+ x = x+new_values_x
164
+ x = self.norm1(x)
165
+ x = x + self.ff1(x)
166
+ return self.norm2(x),attn_scores
167
+
168
+
169
+ # pretty basic, just single head. but done many times, stack to have another dimension (4 with batches).# so get stacks of B x H of attention scores T x T..
170
+ # then matrix multiply these extra stacks with the v
171
+ # (B xnh)x T xT . (Bx nh xTx hs) gives (B Nh) T x hs stacks. now hs is set to be final dimension/ number of heads, so reorder the stacks (concatenating them)
172
+ # can have optional extra projection layer, but doing that later
173
+
174
+
175
+ class MultiheadAttention(nn.Module):
176
+ def __init__(self, config: DictConfig):
177
+ super().__init__()
178
+ self.num_attention_heads = config.num_attention_heads
179
+ self.num_attn_proj = config.num_embed_hidden * config.num_attention_heads
180
+ self.attention_head_size = int(self.num_attn_proj / self.num_attention_heads)
181
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
182
+
183
+ self.query = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
184
+ self.key = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
185
+ self.value = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
186
+
187
+ self.dropout = nn.Dropout(config.dropout)
188
+
189
+ def forward(
190
+ self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
191
+ ) -> torch.Tensor:
192
+ B, T, _ = hidden_states.size()
193
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
194
+ k = (
195
+ self.key(hidden_states)
196
+ .view(B, T, self.num_attention_heads, self.attention_head_size)
197
+ .transpose(1, 2)
198
+ ) # (B, nh, T, hs)
199
+ q = (
200
+ self.query(hidden_states)
201
+ .view(B, T, self.num_attention_heads, self.attention_head_size)
202
+ .transpose(1, 2)
203
+ ) # (B, nh, T, hs)
204
+ v = (
205
+ self.value(hidden_states)
206
+ .view(B, T, self.num_attention_heads, self.attention_head_size)
207
+ .transpose(1, 2)
208
+ ) # (B, nh, T, hs)
209
+
210
+ attention_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
211
+
212
+ if attention_mask is not None:
213
+ attention_mask = attention_mask[:, None, None, :]
214
+ attention_mask = (1.0 - attention_mask) * -10000.0
215
+
216
+ attention_scores = attention_scores + attention_mask
217
+
218
+ attention_scores = F.softmax(attention_scores, dim=-1)
219
+
220
+ attention_scores = self.dropout(attention_scores)
221
+
222
+ y = attention_scores @ v
223
+
224
+ y = y.transpose(1, 2).contiguous().view(B, T, self.num_attn_proj)
225
+
226
+ return y
227
+
228
+
229
+ class PositionalEncoding(nn.Module):
230
+ def __init__(self, model_config: DictConfig,):
231
+ super(PositionalEncoding, self).__init__()
232
+ self.dropout = nn.Dropout(p=model_config.dropout)
233
+ self.num_embed_hidden = model_config.num_embed_hidden
234
+ pe = torch.zeros(model_config.tokens_len, self.num_embed_hidden)
235
+ position = torch.arange(
236
+ 0, model_config.tokens_len, dtype=torch.float
237
+ ).unsqueeze(1)
238
+ div_term = torch.exp(
239
+ torch.arange(0, self.num_embed_hidden, 2).float()
240
+ * (-math.log(10000.0) / self.num_embed_hidden)
241
+ )
242
+ pe[:, 0::2] = torch.sin(position * div_term)
243
+ pe[:, 1::2] = torch.cos(position * div_term)
244
+ pe = pe.unsqueeze(0)
245
+ self.register_buffer("pe", pe)
246
+
247
+ def forward(self, x):
248
+ x = x + self.pe[: x.size(0), :]
249
+ return self.dropout(x)
250
+
251
+
252
+ class RNAFFrwd(
253
+ nn.Module
254
+ ): # params are not shared for context and reply. so need two sets of weights
255
+ """Fully-Connected 3-layer Linear Model"""
256
+
257
+ def __init__(self, model_config: DictConfig):
258
+ """
259
+ :param input_hidden: first-hidden layer input embed-dim
260
+ :type input_hidden: int
261
+ :param intermediate_hidden: layer-(hidden)-layer middle point weight
262
+ :type intermediate_hidden: int
263
+ :param dropout_rate: dropout rate, defaults to None
264
+ :type dropout_rate: float, optional
265
+ """
266
+ # paper specifies,skip connections,layer normalization, and orthogonal initialization
267
+
268
+ super().__init__()
269
+ # 3,679,744 x2 params
270
+ self.rna_ffwd_input_dim = (
271
+ model_config.num_embed_hidden * model_config.num_attention_heads
272
+ )
273
+ self.linear_1 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim)
274
+ self.linear_2 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim)
275
+
276
+ self.norm1 = LayerNorm(self.rna_ffwd_input_dim)
277
+ self.norm2 = LayerNorm(self.rna_ffwd_input_dim)
278
+ self.final = nn.Linear(self.rna_ffwd_input_dim, model_config.num_embed_hidden)
279
+ self.orthogonal_initialization() # torch implementation works perfectly out the box,
280
+
281
+ def orthogonal_initialization(self):
282
+ for l in [
283
+ self.linear_1,
284
+ self.linear_2,
285
+ ]:
286
+ torch.nn.init.orthogonal_(l.weight)
287
+
288
+ def forward(self, x: torch.Tensor, attn_msk: torch.Tensor) -> torch.Tensor:
289
+ sentence_lengths = attn_msk.sum(1)
290
+
291
+ # adding square root reduction projection separately as not a shared.
292
+ # part of the diagram torch.Size([Batch, scent_len, embedd_dim])
293
+
294
+ # x has dims B x T x 2*d_emb
295
+ norms = 1 / torch.sqrt(sentence_lengths.double()).float() # 64
296
+ # TODO: Aggregation is done on all words including the masked ones
297
+ x = norms.unsqueeze(1) * torch.sum(x, dim=1) # 64 x1024
298
+
299
+ x = x + F.gelu(self.linear_1(self.norm1(x)))
300
+ x = x + F.gelu(self.linear_2(self.norm2(x)))
301
+
302
+ return F.normalize(self.final(x), dim=1, p=2) # 64 512
303
+
304
+
305
+ class RNATransformer(nn.Module):
306
+ def __init__(self, model_config: DictConfig):
307
+ super().__init__()
308
+ self.num_embedd_hidden = model_config.num_embed_hidden
309
+ self.encoder = nn.Embedding(
310
+ model_config.vocab_size, model_config.num_embed_hidden
311
+ )
312
+ self.model_input = model_config.model_input
313
+ if 'baseline' not in self.model_input:
314
+ # positional encoder
315
+ self.pos_encoder = PositionalEncoding(model_config)
316
+
317
+ self.transformer_layers = nn.ModuleList(
318
+ [
319
+ SharedInnerBlock(model_config, int(window/model_config.window))
320
+ for window in model_config.relative_attns[
321
+ : model_config.num_encoder_layers
322
+ ]
323
+ ]
324
+ )
325
+ self.MHA = MultiheadAttention(model_config)
326
+ # self.concatenate = FeedForward2(model_config)
327
+
328
+ self.rna_ffrwd = RNAFFrwd(model_config)
329
+ self.pad_id = 0
330
+
331
+ def forward(self, x:torch.Tensor) -> torch.Tensor:
332
+ if x.is_cuda:
333
+ long_tensor = torch.cuda.LongTensor
334
+ else:
335
+ long_tensor = torch.LongTensor
336
+
337
+ embedds = self.encoder(x)
338
+ if 'baseline' not in self.model_input:
339
+ output = self.pos_encoder(embedds)
340
+ attention_mask = (x != self.pad_id).int()
341
+
342
+ for l in self.transformer_layers:
343
+ output,attn_scores = l(output, attention_mask)
344
+ output = self.MHA(output)
345
+ output = self.rna_ffrwd(output, attention_mask)
346
+ return output,attn_scores
347
+ else:
348
+ embedds = torch.flatten(embedds,start_dim=1)
349
+ return embedds,None
350
+
351
+ class GeneEmbeddModel(nn.Module):
352
+ def __init__(
353
+ self, main_config: DictConfig,
354
+ ):
355
+ super().__init__()
356
+ self.train_config = main_config["train_config"]
357
+ self.model_config = main_config["model_config"]
358
+ self.device = self.train_config.device
359
+ self.model_input = self.model_config["model_input"]
360
+ self.false_input_perc = self.model_config["false_input_perc"]
361
+ #adjust n (used to add rel bias on attn scores)
362
+ self.model_config.n = self.model_config.tokens_len*2+1
363
+ self.transformer_layers = RNATransformer(self.model_config)
364
+ #save tokens_len of sequences to be used to split ids between transformers
365
+ self.tokens_len = self.model_config.tokens_len
366
+ #reassign tokens_len and vocab_size to init a new transformer
367
+ #more clean solution -> RNATransformer and its children should
368
+ # have a flag input indicating which transformer
369
+ self.model_config.tokens_len = self.model_config.second_input_token_len
370
+ self.model_config.n = self.model_config.tokens_len*2+1
371
+ self.seq_vocab_size = self.model_config.vocab_size
372
+ #this differs between both models not the token_len/ss_token_len
373
+ self.model_config.vocab_size = self.model_config.second_input_vocab_size
374
+
375
+ self.second_input_model = RNATransformer(self.model_config)
376
+
377
+ #num_transformers refers to using either one model or two in parallel
378
+ self.num_transformers = 2
379
+ if self.model_input == 'seq':
380
+ self.num_transformers = 1
381
+ # could be moved to model
382
+ self.weight_decay = self.train_config.l2_weight_decay
383
+ if 'baseline' in self.model_input:
384
+ self.num_transformers = 1
385
+ num_nodes = self.model_config.num_embed_hidden*self.tokens_len
386
+ self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes)
387
+ else:
388
+ #setting classification layer
389
+ num_nodes = self.num_transformers*self.model_config.num_embed_hidden
390
+ if self.num_transformers == 1:
391
+ self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes)
392
+ else:
393
+ self.final_clf_1 = nn.Linear(num_nodes,num_nodes)
394
+ self.final_clf_2 = nn.Linear(num_nodes,self.model_config.num_classes)
395
+ self.relu = nn.ReLU()
396
+ self.BN = nn.BatchNorm1d(num_nodes)
397
+ self.dropout = nn.Dropout(0.6)
398
+
399
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
400
+
401
+
402
+ def distort_input(self,x):
403
+ for sample_idx in range(x.shape[0]):
404
+ seq_length = x[sample_idx,-1]
405
+ num_tokens_flipped = int(self.false_input_perc*seq_length)
406
+ max_start_flip_idx = seq_length - num_tokens_flipped
407
+
408
+ random_feat_idx = random.randint(0,max_start_flip_idx-1)
409
+ x[sample_idx,random_feat_idx:random_feat_idx+num_tokens_flipped] = \
410
+ torch.tensor(np.random.choice(range(1,self.seq_vocab_size-1),size=num_tokens_flipped,replace=True))
411
+
412
+ x[sample_idx,random_feat_idx+self.tokens_len:random_feat_idx+self.tokens_len+num_tokens_flipped] = \
413
+ torch.tensor(np.random.choice(range(1,self.model_config.second_input_vocab_size-1),size=num_tokens_flipped,replace=True))
414
+ return x
415
+
416
+ def forward(self, x,train=False):
417
+ if self.device == 'cuda':
418
+ long_tensor = torch.cuda.LongTensor
419
+ float_tensor = torch.cuda.FloatTensor
420
+ else:
421
+ long_tensor = torch.LongTensor
422
+ float_tensor = torch.FloatTensor
423
+ if train:
424
+ if self.false_input_perc > 0:
425
+ x = self.distort_input(x)
426
+
427
+ gene_embedd,attn_scores_first = self.transformer_layers(
428
+ x[:, : self.tokens_len].type(long_tensor)
429
+ )
430
+ attn_scores_second = None
431
+ second_input_embedd,attn_scores_second = self.second_input_model(
432
+ x[:, self.tokens_len :-1].type(long_tensor)
433
+ )
434
+
435
+ #for tcga: if seq or baseline
436
+ if self.num_transformers == 1:
437
+ activations = self.final_clf_1(gene_embedd)
438
+ else:
439
+ out_clf_1 = self.final_clf_1(torch.cat((gene_embedd, second_input_embedd), 1))
440
+ out = self.BN(out_clf_1)
441
+ out = self.relu(out)
442
+ out = self.dropout(out)
443
+ activations = self.final_clf_2(out)
444
+
445
+ #create dummy attn scores for baseline
446
+ if 'baseline' in self.model_input:
447
+ attn_scores_first = torch.ones((1,2,2),device=x.device)
448
+
449
+ return [gene_embedd, second_input_embedd, activations,attn_scores_first,attn_scores_second]
transforna/src/model/skorchWrapper.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pickle
4
+
5
+ import skorch
6
+ import torch
7
+ from skorch.dataset import Dataset, ValidSplit
8
+ from skorch.setter import optimizer_setter
9
+ from skorch.utils import is_dataset, to_device
10
+
11
+ logger = logging.getLogger(__name__)
12
+ #from ..tbWriter import writer
13
+
14
+
15
+ class Net(skorch.NeuralNet):
16
+ def __init__(
17
+ self,
18
+ clip=0.25,
19
+ top_k=1,
20
+ correct=0,
21
+ save_embedding=False,
22
+ gene_embedds=[],
23
+ second_input_embedd=[],
24
+ confidence_threshold = 0.95,
25
+ *args,
26
+ **kwargs
27
+ ):
28
+ self.clip = clip
29
+ self.curr_epoch = 0
30
+ super(Net, self).__init__(*args, **kwargs)
31
+ self.correct = correct
32
+ self.save_embedding = save_embedding
33
+ self.gene_embedds = gene_embedds
34
+ self.second_input_embedds = second_input_embedd
35
+ self.main_config = kwargs["module__main_config"]
36
+ self.train_config = self.main_config["train_config"]
37
+ self.top_k = self.train_config.top_k
38
+ self.num_classes = self.main_config["model_config"].num_classes
39
+ self.labels_mapping_path = self.train_config.labels_mapping_path
40
+ if self.labels_mapping_path:
41
+ with open(self.labels_mapping_path, 'rb') as handle:
42
+ self.labels_mapping_dict = pickle.load(handle)
43
+ self.confidence_threshold = confidence_threshold
44
+ self.max_epochs = kwargs["max_epochs"]
45
+ self.task = '' #is set in utils.instantiate_predictor
46
+ self.log_tb = False
47
+
48
+
49
+
50
+
51
+ def set_save_epoch(self):
52
+ '''
53
+ scale best train epoch by valid size
54
+ '''
55
+ if self.task !='tcga':
56
+ if self.train_split:
57
+ self.save_epoch = self.main_config["train_config"].train_epoch
58
+ else:
59
+ self.save_epoch = round(self.main_config["train_config"].train_epoch*\
60
+ (1+self.main_config["valid_size"]))
61
+
62
+ def save_benchmark_model(self):
63
+ '''
64
+ saves benchmark epochs when train_split is none
65
+ '''
66
+ try:
67
+ os.mkdir("ckpt")
68
+ except:
69
+ pass
70
+ cwd = os.getcwd()+"/ckpt/"
71
+ self.save_params(f_params= f'{cwd}/model_params_{self.main_config["task"]}.pt')
72
+
73
+
74
+ def fit(self, X, y=None, valid_ds=None,**fit_params):
75
+ #all sequence lengths should be saved to compute the median based
76
+ self.all_lengths = [[] for i in range(self.num_classes)]
77
+ self.median_lengths = []
78
+
79
+ if not self.warm_start or not self.initialized_:
80
+ self.initialize()
81
+
82
+ if valid_ds:
83
+ self.validation_dataset = valid_ds
84
+ else:
85
+ self.validation_dataset = None
86
+
87
+ self.partial_fit(X, y, **fit_params)
88
+ return self
89
+
90
+ def fit_loop(self, X, y=None, epochs=None, **fit_params):
91
+ #if id then train longer otherwise stop at 0.99
92
+ rounding_digits = 3
93
+ if self.main_config['trained_on'] == 'full':
94
+ rounding_digits = 2
95
+ self.check_data(X, y)
96
+ epochs = epochs if epochs is not None else self.max_epochs
97
+
98
+ dataset_train, dataset_valid = self.get_split_datasets(X, y, **fit_params)
99
+
100
+ if self.validation_dataset is not None:
101
+ dataset_valid = self.validation_dataset.keywords["valid_ds"]
102
+
103
+ on_epoch_kwargs = {
104
+ "dataset_train": dataset_train,
105
+ "dataset_valid": dataset_valid,
106
+ }
107
+
108
+ iterator_train = self.get_iterator(dataset_train, training=True)
109
+ iterator_valid = None
110
+ if dataset_valid is not None:
111
+ iterator_valid = self.get_iterator(dataset_valid, training=False)
112
+
113
+ self.set_save_epoch()
114
+
115
+ for epoch_no in range(epochs):
116
+ #save model if training only on test set
117
+ self.curr_epoch = epoch_no
118
+ #save epoch is scaled by best train epoch
119
+ #save benchmark only when training on boith train and val sets
120
+ if self.task != 'tcga' and epoch_no == self.save_epoch and self.train_split == None:
121
+ self.save_benchmark_model()
122
+
123
+ self.notify("on_epoch_begin", **on_epoch_kwargs)
124
+
125
+ self.run_single_epoch(
126
+ iterator_train,
127
+ training=True,
128
+ prefix="train",
129
+ step_fn=self.train_step,
130
+ **fit_params
131
+ )
132
+
133
+ if dataset_valid is not None:
134
+ self.run_single_epoch(
135
+ iterator_valid,
136
+ training=False,
137
+ prefix="valid",
138
+ step_fn=self.validation_step,
139
+ **fit_params
140
+ )
141
+
142
+
143
+ self.notify("on_epoch_end", **on_epoch_kwargs)
144
+ #manual early stopping for tcga
145
+ if self.task == 'tcga':
146
+ train_acc = round(self.history[:,'train_acc'][-1],rounding_digits)
147
+ if train_acc == 1:
148
+ break
149
+
150
+
151
+
152
+ return self
153
+
154
+ def train_step(self, X, y=None):
155
+ y = X[1]
156
+ X = X[0]
157
+ sample_weights = X[:,-1]
158
+ if self.device == 'cuda':
159
+ sample_weights = sample_weights.to(self.train_config.device)
160
+ self.module_.train()
161
+ self.module_.zero_grad()
162
+ gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1],train=True)
163
+ #curr_epoch is passed to loss as it is used to switch loss criteria from unsup. -> sup
164
+ loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)
165
+
166
+ ###sup loss should be X with samples weight and aggregated
167
+
168
+ loss = loss*sample_weights
169
+ loss = loss.mean()
170
+
171
+ loss.backward()
172
+
173
+ # TODO: clip only some parameters
174
+ torch.nn.utils.clip_grad_norm_(self.module_.parameters(), self.clip)
175
+ self.optimizer_.step()
176
+
177
+ return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}
178
+
179
+ def validation_step(self, X, y=None):
180
+ y = X[1]
181
+ X = X[0]
182
+ sample_weights = X[:,-1]
183
+ if self.device == 'cuda':
184
+ sample_weights = sample_weights.to(self.train_config.device)
185
+ self.module_.eval()
186
+ with torch.no_grad():
187
+ gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1])
188
+ loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y)
189
+
190
+ ###sup loss should be X with samples weight and aggregated
191
+
192
+ loss = loss*sample_weights
193
+ loss = loss.mean()
194
+
195
+ return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]}
196
+
197
+ def get_attention_scores(self, X, y=None):
198
+ '''
199
+ returns attention scores for a given input
200
+ '''
201
+ self.module_.eval()
202
+ with torch.no_grad():
203
+ _, _, _,attn_scores_first,attn_scores_second = self.module_(X[:,:-1])
204
+
205
+ attn_scores_first = attn_scores_first.detach().cpu().numpy()
206
+ if attn_scores_second is not None:
207
+ attn_scores_second = attn_scores_second.detach().cpu().numpy()
208
+ return attn_scores_first,attn_scores_second
209
+
210
+ def predict(self, X):
211
+ self.module_.train(False)
212
+ embedds = self.module_(X[:,:-1])
213
+ sample_weights = X[:,-1]
214
+ if self.device == 'cuda':
215
+ sample_weights = sample_weights.to(self.train_config.device)
216
+
217
+ gene_embedd, second_input_embedd, activations,_,_ = embedds
218
+ if self.save_embedding:
219
+ self.gene_embedds.append(gene_embedd.detach().cpu())
220
+ #in case only a single transformer is deployed, then second_input_embedd are None. thus have no detach()
221
+ if second_input_embedd is not None:
222
+ self.second_input_embedds.append(second_input_embedd.detach().cpu())
223
+
224
+ predictions = torch.cat([activations,sample_weights[:,None]],dim=1)
225
+ return predictions
226
+
227
+
228
+ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
229
+ # log gradients and weights
230
+ for _, m in self.module_.named_modules():
231
+ for pn, p in m.named_parameters():
232
+ if pn.endswith("weight") and pn.find("norm") < 0:
233
+ if p.grad != None:
234
+ if self.log_tb:
235
+ from ..callbacks.tbWriter import writer
236
+ writer.add_histogram("weights/" + pn, p, len(net.history))
237
+ writer.add_histogram(
238
+ "gradients/" + pn, p.grad.data, len(net.history)
239
+ )
240
+
241
+ return
242
+
243
+ def configure_opt(self, l2_weight_decay):
244
+ no_decay = ["bias", "LayerNorm.weight"]
245
+ params_decay = [
246
+ p
247
+ for n, p in self.module_.named_parameters()
248
+ if not any(nd in n for nd in no_decay)
249
+ ]
250
+ params_nodecay = [
251
+ p
252
+ for n, p in self.module_.named_parameters()
253
+ if any(nd in n for nd in no_decay)
254
+ ]
255
+ optim_groups = [
256
+ {"params": params_decay, "weight_decay": l2_weight_decay},
257
+ {"params": params_nodecay, "weight_decay": 0.0},
258
+ ]
259
+ return optim_groups
260
+
261
+ def initialize_optimizer(self, triggered_directly=True):
262
+ """Initialize the model optimizer. If ``self.optimizer__lr``
263
+ is not set, use ``self.lr`` instead.
264
+
265
+ Parameters
266
+ ----------
267
+ triggered_directly : bool (default=True)
268
+ Only relevant when optimizer is re-initialized.
269
+ Initialization of the optimizer can be triggered directly
270
+ (e.g. when lr was changed) or indirectly (e.g. when the
271
+ module was re-initialized). If and only if the former
272
+ happens, the user should receive a message informing them
273
+ about the parameters that caused the re-initialization.
274
+
275
+ """
276
+ # get learning rate from train config
277
+ optimizer_params = self.main_config["train_config"]
278
+ kwargs = {}
279
+ kwargs["lr"] = optimizer_params.learning_rate
280
+ # get l2 weight decay to init opt params
281
+ args = self.configure_opt(optimizer_params.l2_weight_decay)
282
+
283
+ if self.initialized_ and self.verbose:
284
+ msg = self._format_reinit_msg(
285
+ "optimizer", kwargs, triggered_directly=triggered_directly
286
+ )
287
+ logger.info(msg)
288
+
289
+ self.optimizer_ = self.optimizer(args, lr=kwargs["lr"])
290
+
291
+ self._register_virtual_param(
292
+ ["optimizer__param_groups__*__*", "optimizer__*", "lr"],
293
+ optimizer_setter,
294
+ )
295
+
296
+ def initialize_criterion(self):
297
+ """Initializes the criterion."""
298
+ # critereon takes train_config and model_config as an input.
299
+ # we get both from the module parameters
300
+ self.criterion_ = self.criterion(
301
+ self.main_config
302
+ )
303
+ if isinstance(self.criterion_, torch.nn.Module):
304
+ self.criterion_ = to_device(self.criterion_, self.device)
305
+ return self
306
+
307
+ def initialize_callbacks(self):
308
+ """Initializes all callbacks and save the result in the
309
+ ``callbacks_`` attribute.
310
+
311
+ Both ``default_callbacks`` and ``callbacks`` are used (in that
312
+ order). Callbacks may either be initialized or not, and if
313
+ they don't have a name, the name is inferred from the class
314
+ name. The ``initialize`` method is called on all callbacks.
315
+
316
+ The final result will be a list of tuples, where each tuple
317
+ consists of a name and an initialized callback. If names are
318
+ not unique, a ValueError is raised.
319
+
320
+ """
321
+ if self.callbacks == "disable":
322
+ self.callbacks_ = []
323
+ return self
324
+
325
+ callbacks_ = []
326
+
327
+ class Dummy:
328
+ # We cannot use None as dummy value since None is a
329
+ # legitimate value to be set.
330
+ pass
331
+
332
+ for name, cb in self._uniquely_named_callbacks():
333
+ # check if callback itself is changed
334
+ param_callback = getattr(self, "callbacks__" + name, Dummy)
335
+ if param_callback is not Dummy: # callback itself was set
336
+ cb = param_callback
337
+
338
+ # below: check for callback params
339
+ # don't set a parameter for non-existing callback
340
+
341
+ # if the callback is lrcallback then initializa it with the train config,
342
+ # which is an input to the module
343
+ if name == "lrcallback":
344
+ params["config"] = self.main_config["train_config"]
345
+ else:
346
+ params = self.get_params_for("callbacks__{}".format(name))
347
+ if (cb is None) and params:
348
+ raise ValueError(
349
+ "Trying to set a parameter for callback {} "
350
+ "which does not exist.".format(name)
351
+ )
352
+ if cb is None:
353
+ continue
354
+
355
+ if isinstance(cb, type): # uninitialized:
356
+ cb = cb(**params)
357
+ else:
358
+ cb.set_params(**params)
359
+ cb.initialize()
360
+ callbacks_.append((name, cb))
361
+
362
+ self.callbacks_ = callbacks_
363
+
364
+ return self
transforna/src/novelty_prediction/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .id_vs_ood_entropy_clf import *
2
+ from .id_vs_ood_nld_clf import *