mjbuehler commited on
Commit
1ff81a7
1 Parent(s): 4c637e7

Upload X-LoRA-Gemma_Inference.ipynb

Browse files
Files changed (1) hide show
  1. X-LoRA-Gemma_Inference.ipynb +1101 -0
X-LoRA-Gemma_Inference.ipynb ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "3288987d",
6
+ "metadata": {},
7
+ "source": [
8
+ "# X-LoRA Inference: Gemma-7b model for molecular design \n"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "25beb240-1ae1-4537-9cc6-da621862d0bd",
14
+ "metadata": {},
15
+ "source": [
16
+ "### Helper functions "
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "e2c18b20-b1a9-4f3e-ae84-2a551e2ed69c",
23
+ "metadata": {
24
+ "tags": []
25
+ },
26
+ "outputs": [],
27
+ "source": [
28
+ "import os\n",
29
+ "import random\n",
30
+ "\n",
31
+ "import torch\n",
32
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
33
+ "import transformers\n",
34
+ "from datasets import load_dataset\n",
35
+ "from datasets import IterableDataset\n",
36
+ "\n",
37
+ "from transformers import Trainer\n",
38
+ "from transformers import TrainingArguments\n",
39
+ "from transformers import DataCollatorWithPadding\n",
40
+ "from transformers import TrainerCallback\n",
41
+ "from transformers import AutoConfig\n",
42
+ "from transformers import BitsAndBytesConfig\n",
43
+ "\n",
44
+ "from peft import LoraConfig, get_peft_model\n",
45
+ "from torch.utils.data import Dataset\n",
46
+ "from transformers import get_linear_schedule_with_warmup\n",
47
+ "from accelerate import infer_auto_device_map\n",
48
+ "import math\n",
49
+ "import numpy as np\n",
50
+ "import unidecode\n",
51
+ "import pandas as pd\n",
52
+ "from matplotlib import pyplot as plt\n",
53
+ "import peft\n",
54
+ "\n",
55
+ "from tqdm.notebook import tqdm\n",
56
+ "\n",
57
+ "device='cuda'\n",
58
+ "\n",
59
+ "def params(model):\n",
60
+ " model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
61
+ " params = sum([np.prod(p.size()) for p in model_parameters])\n",
62
+ "\n",
63
+ " print(\"Number of model arameters: \", params) \n",
64
+ "\n",
65
+ "def generate_response (model,tokenizer,text_input=\"Biology offers amazing\",\n",
66
+ " num_return_sequences=1,\n",
67
+ " temperature=1., #the higher the temperature, the more creative the model becomes\n",
68
+ " max_new_tokens=127,\n",
69
+ " num_beams=1,\n",
70
+ " top_k = 50,\n",
71
+ " top_p =0.9,repetition_penalty=1.,eos_token_id=107,verbatim=False,\n",
72
+ " exponential_decay_length_penalty_fac=None,add_special_tokens =True, eos_token=None, \n",
73
+ " ):\n",
74
+ "\n",
75
+ " if eos_token==None:\n",
76
+ " eos_token=tokenizer('<end_of_turn>', add_special_tokens =False, ) ['input_ids'][0]\n",
77
+ " \n",
78
+ " inputs = tokenizer(text_input, \n",
79
+ " add_special_tokens =add_special_tokens, \n",
80
+ " return_tensors ='pt').to(device)\n",
81
+ " if verbatim:\n",
82
+ " print (\"Length of input, tokenized: \", inputs[\"input_ids\"].shape, inputs[\"input_ids\"],\"eos_token: \", eos_token)\n",
83
+ " with torch.no_grad():\n",
84
+ " outputs = model.generate(#input_ids=inputs.to(device), \n",
85
+ " input_ids = inputs[\"input_ids\"],\n",
86
+ " attention_mask = inputs[\"attention_mask\"] , # This is usually done automatically by the tokenizer\n",
87
+ " max_new_tokens=max_new_tokens,\n",
88
+ " temperature=temperature, #value used to modulate the next token probabilities.\n",
89
+ " num_beams=num_beams,\n",
90
+ " top_k = top_k,\n",
91
+ " top_p = top_p,\n",
92
+ " num_return_sequences = num_return_sequences,\n",
93
+ " eos_token_id=eos_token,\n",
94
+ " pad_token_id = eos_token,\n",
95
+ " do_sample =True, \n",
96
+ " repetition_penalty=repetition_penalty, \n",
97
+ " )\n",
98
+ "\n",
99
+ " return tokenizer.batch_decode(outputs[:,inputs[\"input_ids\"].shape[1]:].detach().cpu().numpy(), skip_special_tokens=True)\n",
100
+ "\n",
101
+ "def generate_answer (model,tokenizer,system='You a helpful assistant. You are familiar with materials science. ',\n",
102
+ " q='What is spider silk in the context of bioinspired materials?',\n",
103
+ " repetition_penalty=1.1,\n",
104
+ " top_p=0.1, top_k=32, \n",
105
+ " temperature=.6,max_new_tokens=512, verbatim=False, eos_token=None,add_special_tokens=True,\n",
106
+ " prepend_response='', messages=[],\n",
107
+ " ):\n",
108
+ "\n",
109
+ " if eos_token==None:\n",
110
+ " eos_token= tokenizer.eos_token_id\n",
111
+ " \n",
112
+ " if system==None:\n",
113
+ " messages.append ({\"role\": \"user\", \"content\": q} )\n",
114
+ " else:\n",
115
+ " messages.append ({\"role\": \"user\", \"content\": system+q})\n",
116
+ " \n",
117
+ " txt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, )\n",
118
+ " txt=txt+prepend_response\n",
119
+ " \n",
120
+ " output_text=generate_response (model,tokenizer,text_input=txt,eos_token_id=eos_token,\n",
121
+ " num_return_sequences=1, repetition_penalty=repetition_penalty,\n",
122
+ " top_p=top_p, top_k=top_k, add_special_tokens =add_special_tokens,\n",
123
+ " \n",
124
+ " temperature=temperature,max_new_tokens=max_new_tokens, verbatim=verbatim, \n",
125
+ " \n",
126
+ " )\n",
127
+ " return ( output_text[0] )"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "id": "75d89d27-8386-4859-a36e-ce4842415b59",
133
+ "metadata": {},
134
+ "source": [
135
+ "### Load X-LoRA Gemma model "
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "raw",
140
+ "id": "cd1b66f6-1fe1-4b2c-9309-fe01d34d7d54",
141
+ "metadata": {},
142
+ "source": [
143
+ "https://github.com/EricLBuehler/xlora"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "12848c38-cc0c-41c7-bf04-9856730458df",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "import torch\n",
154
+ "from xlora.xlora_utils import load_model \n",
155
+ "\n",
156
+ "XLoRa_model_name = 'lamm-mit/x-lora-gemma-7b'\n",
157
+ "\n",
158
+ "model, tokenizer=load_model(model_name = XLoRa_model_name, \n",
159
+ " device='cuda:0',\n",
160
+ " use_flash_attention_2=True, \n",
161
+ " dtype=torch.bfloat16,\n",
162
+ " )\n",
163
+ "eos_token_id= tokenizer('<end_of_turn>', add_special_tokens=False, ) ['input_ids'][0]\n"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "id": "b197ffd5-7752-4081-9227-c46a485afeec",
169
+ "metadata": {},
170
+ "source": [
171
+ "### Inference using Guidance "
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "raw",
176
+ "id": "f7009898-17a9-468a-970a-59d7c80553ca",
177
+ "metadata": {},
178
+ "source": [
179
+ "https://github.com/guidance-ai/guidance"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "80b62bf2-a424-4858-a321-f55e3327b070",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "from guidance import models\n",
190
+ "from guidance import gen, select, system, user, assistant, newline\n",
191
+ "from IPython.display import display, Markdown\n",
192
+ "\n",
193
+ "gpt = models.TransformersChat(model=model, tokenizer=tokenizer)\n",
194
+ "gpt_question_asker = gpt"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "id": "1cb5a867-a127-45c2-b75b-35883a78930b",
201
+ "metadata": {
202
+ "scrolled": true
203
+ },
204
+ "outputs": [],
205
+ "source": [
206
+ "with user(): \n",
207
+ " lm =gpt + f\"\"\"List the most important biomolecules used in biological materials to make polymers with multifunctional qualities.\"\"\" \n",
208
+ "\n",
209
+ "with assistant(): \n",
210
+ " lm+=\"[\"+gen('res1', max_tokens=1024)"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "markdown",
215
+ "id": "a841c58c-bded-4741-80df-66ca434bfac0",
216
+ "metadata": {},
217
+ "source": [
218
+ "### Inference using Hugging Face generate functions "
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "26a27dc2-4e28-4fee-b37c-446281cd23da",
225
+ "metadata": {
226
+ "scrolled": true
227
+ },
228
+ "outputs": [],
229
+ "source": [
230
+ "system_prompt='You are an expert in biological molecular engineering. '\n",
231
+ "q=\"\"\"\n",
232
+ "What are potential molecular engineering approaches to create better materials? Name specific molecules of interest.\n",
233
+ "\"\"\"\n",
234
+ "\n",
235
+ "res=generate_answer (model, tokenizer,system=system_prompt,\n",
236
+ " q=q,\n",
237
+ " repetition_penalty=1., top_p=0.9, top_k=256, \n",
238
+ " temperature=.5,max_new_tokens=512, verbatim=False, \n",
239
+ " )\n",
240
+ "\n",
241
+ "display (Markdown (\"## X-LoRA:\\n\\n\"+res))"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "id": "82d162fe-6149-44d4-afbe-63213b10f183",
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "system_prompt='You are an expert in biological molecular engineering. '\n",
252
+ "q=\"\"\"\n",
253
+ "List the most important biomolecules used in biological materials to make polymers with multifunctional qualities.\n",
254
+ "\"\"\"\n",
255
+ "messages=[]\n",
256
+ "res=generate_answer (model, tokenizer,system=system_prompt,\n",
257
+ " q=q, repetition_penalty=1., top_p=0.9, top_k=256, temperature=.5,max_new_tokens=512, verbatim=False,messages=messages )\n",
258
+ "\n",
259
+ "display (Markdown (\"## X-LoRA:\\n\\n\"+res))\n",
260
+ "messages.append ({\"role\": \"assistant\", \"content\": res} )"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "id": "f9ad8c01-7cfd-4017-a8af-0b72b4ea25fe",
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "system_prompt=None\n",
271
+ "q=\"\"\"\n",
272
+ "How does chitin form a material, specifically in terms of molecular interactions? \n",
273
+ "\"\"\" \n",
274
+ "res=generate_answer (model, tokenizer,system=system_prompt,\n",
275
+ " q=q, repetition_penalty=1., top_p=0.9, top_k=256, temperature=.1,max_new_tokens=512, verbatim=False,messages=messages,\n",
276
+ " )\n",
277
+ "\n",
278
+ "display (Markdown (\"## X-LoRA:\\n\\n\"+res))\n",
279
+ "messages.append ({\"role\": \"assistant\", \"content\": res} )"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": null,
285
+ "id": "6f520fd9-0d06-4971-9b58-74d8d2c3e2ef",
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "system_prompt=None\n",
290
+ "q=\"\"\"\n",
291
+ "Thank you. What are potential chemical modifications of N-acetylglucosamine units that would improve mechanical properties?\n",
292
+ "\"\"\" \n",
293
+ "res=generate_answer (model, tokenizer,system=system_prompt,\n",
294
+ " q=q, repetition_penalty=1., top_p=0.9, top_k=256, temperature=.1,max_new_tokens=512, verbatim=False,messages=messages,\n",
295
+ " )\n",
296
+ "\n",
297
+ "display (Markdown (\"## X-LoRA:\\n\\n\"+res))\n",
298
+ "messages.append ({\"role\": \"assistant\", \"content\": res} )"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "id": "ce5a2293-b66d-4ef6-987e-451dc1a92621",
304
+ "metadata": {},
305
+ "source": [
306
+ "### Molecule design examples"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "e547bed4-da94-48c7-b9dd-00da7732ef20",
313
+ "metadata": {
314
+ "scrolled": true
315
+ },
316
+ "outputs": [],
317
+ "source": [
318
+ "import pandas as pd\n",
319
+ "from sklearn.preprocessing import MinMaxScaler\n",
320
+ "\n",
321
+ "df_smiles=pd.read_csv ('./QM9.csv')\n",
322
+ "SMILES_LIST=list (df_smiles['smiles'])\n",
323
+ "\n",
324
+ "X = df_smiles.iloc[:, 0].values.reshape(-1, 1) # Input feature, reshaped for compatibility\n",
325
+ "y = df_smiles.iloc[:, 1:] # Target features\n",
326
+ "\n",
327
+ "# Scaling the target features\n",
328
+ "scaler = MinMaxScaler()\n",
329
+ "y_scaled = scaler.fit_transform(y)\n",
330
+ "\n",
331
+ "from sklearn.model_selection import train_test_split\n",
332
+ "\n",
333
+ "X_train, X_test, y_train, y_test= train_test_split(X, y_scaled, test_size=0.2, random_state=42)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "44c43109-0606-42d2-a1b6-01278ff6432f",
340
+ "metadata": {
341
+ "scrolled": true
342
+ },
343
+ "outputs": [],
344
+ "source": [
345
+ "import os\n",
346
+ "import numpy as np\n",
347
+ "import pandas as pd\n",
348
+ "import matplotlib.pyplot as plt\n",
349
+ "import seaborn as sns\n",
350
+ "from sklearn.metrics import mean_squared_error\n",
351
+ "labels = [\"mu\", \"alpha\", \"homo\", \"lumo\", \"gap\", \"r2\", \"zpve\", \"cv\", \"u0\", \"u298\", \"h298\", \"g298\"]\n",
352
+ "\n",
353
+ "def return_str(vals=np.array ([.1, .5, .6, 2.])):\n",
354
+ " ch=''\n",
355
+ " for i in range (len (vals)):\n",
356
+ " ch=ch+f'{vals[i]:1.3f},'\n",
357
+ " \n",
358
+ " return ch[:-1] \n",
359
+ "\n",
360
+ "def extract_start_and_end(string_input, start_token='[', end_token=']'):\n",
361
+ " \"\"\"\n",
362
+ " Extracts the substring from 'string_input' that is enclosed between the first occurrence of\n",
363
+ " 'start_token' and the last occurrence of 'end_token'.\n",
364
+ "\n",
365
+ " Args:\n",
366
+ " string_input (str): The string from which to extract the substring.\n",
367
+ " start_token (str): The starting delimiter. Default is '['.\n",
368
+ " end_token (str): The ending delimiter. Default is ']'.\n",
369
+ "\n",
370
+ " Returns:\n",
371
+ " str: The extracted substring. If 'start_token' or 'end_token' is not found, returns an empty string.\n",
372
+ " \"\"\"\n",
373
+ " # Find the index of the first occurrence of start_token\n",
374
+ " i = string_input.find(start_token)\n",
375
+ " # Find the index of the last occurrence of end_token\n",
376
+ " j = string_input.rfind(end_token)\n",
377
+ "\n",
378
+ " # Check if both tokens are found and i < j to ensure proper enclosure\n",
379
+ " if i == -1 or j == -1 or i >= j:\n",
380
+ " return \"\"\n",
381
+ " else:\n",
382
+ " # Extract and return the content between the first start_token and the last end_token\n",
383
+ " return string_input[i + 1:j]\n",
384
+ "\n",
385
+ "def is_SMILES_novel (SMILES, SMILES_LIST=None):\n",
386
+ "\n",
387
+ " if SMILES_LIST !=None:\n",
388
+ " \n",
389
+ " if SMILES not in SMILES_LIST:\n",
390
+ " is_novel=True\n",
391
+ " else:\n",
392
+ " is_novel=False\n",
393
+ " else:\n",
394
+ " is_novel=None\n",
395
+ " return is_novel\n",
396
+ " \n",
397
+ "def visualize_SMILES (smiles_code, dir_path='./' , root='', sample_count=0):\n",
398
+ " molecule = Chem.MolFromSmiles(smiles_code)\n",
399
+ " \n",
400
+ " # Generate an image of the molecule\n",
401
+ " molecule_image = Draw.MolToImage(molecule)\n",
402
+ " \n",
403
+ " # Display the image directly in Jupyter Notebook\n",
404
+ " display(molecule_image)\n",
405
+ " \n",
406
+ " image_path=f\"{dir_path}/SMILES_{sample_count}_{root}_molecule_image.png\"\n",
407
+ " molecule_image.save(image_path)\n",
408
+ "\n",
409
+ " return image_path\n",
410
+ "\n",
411
+ "\n",
412
+ "def design_from_target(\n",
413
+ " model,\n",
414
+ " tokenizer,\n",
415
+ " target,\n",
416
+ " temperature=0.1,\n",
417
+ " num_beams=1,\n",
418
+ " top_k=50,\n",
419
+ " top_p=0.95,\n",
420
+ " repetition_penalty=1.0,\n",
421
+ " messages=[]\n",
422
+ "):\n",
423
+ " # Format the target line for molecular property generation\n",
424
+ " line = f'GenerateMolecularProperties<{return_str(target)}>'\n",
425
+ " \n",
426
+ " # Add the line to the message history\n",
427
+ " messages.append({\"role\": \"user\", \"content\": line})\n",
428
+ " \n",
429
+ " # Apply chat template with optional tokenization\n",
430
+ " line = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
431
+ " \n",
432
+ " # Generate response with specified parameters\n",
433
+ " result = generate_response(\n",
434
+ " model,\n",
435
+ " tokenizer,\n",
436
+ " text_input=line,\n",
437
+ " num_return_sequences=1,\n",
438
+ " temperature=temperature,\n",
439
+ " top_k=top_k,\n",
440
+ " top_p=top_p,\n",
441
+ " max_new_tokens=256\n",
442
+ " )[0]\n",
443
+ " \n",
444
+ " return result\n",
445
+ "\n",
446
+ "def properties_from_SMILES(\n",
447
+ " model,\n",
448
+ " tokenizer,\n",
449
+ " target,\n",
450
+ " temperature=0.1,\n",
451
+ " top_k=128,\n",
452
+ " top_p=0.9,\n",
453
+ " num_beams=1,\n",
454
+ " repetition_penalty=1.0\n",
455
+ "):\n",
456
+ " # Format the target line for molecular property calculation\n",
457
+ " line = f'CalculateMolecularProperties<{target}>'\n",
458
+ " \n",
459
+ " # Initialize messages and add the formatted line\n",
460
+ " messages = [{\"role\": \"user\", \"content\": line}]\n",
461
+ " \n",
462
+ " # Apply chat template with optional tokenization\n",
463
+ " line = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
464
+ " \n",
465
+ " # Generate response with specified parameters\n",
466
+ " result = generate_response(\n",
467
+ " model,\n",
468
+ " tokenizer,\n",
469
+ " text_input=line,\n",
470
+ " num_return_sequences=1,\n",
471
+ " temperature=temperature,\n",
472
+ " top_k=top_k,\n",
473
+ " top_p=top_p,\n",
474
+ " max_new_tokens=256\n",
475
+ " )[0]\n",
476
+ " \n",
477
+ " # Extract relevant part of the result and convert to float list\n",
478
+ " result = extract_start_and_end(result, start_token='[', end_token=']')\n",
479
+ " return [float(i) for i in result.split(',')]\n",
480
+ "\n",
481
+ " \n",
482
+ "def avg_properties_from_SMILES (model, tokenizer, SMILES ='O=C(N)C1OC(CO)C(O)C(O)C1O', SMILES_dir='./',\n",
483
+ " temperature=0.01, top_k=50,top_p=0.95, num_beams=1, repetition_penalty=1.,\n",
484
+ " labels=None, N_prop=6, plot_results=True):\n",
485
+ " if not os.path.exists(SMILES_dir):\n",
486
+ " os.makedirs(SMILES_dir) \n",
487
+ " properties=[]\n",
488
+ " if labels==None and plot_results:\n",
489
+ " labels= ['mu',\n",
490
+ " 'alpha',\n",
491
+ " 'homo',\n",
492
+ " 'lumo',\n",
493
+ " 'gap',\n",
494
+ " 'r2',\n",
495
+ " 'zpve',\n",
496
+ " 'cv',\n",
497
+ " 'u0',\n",
498
+ " 'u298',\n",
499
+ " 'h298',\n",
500
+ " 'g298']\n",
501
+ " successful=0\n",
502
+ " for i in tqdm(range (N_prop)):\n",
503
+ " \n",
504
+ " try:\n",
505
+ " _prop=properties_from_SMILES (model, tokenizer, SMILES,temperature=temperature, top_k=top_k,top_p=top_p,\n",
506
+ " num_beams=num_beams, repetition_penalty=repetition_penalty,\n",
507
+ " )\n",
508
+ " if len (_prop)==len (labels):\n",
509
+ " \n",
510
+ " properties.append(np.array( _prop) )\n",
511
+ " successful+=1\n",
512
+ " except:\n",
513
+ " print (end=\"\")\n",
514
+ " \n",
515
+ " all_properties = np.array(properties)\n",
516
+ " \n",
517
+ " # Calculate mean and standard deviation for each property\n",
518
+ " means = np.mean(all_properties, axis=0)\n",
519
+ " std_devs = np.std(all_properties, axis=0)\n",
520
+ " \n",
521
+ " # Labels for the x-axis\n",
522
+ " if plot_results: \n",
523
+ " # Creating the plot with error bars\n",
524
+ " plt.figure(figsize=(6, 4))\n",
525
+ " plt.errorbar(labels, means, yerr=std_devs, fmt='o', ecolor='red', capsize=5, capthick=2, marker='s', color='blue')\n",
526
+ " plt.xticks(rotation=45)\n",
527
+ " plt.xlabel('Property')\n",
528
+ " plt.ylabel('Value')\n",
529
+ " plt.title('Average Properties with Error Bars')\n",
530
+ " plt.tight_layout()\n",
531
+ " plt.savefig(SMILES_dir + f\"avg_prop_{SMILES}.svg\", format=\"svg\")\n",
532
+ " \n",
533
+ " plt.show()\n",
534
+ " print (f\"Successful attempts: {successful}/{N_prop}\")\n",
535
+ " \n",
536
+ " return means, std_devs \n",
537
+ "\n",
538
+ "def is_valid_smiles(smiles):\n",
539
+ " # This function tries to create a molecule object from a SMILES string.\n",
540
+ " # If the molecule object is created successfully and is not None, the SMILES is valid.\n",
541
+ " mol = Chem.MolFromSmiles(smiles)\n",
542
+ " return mol is not None\n",
543
+ " \n",
544
+ "def design_molecule(model, tokenizer, target=None, temperature=0.1,\n",
545
+ " num_beams=1,top_k=50,top_p=0.95, repetition_penalty=1.,\n",
546
+ " SMILES_LIST=None, dir_path='./', messages=[],N_attempts_for_forward=1):\n",
547
+ "\n",
548
+ " if not os.path.exists(dir_path):\n",
549
+ " os.makedirs(dir_path)\n",
550
+ " if target.any()==None:\n",
551
+ " target = np.random.rand(12)\n",
552
+ " \n",
553
+ " try:\n",
554
+ " SMILES=design_from_target (model, tokenizer, target, messages=messages)\n",
555
+ " except:\n",
556
+ " SMILES=None\n",
557
+ " print (\"Generation failed.\")\n",
558
+ "\n",
559
+ " is_novel=is_SMILES_novel (SMILES, SMILES_LIST)\n",
560
+ " print (\"Result: \", SMILES, \"is novel: \", is_novel, \"is valid: \", is_valid_smiles(SMILES))\n",
561
+ " try:\n",
562
+ " visualize_SMILES (SMILES, dir_path=dir_path)\n",
563
+ " except:\n",
564
+ " print (\"Vis failed.\")\n",
565
+ "\n",
566
+ " try:\n",
567
+ " if N_attempts_for_forward==1:\n",
568
+ " predicted = properties_from_SMILES(model, tokenizer, SMILES,temperature_pred, num_beams,\n",
569
+ " top_k, top_p, repetition_penalty)\n",
570
+ " else:\n",
571
+ " predicted,_=avg_properties_from_SMILES(model, tokenizer, SMILES, SMILES_dir=SMILES_dir,\n",
572
+ " temperature=temperature_pred, top_k=top_k,top_p=top_p, num_beams=num_beams, repetition_penalty=repetition_penalty,\n",
573
+ " labels=labels, N_prop=N_attempts_for_forward, plot_results=False)\n",
574
+ "\n",
575
+ " sns.set_style(\"whitegrid\")\n",
576
+ " plt.gcf().set_facecolor('white')\n",
577
+ " # Assuming GT_res and predictions are your data arrays/lists for Ground Truth and Predictions respectively\n",
578
+ " \n",
579
+ " x = np.arange(len(labels)) # Label locations\n",
580
+ " width = 0.35 # Width of the bars\n",
581
+ " \n",
582
+ " fig, ax = plt.subplots(figsize=(9, 5))\n",
583
+ " rects1 = ax.bar(x - width/2, target, width, label='Target')\n",
584
+ " rects2 = ax.bar(x + width/2, predicted, width, label='Predicted properties')\n",
585
+ " \n",
586
+ " # Add some text for labels, title and custom x-axis tick labels, etc.\n",
587
+ " ax.set_ylabel('Values')\n",
588
+ " ax.set_title('Comparison of Target and Predicted Properties')\n",
589
+ " ax.set_xticks(x)\n",
590
+ " ax.set_xticklabels(labels, rotation=45, ha=\"right\")\n",
591
+ " ax.legend()\n",
592
+ "\n",
593
+ " except:\n",
594
+ " print(\"Forward anaysis failed.\")\n",
595
+ " return SMILES, is_novel\n",
596
+ "\n",
597
+ "def design_molecule_loop(model, tokenizer, target=None, temperature_gen=0.3,temperature_pred=0.01, SMILES_LIST=None,\n",
598
+ " top_k=50, top_p=0.95, repetition_penalty=1., num_beams=1,update_primer_with_better_draft=False,\n",
599
+ " threshold=0.01, N_max=100, dir_path='./',lower_bound = 0.0,remove_duplicates=True,\n",
600
+ " upper_bound = 0.1,sample_count=0, messages=[], N_attempts_for_forward=1, set_opt=None):\n",
601
+ "\n",
602
+ " mse_smallest_current=9999\n",
603
+ " if not os.path.exists(dir_path):\n",
604
+ " os.makedirs(dir_path)\n",
605
+ " if target is None or not target.any():\n",
606
+ " target = np.random.rand(12)\n",
607
+ "\n",
608
+ " if len (messages) >0:\n",
609
+ " print (\"Using primed generation:\\n\", messages)\n",
610
+ " \n",
611
+ " records = [] # To store SMILES, properties, and MSE\n",
612
+ " for iteration in range(N_max):\n",
613
+ " try:\n",
614
+ " print (f\">>> Iteration={iteration}\")\n",
615
+ " original_messages=copy.deepcopy (messages)\n",
616
+ "\n",
617
+ " SMILES = design_from_target(model, tokenizer, target, temperature_gen, num_beams,\n",
618
+ " top_k, top_p, repetition_penalty, messages=original_messages)\n",
619
+ " is_novel=is_SMILES_novel (SMILES, SMILES_LIST)\n",
620
+ "\n",
621
+ " if is_novel and is_valid_smiles(SMILES):\n",
622
+ " print (f\"{SMILES} is novel: {is_novel}\", \"is valid: \", {is_valid_smiles(SMILES)})\n",
623
+ " if N_attempts_for_forward==1:\n",
624
+ " predicted = properties_from_SMILES(model, tokenizer, SMILES,temperature_pred, num_beams,\n",
625
+ " top_k, top_p, repetition_penalty)\n",
626
+ " else:\n",
627
+ " predicted,_=avg_properties_from_SMILES(model, tokenizer, SMILES, SMILES_dir=dir_path,\n",
628
+ " temperature=temperature_pred, top_k=top_k,top_p=top_p, repetition_penalty=repetition_penalty,\n",
629
+ " labels=labels, N_prop=N_attempts_for_forward, plot_results=False)\n",
630
+ "\n",
631
+ " if set_opt==None:\n",
632
+ " mse = mean_squared_error(target, predicted)\n",
633
+ " else:\n",
634
+ " mse = mean_squared_error(target[set_opt], predicted[set_opt])\n",
635
+ " if mse<mse_smallest_current:\n",
636
+ " mse_smallest_current=mse\n",
637
+ " if update_primer_with_better_draft:\n",
638
+ " messages=prime_messages (SMILES, predicted , N=1)\n",
639
+ " print (\"Smaller MSE found, updated messages primer! Messages: \", messages,\n",
640
+ " f\"\\n\\nCurrent MSE: {mse}\")\n",
641
+ " \n",
642
+ " records.append((SMILES, predicted, mse, is_novel))\n",
643
+ " \n",
644
+ " print (f\">>>Iteration={iteration}, MSE={mse} for SMILES={SMILES}, novel={is_novel}\")\n",
645
+ " if mse < threshold:\n",
646
+ " print(f\"Threshold met at iteration {iteration+1}\")\n",
647
+ " break\n",
648
+ " else:\n",
649
+ " print (f\"{SMILES} is not novel or not valid, validity: {is_valid_smiles(SMILES)}.\")\n",
650
+ " except Exception as e:\n",
651
+ " print(f\"Error during iteration {iteration+1}: {e}\")\n",
652
+ " continue\n",
653
+ "\n",
654
+ " # Sorting records based on MSE (most accurate first)\n",
655
+ " records.sort(key=lambda x: x[2])\n",
656
+ "\n",
657
+ " # Visualizing the best performing molecule\n",
658
+ " best_SMILES, best_predicted, best_mse, is_novel = records[0]\n",
659
+ "\n",
660
+ " print (\"Best SILES: \", best_SMILES)\n",
661
+ " try:\n",
662
+ " print (f\"{best_SMILES} is novel: {is_novel}\")\n",
663
+ " \n",
664
+ " sns.set_style(\"whitegrid\")\n",
665
+ " \n",
666
+ " visualize_pred_vs_target (target, best_predicted, labels, dir_path=dir_path, best_SMILES=best_SMILES,sample_count=0)\n",
667
+ " \n",
668
+ " print(f\"Process completed. Results saved to {csv_path}.\") \n",
669
+ " visualize_SMILES(best_SMILES, dir_path=dir_path, root=f'{target}_BEST')\n",
670
+ "\n",
671
+ " print(f\"Compute molecular structure, UFF eq, Gasteiger, etc.\") \n",
672
+ " \n",
673
+ " compute_gasteiger (best_SMILES, SMILES_dir=dir_path, target= np.array(best_predicted))\n",
674
+ "\n",
675
+ " mol = Chem.MolFromSmiles(best_SMILES)\n",
676
+ " inchi_str = Chem.MolToInchi(mol)\n",
677
+ " print(f\"InChI String of {best_SMILES}:\", inchi_str)\n",
678
+ " \n",
679
+ " \n",
680
+ " except Exception as e:\n",
681
+ " print(f\"Processing/visualization failed for {best_SMILES}: {e}\")\n",
682
+ "\n",
683
+ " # Writing records to a CSV file\n",
684
+ " df = pd.DataFrame(records, columns=['SMILES', 'Predicted Properties', 'MSE', 'is_novel'])\n",
685
+ " csv_path = os.path.join(dir_path, 'SMILES_designs.csv')\n",
686
+ " df.to_csv(csv_path, index=False)\n",
687
+ "\n",
688
+ " # Plot MSE against the index (which now corresponds to the ranking)\n",
689
+ " plt.figure(figsize=(10, 8)) # Adjust the size as needed\n",
690
+ " plt.plot(df['SMILES'], df['MSE'], 'o', markersize=5) # 'o' for circular markers\n",
691
+ " \n",
692
+ " # Adding labels for each point with the SMILES string\n",
693
+ " for i, txt in enumerate(df['SMILES']):\n",
694
+ " plt.annotate(txt, (i, df['MSE'].iloc[i]), fontsize=8, rotation=45, ha='right')\n",
695
+ " \n",
696
+ " visualize_over_SMILES (df,N_max=N_max,SMILES_dir=SMILES_dir,\n",
697
+ " lower_bound = lower_bound,remove_duplicates=remove_duplicates,\n",
698
+ " upper_bound = upper_bound, target=target)\n",
699
+ " return df \n",
700
+ "\n",
701
+ "from rdkit import Chem\n",
702
+ "from rdkit.Chem import Draw\n",
703
+ "import os\n",
704
+ "\n",
705
+ "def visualize_smiles_and_save(smiles_list, per_row=4, dir_path='./', root=''):\n",
706
+ " \"\"\"\n",
707
+ " Visualizes a list of molecules from their SMILES strings with labels, checks for validity, \n",
708
+ " and saves the visualization as an SVG file.\n",
709
+ " \n",
710
+ " Parameters:\n",
711
+ " - smiles_list: List of SMILES strings to visualize.\n",
712
+ " - per_row: Number of molecule images per row in the assembly.\n",
713
+ " - dir_path: Directory path where the SVG file will be saved.\n",
714
+ " \"\"\"\n",
715
+ " if not os.path.exists(dir_path):\n",
716
+ " os.makedirs(dir_path)\n",
717
+ " valid_molecules = []\n",
718
+ " valid_smiles = [] # To store valid SMILES strings for labeling\n",
719
+ " for smile in smiles_list:\n",
720
+ " mol = Chem.MolFromSmiles(smile)\n",
721
+ " if mol: # If the molecule is valid\n",
722
+ " valid_molecules.append(mol)\n",
723
+ " valid_smiles.append(smile) # Add the valid SMILES string\n",
724
+ " \n",
725
+ " # Proceed only if there are valid molecules\n",
726
+ " if not valid_molecules:\n",
727
+ " print(\"No valid molecules found in the provided SMILES strings.\")\n",
728
+ " return\n",
729
+ " \n",
730
+ " # Ensure the directory exists\n",
731
+ " if not os.path.exists(dir_path):\n",
732
+ " os.makedirs(dir_path)\n",
733
+ " \n",
734
+ " # Define the SVG file path\n",
735
+ " svg_file_path = os.path.join(dir_path, f'molecules_with_labels_{root}.svg')\n",
736
+ " \n",
737
+ " # Use RDKit to draw the molecules grid with labels\n",
738
+ " fig = Draw.MolsToGridImage(valid_molecules, molsPerRow=per_row, subImgSize=(200, 200), \n",
739
+ " legends=valid_smiles, useSVG=True)\n",
740
+ " \n",
741
+ " # Saving the SVG content to a file\n",
742
+ " with open(svg_file_path, 'w') as svg_file:\n",
743
+ " svg_file.write(fig.data)\n",
744
+ " display (fig)\n",
745
+ " \n",
746
+ " print(f\"Visualization saved as SVG at: {svg_file_path}\")\n",
747
+ "\n",
748
+ " return valid_smiles \n",
749
+ "\n",
750
+ "def plot_MSE_over_SMILES (df_design,N_max=24,\n",
751
+ " lower_bound = 0.0,\n",
752
+ " upper_bound = 0.08, SMILES_dir='./', target='', ):\n",
753
+ " \n",
754
+ " if not os.path.exists(SMILES_dir):\n",
755
+ " os.makedirs(SMILES_dir) \n",
756
+ " df_sorted = df_design[:N_max].sort_values('MSE',ascending=False).reset_index(drop=True)\n",
757
+ "\n",
758
+ " \n",
759
+ " df_plot=df_sorted[(df_sorted['MSE'] > lower_bound) & (df_sorted['MSE'] < upper_bound)]\n",
760
+ " \n",
761
+ " # Plot MSE against the index (which now corresponds to the ranking)\n",
762
+ " fig, ax = plt.subplots(figsize=(8, 7))\n",
763
+ " plt.plot(df_plot['SMILES'], df_plot['MSE'], 'o-', markersize=5, ) # 'o' for circular markers\n",
764
+ " \n",
765
+ " # Improving the plot aesthetics\n",
766
+ " plt.xticks(rotation=90) # Rotate the x-axis labels for better readability\n",
767
+ " plt.xlabel('Molecule SMILES')\n",
768
+ " plt.ylabel('MSE')\n",
769
+ " #plt.title('Ordered from Best to Worst')\n",
770
+ " plt.tight_layout() # Adjust the layout to make room for the rotated x-axis labels\n",
771
+ " plt.savefig(SMILES_dir+f'SMILES_over_MSE_{target}.svg', format='svg')\n",
772
+ " plt.show()\n",
773
+ " \n",
774
+ "def visualize_over_SMILES (df_design,N_max=24,per_row=20,SMILES_dir='./',\n",
775
+ " lower_bound = 0.0,\n",
776
+ " upper_bound = 0.08, target='', remove_duplicates=True):\n",
777
+ "\n",
778
+ " if remove_duplicates:\n",
779
+ " # Example: Keep the entry with the best MSE among the novel molecules for each SMILES\n",
780
+ " df_design = df_design.sort_values(['MSE', 'is_novel', 'SMILES', ], ascending=[True, False, True]) \\\n",
781
+ " .drop_duplicates(subset='SMILES', keep='first')\n",
782
+ "\n",
783
+ " df_design.reset_index(drop=True, inplace=True)\n",
784
+ " df_design.to_csv(f'{SMILES_dir}/sorted_noduplicates_{N_max}.csv', index=False)\n",
785
+ " \n",
786
+ " valid_smiles=visualize_smiles_and_save(list(df_design['SMILES'][:N_max]), per_row=per_row, dir_path=SMILES_dir, root=f'{target}')\n",
787
+ " \n",
788
+ " smiles_df = pd.DataFrame(valid_smiles, columns=[\"SMILES\"])\n",
789
+ "\n",
790
+ " # Save the DataFrame to a CSV file\n",
791
+ " file_path = \"/smiles_data.csv\"\n",
792
+ " smiles_df.to_csv(f'{SMILES_dir}/valid_SMILES_{N_max}.csv', index=False )\n",
793
+ " \n",
794
+ " fig, ax = plt.subplots(figsize=(8, 5))\n",
795
+ " \n",
796
+ " df_plot=df_design[(df_design['MSE'] > lower_bound) & (df_design['MSE'] < upper_bound)]\n",
797
+ " df_plot.plot(kind='kde', color='darkblue', label='KDE', ax=ax)\n",
798
+ " \n",
799
+ " # Plot histogram with density=True for probability density representation\n",
800
+ " plt.hist(df_design['MSE'], density=True, alpha=0.5, color='skyblue', label='Histogram',bins=50, \n",
801
+ " range=[lower_bound,upper_bound]\n",
802
+ " )\n",
803
+ " plt.xlim(lower_bound, upper_bound)\n",
804
+ " plt.title('Density and Histogram Plot of MSE')\n",
805
+ " plt.xlabel('MSE')\n",
806
+ " plt.ylabel('Density')\n",
807
+ " \n",
808
+ " # Adding a legend to distinguish between the KDE and Histogram\n",
809
+ " plt.legend()\n",
810
+ " \n",
811
+ " plt.savefig(SMILES_dir+f'mse_histogram_{target}.svg', format='svg')\n",
812
+ " plt.show()\n",
813
+ "\n",
814
+ " plot_MSE_over_SMILES (df_design,N_max=N_max,\n",
815
+ " lower_bound = lower_bound,\n",
816
+ " upper_bound = upper_bound, target=target,SMILES_dir=SMILES_dir)\n",
817
+ " \n",
818
+ " return df_design\n",
819
+ "\n",
820
+ "import numpy as np\n",
821
+ "import matplotlib.pyplot as plt\n",
822
+ "import pandas as pd\n",
823
+ "from pandas.plotting import parallel_coordinates\n",
824
+ "\n",
825
+ "def plot_change_in_design(original, labels, target, SMILES_dir='./'):\n",
826
+ " if not os.path.exists(SMILES_dir):\n",
827
+ " os.makedirs(SMILES_dir)\n",
828
+ " \n",
829
+ " # Create a DataFrame to hold the original and target vectors with labels\n",
830
+ " df = pd.DataFrame([original, target], columns=labels)\n",
831
+ " df['Version'] = ['Original', 'Target'] # Add a 'Version' column for coloring\n",
832
+ " \n",
833
+ " # Plotting\n",
834
+ " plt.figure(figsize=(7, 4))\n",
835
+ " parallel_coordinates(df, 'Version', color=['blue', 'red'])\n",
836
+ " plt.title('Original vs Target Values across Properties')\n",
837
+ " plt.xticks(rotation=45)\n",
838
+ " plt.tight_layout()\n",
839
+ " \n",
840
+ " # Annotating changes with thicker arrows pointing towards the target\n",
841
+ " for i, label in enumerate(labels):\n",
842
+ " if original[i] < target[i]: # If the target value is greater, arrow points upwards\n",
843
+ " plt.annotate('', xy=(i, target[i]), xytext=(i, original[i]),\n",
844
+ " arrowprops=dict(arrowstyle=\"->\", color='black', lw=2))\n",
845
+ " else: # If the target value is lesser, arrow points downwards\n",
846
+ " plt.annotate('', xy=(i, target[i]), xytext=(i, original[i]),\n",
847
+ " arrowprops=dict(arrowstyle=\"->\", color='black', lw=2))\n",
848
+ " \n",
849
+ " # Save the plot as an SVG file in the specified directory\n",
850
+ " plt.savefig(SMILES_dir + \"parallel_coordinates_changes_direction.svg\", format=\"svg\")\n",
851
+ " \n",
852
+ " plt.show()\n",
853
+ " \n",
854
+ "def visualize_pred_vs_target (target, best_predicted, labels, dir_path='./', best_SMILES='',sample_count=0): \n",
855
+ " if not os.path.exists(dir_path):\n",
856
+ " os.makedirs(dir_path)\n",
857
+ " sns.set_style(\"whitegrid\")\n",
858
+ " plt.gcf().set_facecolor('white')\n",
859
+ " \n",
860
+ " x = np.arange(len(labels)) # Label locations\n",
861
+ " width = 0.35 # Width of the bars\n",
862
+ " \n",
863
+ " fig, ax = plt.subplots(figsize=(9, 5))\n",
864
+ " rects1 = ax.bar(x - width/2, target, width, label='Target')\n",
865
+ " rects2 = ax.bar(x + width/2, best_predicted, width, label='Predicted properties')\n",
866
+ " \n",
867
+ " # Add some text for labels, title and custom x-axis tick labels, etc.\n",
868
+ " ax.set_ylabel('Values')\n",
869
+ " ax.set_title(f'Comparison of Target and Predicted Properties, {best_SMILES}')\n",
870
+ " ax.set_xticks(x)\n",
871
+ " ax.set_xticklabels(labels, rotation=45, ha=\"right\")\n",
872
+ " ax.legend()\n",
873
+ " fig.tight_layout()\n",
874
+ " plt.savefig(f\"{dir_path}/QM9_best_design_{target}_barplot_{sample_count}.svg\")\n",
875
+ " plt.show()\n",
876
+ " #plt.show()\n",
877
+ "\n",
878
+ "from rdkit import Chem\n",
879
+ "from rdkit.Chem import AllChem, Draw\n",
880
+ "from rdkit.Chem import AllChem, rdDepictor\n",
881
+ "from rdkit.Chem.Draw import rdMolDraw2D\n",
882
+ " \n",
883
+ "def prime_messages (SMILES_chitin_monomer, target, N=1):\n",
884
+ " messages=[]\n",
885
+ " for i in range (N):\n",
886
+ " \n",
887
+ " line=f'GenerateMolecularProperties<{return_str( target)}>'\n",
888
+ " messages.append ({\"role\": \"user\", \"content\": line}, )\n",
889
+ " line=f'[{SMILES_chitin_monomer}]'\n",
890
+ " messages.append ({\"role\": \"assistant\", \"content\": line}, )\n",
891
+ " \n",
892
+ " return messages\n",
893
+ "\n",
894
+ "from rdkit import Chem\n",
895
+ "from rdkit.Chem import AllChem\n",
896
+ "\n",
897
+ "def smiles_to_3d(smiles, num_confs=100):\n",
898
+ " mol = Chem.MolFromSmiles(smiles)\n",
899
+ " if mol is None:\n",
900
+ " print(\"Failed to create molecule from SMILES\")\n",
901
+ " return None\n",
902
+ "\n",
903
+ " mol = Chem.AddHs(mol)\n",
904
+ " params = AllChem.ETKDGv3()\n",
905
+ " params.randomSeed = 42\n",
906
+ " if not AllChem.EmbedMultipleConfs(mol, numConfs=num_confs, params=params):\n",
907
+ " print(\"Embedding conformations failed.\")\n",
908
+ " return None\n",
909
+ "\n",
910
+ " results = []\n",
911
+ " for conf_id in range(num_confs):\n",
912
+ " ff = AllChem.MMFFGetMoleculeForceField(mol, AllChem.MMFFGetMoleculeProperties(mol), confId=conf_id)\n",
913
+ " if ff is None:\n",
914
+ " print(f\"Failed to setup MMFF for conformer {conf_id}\")\n",
915
+ " continue\n",
916
+ " energy = ff.Minimize()\n",
917
+ " results.append((conf_id, ff.CalcEnergy()))\n",
918
+ "\n",
919
+ " if not results:\n",
920
+ " print(\"No successful energy minimization.\")\n",
921
+ " return None\n",
922
+ " \n",
923
+ "\n",
924
+ " best_conf = mol.GetConformer(min_energy_conf[0])\n",
925
+ " best_mol = Chem.Mol(mol)\n",
926
+ " best_mol.RemoveAllConformers()\n",
927
+ " best_mol.AddConformer(best_conf, assignId=True)\n",
928
+ "\n",
929
+ " coords = best_conf.GetPositions()\n",
930
+ " atom_symbols = [atom.GetSymbol() for atom in best_mol.GetAtoms()]\n",
931
+ " geometry = '\\n'.join(f'{atom} {coord[0]} {coord[1]} {coord[2]}' for atom, coord in zip(atom_symbols, coords))\n",
932
+ "\n",
933
+ " display (best_mol)\n",
934
+ " \n",
935
+ " return geometry, best_mol"
936
+ ]
937
+ },
938
+ {
939
+ "cell_type": "markdown",
940
+ "id": "23f18039-4441-496c-89b0-9e467eaac83e",
941
+ "metadata": {},
942
+ "source": [
943
+ "### Property calculation as possible starting point for design iterations "
944
+ ]
945
+ },
946
+ {
947
+ "cell_type": "code",
948
+ "execution_count": null,
949
+ "id": "6519474d-4e03-4273-a79e-454d5845e6d6",
950
+ "metadata": {
951
+ "scrolled": true
952
+ },
953
+ "outputs": [],
954
+ "source": [
955
+ "SMILES_START='O1C2C3OC2C13'\n",
956
+ "properties,_=avg_properties_from_SMILES (model, tokenizer, SMILES_START, SMILES_dir=SMILES_dir,\n",
957
+ " temperature=0.3, top_k=256,top_p=0.9, num_beams=1, repetition_penalty=1.,\n",
958
+ " labels=labels, N_prop=3, plot_results=True)\n"
959
+ ]
960
+ },
961
+ {
962
+ "cell_type": "code",
963
+ "execution_count": null,
964
+ "id": "198840ea-21f8-41eb-b62a-bc325261b731",
965
+ "metadata": {},
966
+ "outputs": [],
967
+ "source": [
968
+ "# Retrieve the scaling parameters\n",
969
+ "data_min = scaler.data_min_\n",
970
+ "data_max = scaler.data_max_\n",
971
+ "scale = scaler.scale_\n",
972
+ "feature_min = scaler.min_\n",
973
+ "\n",
974
+ "print(\"Feature Scaling Parameters:\")\n",
975
+ "print(\"{:<20} {:<20} {:<20} {:<20}\".format(\"Feature Index\", \"Min Value\", \"Max Value\", \"Scale Factor\"))\n",
976
+ "for i in range(len(data_min)):\n",
977
+ " print(\"{:<20} {:<20} {:<20} {:<20}\".format(i, data_min[i], data_max[i], scale[i]))\n",
978
+ "\n",
979
+ "print(\"\\nPer-feature Shifts (Min):\")\n",
980
+ "for i, min_val in enumerate(feature_min):\n",
981
+ " print(\"Feature {}: {:.6f}\".format(i, min_val))"
982
+ ]
983
+ },
984
+ {
985
+ "cell_type": "markdown",
986
+ "id": "0dd1f217-74c0-40f3-8edc-4b610c12e0ea",
987
+ "metadata": {},
988
+ "source": [
989
+ "### Molecular design: Iterative solution "
990
+ ]
991
+ },
992
+ {
993
+ "cell_type": "code",
994
+ "execution_count": null,
995
+ "id": "fc2747b6-90cc-4d42-bf93-bb39dc6d9198",
996
+ "metadata": {},
997
+ "outputs": [],
998
+ "source": [
999
+ "import copy \n",
1000
+ "properties=y_test[4]\n",
1001
+ "\n",
1002
+ "#Create new set of properties based on existing molecule (from test set)\n",
1003
+ "properties_new=copy.deepcopy (properties)\n",
1004
+ "properties_new[0]=properties[0]+0.2\n",
1005
+ "properties_new[1]=properties[1]+0.2\n",
1006
+ "plot_change_in_design (properties, labels, properties_new,SMILES_dir)"
1007
+ ]
1008
+ },
1009
+ {
1010
+ "cell_type": "code",
1011
+ "execution_count": null,
1012
+ "id": "c5f9b9c5-c746-48d1-841a-a2113d13279e",
1013
+ "metadata": {
1014
+ "scrolled": true
1015
+ },
1016
+ "outputs": [],
1017
+ "source": [
1018
+ "df_design=design_molecule_loop (model, tokenizer, np.array(properties_new), SMILES_LIST=SMILES_LIST, dir_path=SMILES_dir,\n",
1019
+ " temperature_pred=0.1, temperature_gen=0.3, top_k=32,top_p=0.1, repetition_penalty=1.,\n",
1020
+ " threshold=0.001, N_max=64, \n",
1021
+ " N_attempts_for_forward=6,\n",
1022
+ " )"
1023
+ ]
1024
+ },
1025
+ {
1026
+ "cell_type": "code",
1027
+ "execution_count": null,
1028
+ "id": "8c5be323-aa47-49dd-bc44-be74936c62c8",
1029
+ "metadata": {
1030
+ "scrolled": true
1031
+ },
1032
+ "outputs": [],
1033
+ "source": [
1034
+ "visualize_over_SMILES (df_design,N_max=30,SMILES_dir=SMILES_dir,per_row=5,\n",
1035
+ " lower_bound = 0.0, remove_duplicates=True,\n",
1036
+ " upper_bound = 0.02, target=np.array(properties_new))\n",
1037
+ "\n",
1038
+ "target=np.array(properties_new)\n",
1039
+ "best_SMILES, best_predicted, best_mse, is_novel = df_design_2.iloc[5]\n",
1040
+ "\n",
1041
+ "print (\"Best SILES: \", best_SMILES)\n",
1042
+ "print (f\"{best_SMILES} is novel: {is_novel}\")\n",
1043
+ "\n",
1044
+ "sns.set_style(\"whitegrid\")\n",
1045
+ "\n",
1046
+ "visualize_pred_vs_target (target, best_predicted, labels, dir_path=SMILES_dir, best_SMILES=best_SMILES,sample_count=0)\n",
1047
+ " \n",
1048
+ "visualize_SMILES(best_SMILES, dir_path=SMILES_dir, root=f'{target}_BEST')"
1049
+ ]
1050
+ },
1051
+ {
1052
+ "cell_type": "code",
1053
+ "execution_count": null,
1054
+ "id": "25fd7dbe-95fd-4169-86e9-b05e86bbfb3a",
1055
+ "metadata": {
1056
+ "scrolled": true
1057
+ },
1058
+ "outputs": [],
1059
+ "source": [
1060
+ "target=np.array(properties_new)\n",
1061
+ "best_SMILES, best_predicted, best_mse, is_novel = df_design_2.iloc[5]\n",
1062
+ "\n",
1063
+ "print (\"Best SILES: \", best_SMILES)\n",
1064
+ "print (f\"{best_SMILES} is novel: {is_novel}\")\n",
1065
+ "\n",
1066
+ "sns.set_style(\"whitegrid\")\n",
1067
+ "\n",
1068
+ "visualize_pred_vs_target (target, best_predicted, labels, dir_path=SMILES_dir, best_SMILES=best_SMILES,sample_count=0)\n",
1069
+ " \n",
1070
+ "visualize_SMILES(best_SMILES, dir_path=SMILES_dir, root=f'{target}_BEST')"
1071
+ ]
1072
+ }
1073
+ ],
1074
+ "metadata": {
1075
+ "environment": {
1076
+ "kernel": "python3",
1077
+ "name": ".m115",
1078
+ "type": "gcloud",
1079
+ "uri": "gcr.io/deeplearning-platform-release/:m115"
1080
+ },
1081
+ "kernelspec": {
1082
+ "display_name": "Python 3 (ipykernel)",
1083
+ "language": "python",
1084
+ "name": "python3"
1085
+ },
1086
+ "language_info": {
1087
+ "codemirror_mode": {
1088
+ "name": "ipython",
1089
+ "version": 3
1090
+ },
1091
+ "file_extension": ".py",
1092
+ "mimetype": "text/x-python",
1093
+ "name": "python",
1094
+ "nbconvert_exporter": "python",
1095
+ "pygments_lexer": "ipython3",
1096
+ "version": "3.11.7"
1097
+ }
1098
+ },
1099
+ "nbformat": 4,
1100
+ "nbformat_minor": 5
1101
+ }