Takashi Itoh commited on
Commit
f140fcb
1 Parent(s): fcac6c0
Files changed (1) hide show
  1. app.py +451 -131
app.py CHANGED
@@ -21,6 +21,7 @@ from transformers.modeling_outputs import BaseModelOutput
21
  os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
22
 
23
  import models.fm4m as fm4m
 
24
  RDLogger.logger().setLevel(RDLogger.ERROR)
25
 
26
 
@@ -32,20 +33,41 @@ def smiles_to_image(smiles):
32
 
33
  # Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
34
  smiles_image_mapping = {
35
- "Mol 1": {"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", "image": "img/img1.png"},
 
 
 
36
  # Example SMILES for ethanol
37
- "Mol 2": {"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", "image": "img/img2.png"},
 
 
 
38
  # Example SMILES for butane
39
- "Mol 3": {"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
40
- "image": "img/img3.png"}, # Example SMILES for ethylamine
41
- "Mol 4": {"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", "image": "img/img4.png"},
 
 
 
 
 
42
  # Example SMILES for diethyl ether
43
- "Mol 5": {"smiles": "C=CCS[C@@H](C)CC(=O)OCC", "image": "img/img5.png"} # Example SMILES for chloroethane
 
 
 
44
  }
45
 
46
  datasets = [" ", "BACE", "ESOL", "Load Custom Dataset"]
47
 
48
- models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED", "Mordred", "MorganFingerprint"]
 
 
 
 
 
 
 
49
 
50
  fusion_available = ["Concat"]
51
 
@@ -56,7 +78,12 @@ def evaluate_and_log(models, dataset, task_type, eval_output, state):
56
  result = f"{eval_output}"
57
  result = result.replace(" Score", "")
58
 
59
- new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_dic[task_type], "Result": result}
 
 
 
 
 
60
  new_entry_df = pd.DataFrame([new_entry])
61
 
62
  state["log_df"] = pd.concat([new_entry_df, state["log_df"]])
@@ -106,22 +133,42 @@ gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-
106
 
107
  def generate(latent_vector, mask):
108
  encoder_outputs = BaseModelOutput(latent_vector)
109
- decoder_output = gen_model.generate(encoder_outputs=encoder_outputs, attention_mask=mask,
110
- max_new_tokens=64, do_sample=True, top_k=5, top_p=0.95, num_return_sequences=1)
 
 
 
 
 
 
 
111
  selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
112
  return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]
113
 
114
 
115
  def perturb_latent(latent_vecs, noise_scale=0.5):
116
- return torch.tensor(np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
117
- dtype=torch.float32) + latent_vecs
 
 
 
 
 
118
 
119
 
120
  def encode(selfies):
121
- encoding = gen_tokenizer(selfies, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
 
 
 
 
 
 
122
  input_ids = encoding['input_ids']
123
  attention_mask = encoding['attention_mask']
124
- outputs = gen_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
 
 
125
  model_output = outputs.last_hidden_state
126
  return model_output, attention_mask
127
 
@@ -140,7 +187,8 @@ def generate_canonical(smiles):
140
  mol = Chem.MolFromSmiles(gen[0])
141
  if mol:
142
  gen_mol = Chem.MolToSmiles(mol)
143
- if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break
 
144
  else:
145
  print('Abnormal molecule:', gen[0])
146
 
@@ -154,9 +202,20 @@ def generate_canonical(smiles):
154
  # Prepare the table with ref mol and gen mol
155
  data = {
156
  "Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
157
- "Reference Mol": [ref_properties[0], ref_properties[1], ref_properties[2], ref_properties[3],
158
- tanimoto_similarity],
159
- "Generated Mol": [gen_properties[0], gen_properties[1], gen_properties[2], gen_properties[3], ""]
 
 
 
 
 
 
 
 
 
 
 
160
  }
161
  df = pd.DataFrame(data)
162
 
@@ -184,7 +243,6 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type, s
184
  downstream_model = downstream_model.rstrip()
185
  params = None
186
 
187
-
188
  try:
189
  if not selected_models:
190
  return "Please select at least one enabled model."
@@ -195,22 +253,38 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type, s
195
  downstream_model = "DefaultClassifier"
196
  params = None
197
 
198
- result, state["roc_auc"], state["fpr"], state["tpr"], state["x_batch"], state["y_batch"] = \
199
- fm4m.multi_modal(model_list=selected_models,
200
- downstream_model=downstream_model,
201
- params = params,
202
- dataset=dataset)
 
 
 
 
 
 
 
 
203
 
204
  elif task_type == "Regression":
205
  if downstream_model == "Default Settings":
206
  downstream_model = "DefaultRegressor"
207
  params = None
208
 
209
- result, state["RMSE"], state["y_batch_test"], state["y_prob"], state["x_batch"], state["y_batch"] = \
210
- fm4m.multi_modal(model_list=selected_models,
211
- downstream_model=downstream_model,
212
- params=params,
213
- dataset=dataset)
 
 
 
 
 
 
 
 
214
 
215
  else:
216
  if task_type == "Classification":
@@ -218,22 +292,38 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type, s
218
  downstream_model = "DefaultClassifier"
219
  params = None
220
 
221
- result, state["roc_auc"], state["fpr"], state["tpr"], state["x_batch"], state["y_batch"] = \
222
- fm4m.single_modal(model=selected_models[0],
223
- downstream_model=downstream_model,
224
- params=params,
225
- dataset=dataset)
 
 
 
 
 
 
 
 
226
 
227
  elif task_type == "Regression":
228
  if downstream_model == "Default Settings":
229
  downstream_model = "DefaultRegressor"
230
  params = None
231
 
232
- result, state["RMSE"], state["y_batch_test"], state["y_prob"], state["x_batch"], state["y_batch"] = \
233
- fm4m.single_modal(model=selected_models[0],
234
- downstream_model=downstream_model,
235
- params=params,
236
- dataset=dataset)
 
 
 
 
 
 
 
 
237
 
238
  if result == None:
239
  result = "Data & Model Setting is incorrect"
@@ -263,7 +353,13 @@ def display_plot(plot_type, state):
263
  roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr")
264
  ax.set_title("ROC-AUC Curve")
265
  try:
266
- ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})')
 
 
 
 
 
 
267
  ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
268
  ax.set_xlim([0.0, 1.0])
269
  ax.set_ylim([0.0, 1.05])
@@ -275,7 +371,11 @@ def display_plot(plot_type, state):
275
  ax.legend(loc='lower right')
276
 
277
  elif plot_type == "Parity Plot":
278
- RMSE, y_batch_test, y_prob = state.get("RMSE"), state.get("y_batch_test"), state.get("y_prob")
 
 
 
 
279
  ax.set_title("Parity plot")
280
 
281
  # change format
@@ -284,7 +384,12 @@ def display_plot(plot_type, state):
284
  print(y_prob)
285
  y_batch_test = np.array(y_batch_test, dtype=float)
286
  y_prob = np.array(y_prob, dtype=float)
287
- ax.scatter(y_batch_test, y_prob, color="blue", label=f"Predicted vs Actual (RMSE: {RMSE:.4f})")
 
 
 
 
 
288
  min_val = min(min(y_batch_test), min(y_prob))
289
  max_val = max(max(y_batch_test), max(y_prob))
290
  ax.plot([min_val, max_val], [min_val, max_val], 'r-')
@@ -297,7 +402,6 @@ def display_plot(plot_type, state):
297
  print(y_batch_test)
298
  print(y_prob)
299
 
300
-
301
  ax.set_xlabel('Actual Values')
302
  ax.set_ylabel('Predicted Values')
303
 
@@ -316,13 +420,25 @@ predefined_datasets = {
316
  # Function to load a predefined dataset from the local path
317
  def load_predefined_dataset(dataset_name):
318
  val = predefined_datasets.get(dataset_name)
319
- try: file_path = val.split(",")[0]
320
- except:file_path=False
 
 
321
 
322
  if file_path:
323
  df = pd.read_csv(file_path)
324
- return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns)), f"{dataset_name.lower()}"
325
- return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]), f"Dataset not found"
 
 
 
 
 
 
 
 
 
 
326
 
327
 
328
  # Function to display the head of the uploaded CSV file
@@ -330,7 +446,11 @@ def display_csv_head(file):
330
  if file is not None:
331
  # Load the CSV file into a DataFrame
332
  df = pd.read_csv(file.name)
333
- return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns))
 
 
 
 
334
  return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
335
 
336
 
@@ -338,28 +458,54 @@ def display_csv_head(file):
338
  def handle_dataset_selection(selected_dataset):
339
  if selected_dataset == "Custom Dataset":
340
  # Show file upload fields for train and test datasets if "Custom Dataset" is selected
341
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
342
- visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
 
 
 
 
 
 
 
 
343
  else:
344
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(
345
- visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
 
 
 
 
 
 
 
 
346
 
347
 
348
  # Function to select input and output columns and display a message
349
- def select_columns(input_column, output_column, train_data, test_data,dataset_name):
350
  if input_column and output_column:
351
  return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
352
  return "Please select both input and output columns."
353
 
354
- def set_dataname(dataset_name, dataset_selector ):
 
355
  if dataset_selector == "Custom Dataset":
356
  return f"{dataset_name}"
357
  return f"{dataset_selector}"
358
 
 
359
  # Function to create model based on user input
360
- def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None):
 
 
361
  if model_name == "XGBClassifier":
362
- model = xgb.XGBClassifier(objective='binary:logistic',eval_metric= 'auc', max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
 
 
 
 
 
 
363
  elif model_name == "SVR":
364
  model = SVR(degree=degree, kernel=kernel)
365
  elif model_name == "Kernel Ridge":
@@ -374,6 +520,7 @@ def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degr
374
 
375
  return f"{model_name} * {model.get_params()}"
376
 
 
377
  # Define the Gradio layout
378
  with gr.Blocks() as demo:
379
  log_df = pd.DataFrame(
@@ -383,79 +530,174 @@ with gr.Blocks() as demo:
383
  with gr.Row():
384
  # Left Column
385
  with gr.Column():
386
- gr.HTML('''
 
387
  <div style="background-color: #6A8EAE; color: #FFFFFF; padding: 10px;">
388
  <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
389
  </div>
390
- ''')
 
391
  # Dropdown menu for predefined datasets including "Custom Dataset" option
392
- dataset_selector = gr.Dropdown(label="Select Dataset",
393
- choices=list(predefined_datasets.keys()) + ["Custom Dataset"])
 
 
394
  # Display the message for selected columns
395
- selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=False)
 
 
396
 
397
  with gr.Accordion("Dataset Settings", open=True):
398
  # File upload options for custom dataset (train and test)
399
  dataset_name = gr.Textbox(label="Dataset Name", visible=False)
400
- train_file = gr.File(label="Upload Custom Train Dataset", file_types=[".csv"], visible=False)
401
- train_display = gr.Dataframe(label="Train Dataset Preview (First 5 Rows)", visible=False, interactive=False)
 
 
 
 
 
 
 
 
402
 
403
- test_file = gr.File(label="Upload Custom Test Dataset", file_types=[".csv"], visible=False)
404
- test_display = gr.Dataframe(label="Test Dataset Preview (First 5 Rows)", visible=False, interactive=False)
 
 
 
 
 
 
 
 
405
 
406
  # Predefined dataset displays
407
- predefined_display = gr.Dataframe(label="Predefined Dataset Preview (First 5 Rows)", visible=False,
408
- interactive=False)
 
 
 
409
 
410
  # Dropdowns for selecting input and output columns for the custom dataset
411
- input_column_selector = gr.Dropdown(label="Select Input Column", choices=[], visible=False)
412
- output_column_selector = gr.Dropdown(label="Select Output Column", choices=[], visible=False)
 
 
 
 
413
 
414
  # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
415
- dataset_selector.change(handle_dataset_selection,
416
- inputs=dataset_selector,
417
- outputs=[dataset_name, train_file, train_display, test_file, test_display, predefined_display,
418
- input_column_selector, output_column_selector])
 
 
 
 
 
 
 
 
 
 
419
 
420
  # When a predefined dataset is selected, load its head and update column selectors
421
- dataset_selector.change(load_predefined_dataset,
422
- inputs=dataset_selector,
423
- outputs=[predefined_display, input_column_selector, output_column_selector, selected_columns_message])
 
 
 
 
 
 
 
424
 
425
  # When a custom train file is uploaded, display its head and update column selectors
426
- train_file.change(display_csv_head, inputs=train_file,
427
- outputs=[train_display, input_column_selector, output_column_selector])
 
 
 
 
 
 
 
428
 
429
  # When a custom test file is uploaded, display its head
430
- test_file.change(display_csv_head, inputs=test_file,
431
- outputs=[test_display, input_column_selector, output_column_selector])
 
 
 
 
 
 
 
432
 
433
- dataset_selector.change(set_dataname,
434
- inputs=[dataset_name, dataset_selector],
435
- outputs=dataset_name)
 
 
436
 
437
  # Update the selected columns information when dropdown values are changed
438
- input_column_selector.change(select_columns,
439
- inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
440
- outputs=selected_columns_message)
 
 
 
 
 
 
 
 
441
 
442
- output_column_selector.change(select_columns,
443
- inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
444
- outputs=selected_columns_message)
 
 
 
 
 
 
 
 
445
 
446
- model_checkbox = gr.CheckboxGroup(choices=models_enabled, label="Select Model")
 
 
447
 
448
- task_radiobutton = gr.Radio(choices=["Classification", "Regression"], label="Task Type")
 
 
449
 
450
  ####### adding hyper parameter tuning ###########
451
- model_name = gr.Dropdown(["Default - Auto", "XGBClassifier", "SVR", "Kernel Ridge", "Linear Regression"], label="Select Downstream Model")
 
 
 
 
 
 
 
 
 
452
  with gr.Accordion("Downstream Hyperparameter Settings", open=True):
453
  # Create placeholders for hyperparameter components
454
- max_depth = gr.Slider(1, 20, step=1,visible=False, label="max_depth")
455
- n_estimators = gr.Slider(100, 5000, step=100, visible=False, label="n_estimators")
 
 
456
  alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
457
- degree = gr.Slider(1, 20, step=1,visible=False, label="degree")
458
- kernel = gr.Dropdown(choices=["rbf", "poly", "linear"], visible=False, label="kernel")
 
 
459
 
460
  # Output textbox
461
  output = gr.Textbox(label="Loaded Parameters")
@@ -463,24 +705,52 @@ with gr.Blocks() as demo:
463
  # Dynamically show relevant hyperparameters based on selected model
464
  def update_hyperparameters(model_name):
465
  if model_name == "XGBClassifier":
466
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
467
- visible=False), gr.update(visible=False)
 
 
 
 
 
468
  elif model_name == "SVR":
469
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
470
- visible=True), gr.update(visible=True)
 
 
 
 
 
471
  elif model_name == "Kernel Ridge":
472
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(
473
- visible=True), gr.update(visible=True)
 
 
 
 
 
474
  elif model_name == "Linear Regression":
475
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
476
- visible=False), gr.update(visible=False)
 
 
 
 
 
477
  elif model_name == "Default - Auto":
478
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
479
- visible=False), gr.update(visible=False)
 
 
 
 
 
480
 
481
  # When model is selected, update which hyperparameters are visible
482
- model_name.change(update_hyperparameters, inputs=[model_name],
483
- outputs=[max_depth, n_estimators, alpha, degree, kernel])
 
 
 
484
 
485
  # Submit button to create the model with selected hyperparameters
486
  submit_button = gr.Button("Create Downstream Model")
@@ -488,19 +758,29 @@ with gr.Blocks() as demo:
488
  # Function to handle model creation based on input parameters
489
  def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel):
490
  if model_name == "XGBClassifier":
491
- return create_model(model_name, max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
 
 
 
 
 
492
  elif model_name == "SVR":
493
  return create_model(model_name, degree=degree, kernel=kernel)
494
  elif model_name == "Kernel Ridge":
495
- return create_model(model_name, alpha=alpha, degree=degree, kernel=kernel)
 
 
496
  elif model_name == "Linear Regression":
497
  return create_model(model_name)
498
  elif model_name == "Default - Auto":
499
  return create_model(model_name)
500
 
501
  # When the submit button is clicked, run the on_submit function
502
- submit_button.click(on_submit, inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
503
- outputs=output)
 
 
 
504
  ###### End of hyper param tuning #########
505
 
506
  fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
@@ -509,40 +789,69 @@ with gr.Blocks() as demo:
509
 
510
  # Middle Column
511
  with gr.Column():
512
- gr.HTML('''
 
513
  <div style="background-color: #8F9779; color: #FFFFFF; padding: 10px;">
514
  <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
515
  </div>
516
- ''')
 
517
  eval_output = gr.Textbox(label="Train downstream model")
518
 
519
- plot_radio = gr.Radio(choices=["ROC-AUC", "Parity Plot", "Latent Space"], label="Select Plot Type")
 
 
 
520
  plot_output = gr.Plot(label="Visualization")
521
 
522
  create_log = gr.Button("Store log")
523
 
524
- log_table = gr.Dataframe(value=log_df, label="Log of Selections and Results", interactive=False)
525
-
526
- eval_button.click(display_eval,
527
- inputs=[model_checkbox, selected_columns_message, task_radiobutton, output, fusion_radiobutton, state],
528
- outputs=eval_output)
529
-
530
- plot_radio.change(display_plot, inputs=[plot_radio, state], outputs=plot_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
  # Function to gather selected models
533
  def gather_selected_models(*models):
534
  selected = [model for model in models if model]
535
  return selected
536
 
537
- create_log.click(evaluate_and_log, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output, state],
538
- outputs=log_table)
 
 
 
 
 
 
 
 
 
539
  # Right Column
540
  with gr.Column():
541
- gr.HTML('''
 
542
  <div style="background-color: #D2B48C; color: #FFFFFF; padding: 10px;">
543
  <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
544
  </div>
545
- ''')
 
546
  smiles_input = gr.Textbox(label="Input SMILES String")
547
  image_display = gr.Image(label="Molecule Image", height=250, width=250)
548
  # Show images for selection
@@ -554,17 +863,28 @@ with gr.Blocks() as demo:
554
  )
555
  image_selector.change(load_image, image_selector, image_display)
556
  generate_button = gr.Button("Generate")
557
- gen_image_display = gr.Image(label="Generated Molecule Image", height=250, width=250)
 
 
558
  generated_output = gr.Textbox(label="Generated Output")
559
  property_table = gr.Dataframe(label="Molecular Properties Comparison")
560
 
561
  # Handle image selection
562
- image_selector.change(handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display])
563
- smiles_input.change(smiles_to_image, inputs=smiles_input, outputs=image_display)
 
 
 
 
 
 
564
 
565
  # Generate button to display canonical SMILES and molecule image
566
- generate_button.click(generate_canonical, inputs=smiles_input,
567
- outputs=[property_table, generated_output, gen_image_display])
 
 
 
568
 
569
 
570
  if __name__ == "__main__":
 
21
  os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
22
 
23
  import models.fm4m as fm4m
24
+
25
  RDLogger.logger().setLevel(RDLogger.ERROR)
26
 
27
 
 
33
 
34
  # Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
35
  smiles_image_mapping = {
36
+ "Mol 1": {
37
+ "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
38
+ "image": "img/img1.png",
39
+ },
40
  # Example SMILES for ethanol
41
+ "Mol 2": {
42
+ "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
43
+ "image": "img/img2.png",
44
+ },
45
  # Example SMILES for butane
46
+ "Mol 3": {
47
+ "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
48
+ "image": "img/img3.png",
49
+ }, # Example SMILES for ethylamine
50
+ "Mol 4": {
51
+ "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
52
+ "image": "img/img4.png",
53
+ },
54
  # Example SMILES for diethyl ether
55
+ "Mol 5": {
56
+ "smiles": "C=CCS[C@@H](C)CC(=O)OCC",
57
+ "image": "img/img5.png",
58
+ }, # Example SMILES for chloroethane
59
  }
60
 
61
  datasets = [" ", "BACE", "ESOL", "Load Custom Dataset"]
62
 
63
+ models_enabled = [
64
+ "SELFIES-TED",
65
+ "MHG-GED",
66
+ "MolFormer",
67
+ "SMI-TED",
68
+ "Mordred",
69
+ "MorganFingerprint",
70
+ ]
71
 
72
  fusion_available = ["Concat"]
73
 
 
78
  result = f"{eval_output}"
79
  result = result.replace(" Score", "")
80
 
81
+ new_entry = {
82
+ "Selected Models": str(models),
83
+ "Dataset": dataset,
84
+ "Task": task_dic[task_type],
85
+ "Result": result,
86
+ }
87
  new_entry_df = pd.DataFrame([new_entry])
88
 
89
  state["log_df"] = pd.concat([new_entry_df, state["log_df"]])
 
133
 
134
  def generate(latent_vector, mask):
135
  encoder_outputs = BaseModelOutput(latent_vector)
136
+ decoder_output = gen_model.generate(
137
+ encoder_outputs=encoder_outputs,
138
+ attention_mask=mask,
139
+ max_new_tokens=64,
140
+ do_sample=True,
141
+ top_k=5,
142
+ top_p=0.95,
143
+ num_return_sequences=1,
144
+ )
145
  selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
146
  return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]
147
 
148
 
149
  def perturb_latent(latent_vecs, noise_scale=0.5):
150
+ return (
151
+ torch.tensor(
152
+ np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
153
+ dtype=torch.float32,
154
+ )
155
+ + latent_vecs
156
+ )
157
 
158
 
159
  def encode(selfies):
160
+ encoding = gen_tokenizer(
161
+ selfies,
162
+ return_tensors='pt',
163
+ max_length=128,
164
+ truncation=True,
165
+ padding='max_length',
166
+ )
167
  input_ids = encoding['input_ids']
168
  attention_mask = encoding['attention_mask']
169
+ outputs = gen_model.model.encoder(
170
+ input_ids=input_ids, attention_mask=attention_mask
171
+ )
172
  model_output = outputs.last_hidden_state
173
  return model_output, attention_mask
174
 
 
187
  mol = Chem.MolFromSmiles(gen[0])
188
  if mol:
189
  gen_mol = Chem.MolToSmiles(mol)
190
+ if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
191
+ break
192
  else:
193
  print('Abnormal molecule:', gen[0])
194
 
 
202
  # Prepare the table with ref mol and gen mol
203
  data = {
204
  "Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
205
+ "Reference Mol": [
206
+ ref_properties[0],
207
+ ref_properties[1],
208
+ ref_properties[2],
209
+ ref_properties[3],
210
+ tanimoto_similarity,
211
+ ],
212
+ "Generated Mol": [
213
+ gen_properties[0],
214
+ gen_properties[1],
215
+ gen_properties[2],
216
+ gen_properties[3],
217
+ "",
218
+ ],
219
  }
220
  df = pd.DataFrame(data)
221
 
 
243
  downstream_model = downstream_model.rstrip()
244
  params = None
245
 
 
246
  try:
247
  if not selected_models:
248
  return "Please select at least one enabled model."
 
253
  downstream_model = "DefaultClassifier"
254
  params = None
255
 
256
+ (
257
+ result,
258
+ state["roc_auc"],
259
+ state["fpr"],
260
+ state["tpr"],
261
+ state["x_batch"],
262
+ state["y_batch"],
263
+ ) = fm4m.multi_modal(
264
+ model_list=selected_models,
265
+ downstream_model=downstream_model,
266
+ params=params,
267
+ dataset=dataset,
268
+ )
269
 
270
  elif task_type == "Regression":
271
  if downstream_model == "Default Settings":
272
  downstream_model = "DefaultRegressor"
273
  params = None
274
 
275
+ (
276
+ result,
277
+ state["RMSE"],
278
+ state["y_batch_test"],
279
+ state["y_prob"],
280
+ state["x_batch"],
281
+ state["y_batch"],
282
+ ) = fm4m.multi_modal(
283
+ model_list=selected_models,
284
+ downstream_model=downstream_model,
285
+ params=params,
286
+ dataset=dataset,
287
+ )
288
 
289
  else:
290
  if task_type == "Classification":
 
292
  downstream_model = "DefaultClassifier"
293
  params = None
294
 
295
+ (
296
+ result,
297
+ state["roc_auc"],
298
+ state["fpr"],
299
+ state["tpr"],
300
+ state["x_batch"],
301
+ state["y_batch"],
302
+ ) = fm4m.single_modal(
303
+ model=selected_models[0],
304
+ downstream_model=downstream_model,
305
+ params=params,
306
+ dataset=dataset,
307
+ )
308
 
309
  elif task_type == "Regression":
310
  if downstream_model == "Default Settings":
311
  downstream_model = "DefaultRegressor"
312
  params = None
313
 
314
+ (
315
+ result,
316
+ state["RMSE"],
317
+ state["y_batch_test"],
318
+ state["y_prob"],
319
+ state["x_batch"],
320
+ state["y_batch"],
321
+ ) = fm4m.single_modal(
322
+ model=selected_models[0],
323
+ downstream_model=downstream_model,
324
+ params=params,
325
+ dataset=dataset,
326
+ )
327
 
328
  if result == None:
329
  result = "Data & Model Setting is incorrect"
 
353
  roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr")
354
  ax.set_title("ROC-AUC Curve")
355
  try:
356
+ ax.plot(
357
+ fpr,
358
+ tpr,
359
+ color='darkorange',
360
+ lw=2,
361
+ label=f'ROC curve (area = {roc_auc:.4f})',
362
+ )
363
  ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
364
  ax.set_xlim([0.0, 1.0])
365
  ax.set_ylim([0.0, 1.05])
 
371
  ax.legend(loc='lower right')
372
 
373
  elif plot_type == "Parity Plot":
374
+ RMSE, y_batch_test, y_prob = (
375
+ state.get("RMSE"),
376
+ state.get("y_batch_test"),
377
+ state.get("y_prob"),
378
+ )
379
  ax.set_title("Parity plot")
380
 
381
  # change format
 
384
  print(y_prob)
385
  y_batch_test = np.array(y_batch_test, dtype=float)
386
  y_prob = np.array(y_prob, dtype=float)
387
+ ax.scatter(
388
+ y_batch_test,
389
+ y_prob,
390
+ color="blue",
391
+ label=f"Predicted vs Actual (RMSE: {RMSE:.4f})",
392
+ )
393
  min_val = min(min(y_batch_test), min(y_prob))
394
  max_val = max(max(y_batch_test), max(y_prob))
395
  ax.plot([min_val, max_val], [min_val, max_val], 'r-')
 
402
  print(y_batch_test)
403
  print(y_prob)
404
 
 
405
  ax.set_xlabel('Actual Values')
406
  ax.set_ylabel('Predicted Values')
407
 
 
420
  # Function to load a predefined dataset from the local path
421
  def load_predefined_dataset(dataset_name):
422
  val = predefined_datasets.get(dataset_name)
423
+ try:
424
+ file_path = val.split(",")[0]
425
+ except:
426
+ file_path = False
427
 
428
  if file_path:
429
  df = pd.read_csv(file_path)
430
+ return (
431
+ df.head(),
432
+ gr.update(choices=list(df.columns)),
433
+ gr.update(choices=list(df.columns)),
434
+ f"{dataset_name.lower()}",
435
+ )
436
+ return (
437
+ pd.DataFrame(),
438
+ gr.update(choices=[]),
439
+ gr.update(choices=[]),
440
+ f"Dataset not found",
441
+ )
442
 
443
 
444
  # Function to display the head of the uploaded CSV file
 
446
  if file is not None:
447
  # Load the CSV file into a DataFrame
448
  df = pd.read_csv(file.name)
449
+ return (
450
+ df.head(),
451
+ gr.update(choices=list(df.columns)),
452
+ gr.update(choices=list(df.columns)),
453
+ )
454
  return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
455
 
456
 
 
458
  def handle_dataset_selection(selected_dataset):
459
  if selected_dataset == "Custom Dataset":
460
  # Show file upload fields for train and test datasets if "Custom Dataset" is selected
461
+ return (
462
+ gr.update(visible=True),
463
+ gr.update(visible=True),
464
+ gr.update(visible=True),
465
+ gr.update(visible=True),
466
+ gr.update(visible=True),
467
+ gr.update(visible=False),
468
+ gr.update(visible=True),
469
+ gr.update(visible=True),
470
+ )
471
  else:
472
+ return (
473
+ gr.update(visible=True),
474
+ gr.update(visible=False),
475
+ gr.update(visible=False),
476
+ gr.update(visible=False),
477
+ gr.update(visible=False),
478
+ gr.update(visible=False),
479
+ gr.update(visible=False),
480
+ gr.update(visible=False),
481
+ )
482
 
483
 
484
  # Function to select input and output columns and display a message
485
+ def select_columns(input_column, output_column, train_data, test_data, dataset_name):
486
  if input_column and output_column:
487
  return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
488
  return "Please select both input and output columns."
489
 
490
+
491
+ def set_dataname(dataset_name, dataset_selector):
492
  if dataset_selector == "Custom Dataset":
493
  return f"{dataset_name}"
494
  return f"{dataset_selector}"
495
 
496
+
497
  # Function to create model based on user input
498
+ def create_model(
499
+ model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None
500
+ ):
501
  if model_name == "XGBClassifier":
502
+ model = xgb.XGBClassifier(
503
+ objective='binary:logistic',
504
+ eval_metric='auc',
505
+ max_depth=max_depth,
506
+ n_estimators=n_estimators,
507
+ alpha=alpha,
508
+ )
509
  elif model_name == "SVR":
510
  model = SVR(degree=degree, kernel=kernel)
511
  elif model_name == "Kernel Ridge":
 
520
 
521
  return f"{model_name} * {model.get_params()}"
522
 
523
+
524
  # Define the Gradio layout
525
  with gr.Blocks() as demo:
526
  log_df = pd.DataFrame(
 
530
  with gr.Row():
531
  # Left Column
532
  with gr.Column():
533
+ gr.HTML(
534
+ '''
535
  <div style="background-color: #6A8EAE; color: #FFFFFF; padding: 10px;">
536
  <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
537
  </div>
538
+ '''
539
+ )
540
  # Dropdown menu for predefined datasets including "Custom Dataset" option
541
+ dataset_selector = gr.Dropdown(
542
+ label="Select Dataset",
543
+ choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
544
+ )
545
  # Display the message for selected columns
546
+ selected_columns_message = gr.Textbox(
547
+ label="Selected Columns Info", visible=False
548
+ )
549
 
550
  with gr.Accordion("Dataset Settings", open=True):
551
  # File upload options for custom dataset (train and test)
552
  dataset_name = gr.Textbox(label="Dataset Name", visible=False)
553
+ train_file = gr.File(
554
+ label="Upload Custom Train Dataset",
555
+ file_types=[".csv"],
556
+ visible=False,
557
+ )
558
+ train_display = gr.Dataframe(
559
+ label="Train Dataset Preview (First 5 Rows)",
560
+ visible=False,
561
+ interactive=False,
562
+ )
563
 
564
+ test_file = gr.File(
565
+ label="Upload Custom Test Dataset",
566
+ file_types=[".csv"],
567
+ visible=False,
568
+ )
569
+ test_display = gr.Dataframe(
570
+ label="Test Dataset Preview (First 5 Rows)",
571
+ visible=False,
572
+ interactive=False,
573
+ )
574
 
575
  # Predefined dataset displays
576
+ predefined_display = gr.Dataframe(
577
+ label="Predefined Dataset Preview (First 5 Rows)",
578
+ visible=False,
579
+ interactive=False,
580
+ )
581
 
582
  # Dropdowns for selecting input and output columns for the custom dataset
583
+ input_column_selector = gr.Dropdown(
584
+ label="Select Input Column", choices=[], visible=False
585
+ )
586
+ output_column_selector = gr.Dropdown(
587
+ label="Select Output Column", choices=[], visible=False
588
+ )
589
 
590
  # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
591
+ dataset_selector.change(
592
+ handle_dataset_selection,
593
+ inputs=dataset_selector,
594
+ outputs=[
595
+ dataset_name,
596
+ train_file,
597
+ train_display,
598
+ test_file,
599
+ test_display,
600
+ predefined_display,
601
+ input_column_selector,
602
+ output_column_selector,
603
+ ],
604
+ )
605
 
606
  # When a predefined dataset is selected, load its head and update column selectors
607
+ dataset_selector.change(
608
+ load_predefined_dataset,
609
+ inputs=dataset_selector,
610
+ outputs=[
611
+ predefined_display,
612
+ input_column_selector,
613
+ output_column_selector,
614
+ selected_columns_message,
615
+ ],
616
+ )
617
 
618
  # When a custom train file is uploaded, display its head and update column selectors
619
+ train_file.change(
620
+ display_csv_head,
621
+ inputs=train_file,
622
+ outputs=[
623
+ train_display,
624
+ input_column_selector,
625
+ output_column_selector,
626
+ ],
627
+ )
628
 
629
  # When a custom test file is uploaded, display its head
630
+ test_file.change(
631
+ display_csv_head,
632
+ inputs=test_file,
633
+ outputs=[
634
+ test_display,
635
+ input_column_selector,
636
+ output_column_selector,
637
+ ],
638
+ )
639
 
640
+ dataset_selector.change(
641
+ set_dataname,
642
+ inputs=[dataset_name, dataset_selector],
643
+ outputs=dataset_name,
644
+ )
645
 
646
  # Update the selected columns information when dropdown values are changed
647
+ input_column_selector.change(
648
+ select_columns,
649
+ inputs=[
650
+ input_column_selector,
651
+ output_column_selector,
652
+ train_file,
653
+ test_file,
654
+ dataset_name,
655
+ ],
656
+ outputs=selected_columns_message,
657
+ )
658
 
659
+ output_column_selector.change(
660
+ select_columns,
661
+ inputs=[
662
+ input_column_selector,
663
+ output_column_selector,
664
+ train_file,
665
+ test_file,
666
+ dataset_name,
667
+ ],
668
+ outputs=selected_columns_message,
669
+ )
670
 
671
+ model_checkbox = gr.CheckboxGroup(
672
+ choices=models_enabled, label="Select Model"
673
+ )
674
 
675
+ task_radiobutton = gr.Radio(
676
+ choices=["Classification", "Regression"], label="Task Type"
677
+ )
678
 
679
  ####### adding hyper parameter tuning ###########
680
+ model_name = gr.Dropdown(
681
+ [
682
+ "Default - Auto",
683
+ "XGBClassifier",
684
+ "SVR",
685
+ "Kernel Ridge",
686
+ "Linear Regression",
687
+ ],
688
+ label="Select Downstream Model",
689
+ )
690
  with gr.Accordion("Downstream Hyperparameter Settings", open=True):
691
  # Create placeholders for hyperparameter components
692
+ max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth")
693
+ n_estimators = gr.Slider(
694
+ 100, 5000, step=100, visible=False, label="n_estimators"
695
+ )
696
  alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
697
+ degree = gr.Slider(1, 20, step=1, visible=False, label="degree")
698
+ kernel = gr.Dropdown(
699
+ choices=["rbf", "poly", "linear"], visible=False, label="kernel"
700
+ )
701
 
702
  # Output textbox
703
  output = gr.Textbox(label="Loaded Parameters")
 
705
  # Dynamically show relevant hyperparameters based on selected model
706
  def update_hyperparameters(model_name):
707
  if model_name == "XGBClassifier":
708
+ return (
709
+ gr.update(visible=True),
710
+ gr.update(visible=True),
711
+ gr.update(visible=True),
712
+ gr.update(visible=False),
713
+ gr.update(visible=False),
714
+ )
715
  elif model_name == "SVR":
716
+ return (
717
+ gr.update(visible=False),
718
+ gr.update(visible=False),
719
+ gr.update(visible=False),
720
+ gr.update(visible=True),
721
+ gr.update(visible=True),
722
+ )
723
  elif model_name == "Kernel Ridge":
724
+ return (
725
+ gr.update(visible=False),
726
+ gr.update(visible=False),
727
+ gr.update(visible=True),
728
+ gr.update(visible=True),
729
+ gr.update(visible=True),
730
+ )
731
  elif model_name == "Linear Regression":
732
+ return (
733
+ gr.update(visible=False),
734
+ gr.update(visible=False),
735
+ gr.update(visible=False),
736
+ gr.update(visible=False),
737
+ gr.update(visible=False),
738
+ )
739
  elif model_name == "Default - Auto":
740
+ return (
741
+ gr.update(visible=False),
742
+ gr.update(visible=False),
743
+ gr.update(visible=False),
744
+ gr.update(visible=False),
745
+ gr.update(visible=False),
746
+ )
747
 
748
  # When model is selected, update which hyperparameters are visible
749
+ model_name.change(
750
+ update_hyperparameters,
751
+ inputs=[model_name],
752
+ outputs=[max_depth, n_estimators, alpha, degree, kernel],
753
+ )
754
 
755
  # Submit button to create the model with selected hyperparameters
756
  submit_button = gr.Button("Create Downstream Model")
 
758
  # Function to handle model creation based on input parameters
759
  def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel):
760
  if model_name == "XGBClassifier":
761
+ return create_model(
762
+ model_name,
763
+ max_depth=max_depth,
764
+ n_estimators=n_estimators,
765
+ alpha=alpha,
766
+ )
767
  elif model_name == "SVR":
768
  return create_model(model_name, degree=degree, kernel=kernel)
769
  elif model_name == "Kernel Ridge":
770
+ return create_model(
771
+ model_name, alpha=alpha, degree=degree, kernel=kernel
772
+ )
773
  elif model_name == "Linear Regression":
774
  return create_model(model_name)
775
  elif model_name == "Default - Auto":
776
  return create_model(model_name)
777
 
778
  # When the submit button is clicked, run the on_submit function
779
+ submit_button.click(
780
+ on_submit,
781
+ inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
782
+ outputs=output,
783
+ )
784
  ###### End of hyper param tuning #########
785
 
786
  fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
 
789
 
790
  # Middle Column
791
  with gr.Column():
792
+ gr.HTML(
793
+ '''
794
  <div style="background-color: #8F9779; color: #FFFFFF; padding: 10px;">
795
  <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
796
  </div>
797
+ '''
798
+ )
799
  eval_output = gr.Textbox(label="Train downstream model")
800
 
801
+ plot_radio = gr.Radio(
802
+ choices=["ROC-AUC", "Parity Plot", "Latent Space"],
803
+ label="Select Plot Type",
804
+ )
805
  plot_output = gr.Plot(label="Visualization")
806
 
807
  create_log = gr.Button("Store log")
808
 
809
+ log_table = gr.Dataframe(
810
+ value=log_df, label="Log of Selections and Results", interactive=False
811
+ )
812
+
813
+ eval_button.click(
814
+ display_eval,
815
+ inputs=[
816
+ model_checkbox,
817
+ selected_columns_message,
818
+ task_radiobutton,
819
+ output,
820
+ fusion_radiobutton,
821
+ state,
822
+ ],
823
+ outputs=eval_output,
824
+ )
825
+
826
+ plot_radio.change(
827
+ display_plot, inputs=[plot_radio, state], outputs=plot_output
828
+ )
829
 
830
  # Function to gather selected models
831
  def gather_selected_models(*models):
832
  selected = [model for model in models if model]
833
  return selected
834
 
835
+ create_log.click(
836
+ evaluate_and_log,
837
+ inputs=[
838
+ model_checkbox,
839
+ dataset_name,
840
+ task_radiobutton,
841
+ eval_output,
842
+ state,
843
+ ],
844
+ outputs=log_table,
845
+ )
846
  # Right Column
847
  with gr.Column():
848
+ gr.HTML(
849
+ '''
850
  <div style="background-color: #D2B48C; color: #FFFFFF; padding: 10px;">
851
  <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
852
  </div>
853
+ '''
854
+ )
855
  smiles_input = gr.Textbox(label="Input SMILES String")
856
  image_display = gr.Image(label="Molecule Image", height=250, width=250)
857
  # Show images for selection
 
863
  )
864
  image_selector.change(load_image, image_selector, image_display)
865
  generate_button = gr.Button("Generate")
866
+ gen_image_display = gr.Image(
867
+ label="Generated Molecule Image", height=250, width=250
868
+ )
869
  generated_output = gr.Textbox(label="Generated Output")
870
  property_table = gr.Dataframe(label="Molecular Properties Comparison")
871
 
872
  # Handle image selection
873
+ image_selector.change(
874
+ handle_image_selection,
875
+ inputs=image_selector,
876
+ outputs=[smiles_input, image_display],
877
+ )
878
+ smiles_input.change(
879
+ smiles_to_image, inputs=smiles_input, outputs=image_display
880
+ )
881
 
882
  # Generate button to display canonical SMILES and molecule image
883
+ generate_button.click(
884
+ generate_canonical,
885
+ inputs=smiles_input,
886
+ outputs=[property_table, generated_output, gen_image_display],
887
+ )
888
 
889
 
890
  if __name__ == "__main__":