tadsatlawa-na commited on
Commit
15f8f23
·
1 Parent(s): 5b6bdaa

Add example notebook

Browse files
Files changed (1) hide show
  1. example.ipynb +155 -0
example.ipynb ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "# nanoBERT Example\n",
21
+ "\n",
22
+ "Here we present nanoBERT, a nanobody-specific transformer to predict amino\n",
23
+ " acids in a given position in a query sequence"
24
+ ],
25
+ "metadata": {
26
+ "id": "JU2dnhr24egK"
27
+ }
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 1,
32
+ "metadata": {
33
+ "colab": {
34
+ "base_uri": "https://localhost:8080/"
35
+ },
36
+ "id": "gxL4QKeNqYXI",
37
+ "outputId": "ad6c9ed6-8d6a-45f7-ba15-4026b17906d4"
38
+ },
39
+ "outputs": [
40
+ {
41
+ "output_type": "stream",
42
+ "name": "stdout",
43
+ "text": [
44
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.34.0)\n",
45
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.4)\n",
46
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.17.3)\n",
47
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
48
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
49
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
50
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
51
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
52
+ "Requirement already satisfied: tokenizers<0.15,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.14.1)\n",
53
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.0)\n",
54
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
55
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n",
56
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n",
57
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.0)\n",
58
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
59
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.6)\n",
60
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n"
61
+ ]
62
+ }
63
+ ],
64
+ "source": [
65
+ "# Install stadard library\n",
66
+ "! pip install --upgrade transformers"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "from transformers import pipeline, RobertaTokenizer, AutoModel"
73
+ ],
74
+ "metadata": {
75
+ "id": "vG5ndbr_rYjL"
76
+ },
77
+ "execution_count": 2,
78
+ "outputs": []
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "source": [
83
+ "# Initialise the tokenizer\n",
84
+ "tokenizer = RobertaTokenizer.from_pretrained(\"tadsatlawa/nanoBERT\", return_tensors=\"pt\")"
85
+ ],
86
+ "metadata": {
87
+ "id": "1GNqH8HlrzmF"
88
+ },
89
+ "execution_count": 3,
90
+ "outputs": []
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "source": [
95
+ "# Initialise model\n",
96
+ "unmasker = pipeline('fill-mask', model=\"tadsatlawa/nanoBERT\", tokenizer=tokenizer, top_k=20 )"
97
+ ],
98
+ "metadata": {
99
+ "id": "3CYcwIOU3xCY"
100
+ },
101
+ "execution_count": 4,
102
+ "outputs": []
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "source": [
107
+ "# Predict the residue probability at one or more masked positions\n",
108
+ "# mark position to predict with '<mask>'\n",
109
+ "seq = \"QLVSGPEVKKP<mask>ASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS\"\n",
110
+ "\n",
111
+ "residueProbability = unmasker(seq)\n",
112
+ "\n",
113
+ "# Print residue probabilities\n",
114
+ "for probability in residueProbability:\n",
115
+ " print(probability)"
116
+ ],
117
+ "metadata": {
118
+ "colab": {
119
+ "base_uri": "https://localhost:8080/"
120
+ },
121
+ "id": "6rtUxgbYsygY",
122
+ "outputId": "38fdd80d-cf30-4573-dbe5-c40f9b306470"
123
+ },
124
+ "execution_count": 5,
125
+ "outputs": [
126
+ {
127
+ "output_type": "stream",
128
+ "name": "stdout",
129
+ "text": [
130
+ "{'score': 0.7448901534080505, 'token': 10, 'token_str': 'G', 'sequence': 'QLVSGPEVKKPGASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
131
+ "{'score': 0.04520424082875252, 'token': 19, 'token_str': 'R', 'sequence': 'QLVSGPEVKKPRASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
132
+ "{'score': 0.029332099482417107, 'token': 5, 'token_str': 'A', 'sequence': 'QLVSGPEVKKPAASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
133
+ "{'score': 0.023554226383566856, 'token': 20, 'token_str': 'S', 'sequence': 'QLVSGPEVKKPSASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
134
+ "{'score': 0.022556299343705177, 'token': 17, 'token_str': 'P', 'sequence': 'QLVSGPEVKKPPASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
135
+ "{'score': 0.02046232856810093, 'token': 22, 'token_str': 'V', 'sequence': 'QLVSGPEVKKPVASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
136
+ "{'score': 0.017790036275982857, 'token': 8, 'token_str': 'E', 'sequence': 'QLVSGPEVKKPEASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
137
+ "{'score': 0.015881769359111786, 'token': 6, 'token_str': 'C', 'sequence': 'QLVSGPEVKKPCASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
138
+ "{'score': 0.014478186145424843, 'token': 23, 'token_str': 'W', 'sequence': 'QLVSGPEVKKPWASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
139
+ "{'score': 0.013189132325351238, 'token': 14, 'token_str': 'L', 'sequence': 'QLVSGPEVKKPLASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
140
+ "{'score': 0.010759864002466202, 'token': 9, 'token_str': 'F', 'sequence': 'QLVSGPEVKKPFASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
141
+ "{'score': 0.010044544003903866, 'token': 7, 'token_str': 'D', 'sequence': 'QLVSGPEVKKPDASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
142
+ "{'score': 0.00823446735739708, 'token': 21, 'token_str': 'T', 'sequence': 'QLVSGPEVKKPTASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
143
+ "{'score': 0.005904716905206442, 'token': 24, 'token_str': 'Y', 'sequence': 'QLVSGPEVKKPYASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
144
+ "{'score': 0.004586651921272278, 'token': 12, 'token_str': 'I', 'sequence': 'QLVSGPEVKKPIASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
145
+ "{'score': 0.004159640986472368, 'token': 18, 'token_str': 'Q', 'sequence': 'QLVSGPEVKKPQASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
146
+ "{'score': 0.0033481556456536055, 'token': 15, 'token_str': 'M', 'sequence': 'QLVSGPEVKKPMASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
147
+ "{'score': 0.0021347403526306152, 'token': 13, 'token_str': 'K', 'sequence': 'QLVSGPEVKKPKASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
148
+ "{'score': 0.0021168291568756104, 'token': 16, 'token_str': 'N', 'sequence': 'QLVSGPEVKKPNASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n",
149
+ "{'score': 0.0013719164999201894, 'token': 11, 'token_str': 'H', 'sequence': 'QLVSGPEVKKPHASVKVSCKASGYIFNNYGISWVRQAPGQGLEWMGWISTDNGNTNYAQKVQGRVTMTTDTSTSTAYMELRSLRYDDTAVYYCANNWGSYFEHWGQGTLVTVSS'}\n"
150
+ ]
151
+ }
152
+ ]
153
+ }
154
+ ]
155
+ }