Spaces:
Running
Running
Takashi Itoh
commited on
Commit
•
f140fcb
1
Parent(s):
fcac6c0
Reformat
Browse files
app.py
CHANGED
@@ -21,6 +21,7 @@ from transformers.modeling_outputs import BaseModelOutput
|
|
21 |
os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
|
22 |
|
23 |
import models.fm4m as fm4m
|
|
|
24 |
RDLogger.logger().setLevel(RDLogger.ERROR)
|
25 |
|
26 |
|
@@ -32,20 +33,41 @@ def smiles_to_image(smiles):
|
|
32 |
|
33 |
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
|
34 |
smiles_image_mapping = {
|
35 |
-
"Mol 1": {
|
|
|
|
|
|
|
36 |
# Example SMILES for ethanol
|
37 |
-
"Mol 2": {
|
|
|
|
|
|
|
38 |
# Example SMILES for butane
|
39 |
-
"Mol 3": {
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
# Example SMILES for diethyl ether
|
43 |
-
"Mol 5": {
|
|
|
|
|
|
|
44 |
}
|
45 |
|
46 |
datasets = [" ", "BACE", "ESOL", "Load Custom Dataset"]
|
47 |
|
48 |
-
models_enabled = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
fusion_available = ["Concat"]
|
51 |
|
@@ -56,7 +78,12 @@ def evaluate_and_log(models, dataset, task_type, eval_output, state):
|
|
56 |
result = f"{eval_output}"
|
57 |
result = result.replace(" Score", "")
|
58 |
|
59 |
-
new_entry = {
|
|
|
|
|
|
|
|
|
|
|
60 |
new_entry_df = pd.DataFrame([new_entry])
|
61 |
|
62 |
state["log_df"] = pd.concat([new_entry_df, state["log_df"]])
|
@@ -106,22 +133,42 @@ gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-
|
|
106 |
|
107 |
def generate(latent_vector, mask):
|
108 |
encoder_outputs = BaseModelOutput(latent_vector)
|
109 |
-
decoder_output = gen_model.generate(
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
|
112 |
return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]
|
113 |
|
114 |
|
115 |
def perturb_latent(latent_vecs, noise_scale=0.5):
|
116 |
-
return
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
|
120 |
def encode(selfies):
|
121 |
-
encoding = gen_tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
input_ids = encoding['input_ids']
|
123 |
attention_mask = encoding['attention_mask']
|
124 |
-
outputs = gen_model.model.encoder(
|
|
|
|
|
125 |
model_output = outputs.last_hidden_state
|
126 |
return model_output, attention_mask
|
127 |
|
@@ -140,7 +187,8 @@ def generate_canonical(smiles):
|
|
140 |
mol = Chem.MolFromSmiles(gen[0])
|
141 |
if mol:
|
142 |
gen_mol = Chem.MolToSmiles(mol)
|
143 |
-
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
|
|
|
144 |
else:
|
145 |
print('Abnormal molecule:', gen[0])
|
146 |
|
@@ -154,9 +202,20 @@ def generate_canonical(smiles):
|
|
154 |
# Prepare the table with ref mol and gen mol
|
155 |
data = {
|
156 |
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
|
157 |
-
"Reference Mol": [
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
}
|
161 |
df = pd.DataFrame(data)
|
162 |
|
@@ -184,7 +243,6 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type, s
|
|
184 |
downstream_model = downstream_model.rstrip()
|
185 |
params = None
|
186 |
|
187 |
-
|
188 |
try:
|
189 |
if not selected_models:
|
190 |
return "Please select at least one enabled model."
|
@@ -195,22 +253,38 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type, s
|
|
195 |
downstream_model = "DefaultClassifier"
|
196 |
params = None
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
elif task_type == "Regression":
|
205 |
if downstream_model == "Default Settings":
|
206 |
downstream_model = "DefaultRegressor"
|
207 |
params = None
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
else:
|
216 |
if task_type == "Classification":
|
@@ -218,22 +292,38 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type, s
|
|
218 |
downstream_model = "DefaultClassifier"
|
219 |
params = None
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
elif task_type == "Regression":
|
228 |
if downstream_model == "Default Settings":
|
229 |
downstream_model = "DefaultRegressor"
|
230 |
params = None
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
if result == None:
|
239 |
result = "Data & Model Setting is incorrect"
|
@@ -263,7 +353,13 @@ def display_plot(plot_type, state):
|
|
263 |
roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr")
|
264 |
ax.set_title("ROC-AUC Curve")
|
265 |
try:
|
266 |
-
ax.plot(
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
268 |
ax.set_xlim([0.0, 1.0])
|
269 |
ax.set_ylim([0.0, 1.05])
|
@@ -275,7 +371,11 @@ def display_plot(plot_type, state):
|
|
275 |
ax.legend(loc='lower right')
|
276 |
|
277 |
elif plot_type == "Parity Plot":
|
278 |
-
RMSE, y_batch_test, y_prob =
|
|
|
|
|
|
|
|
|
279 |
ax.set_title("Parity plot")
|
280 |
|
281 |
# change format
|
@@ -284,7 +384,12 @@ def display_plot(plot_type, state):
|
|
284 |
print(y_prob)
|
285 |
y_batch_test = np.array(y_batch_test, dtype=float)
|
286 |
y_prob = np.array(y_prob, dtype=float)
|
287 |
-
ax.scatter(
|
|
|
|
|
|
|
|
|
|
|
288 |
min_val = min(min(y_batch_test), min(y_prob))
|
289 |
max_val = max(max(y_batch_test), max(y_prob))
|
290 |
ax.plot([min_val, max_val], [min_val, max_val], 'r-')
|
@@ -297,7 +402,6 @@ def display_plot(plot_type, state):
|
|
297 |
print(y_batch_test)
|
298 |
print(y_prob)
|
299 |
|
300 |
-
|
301 |
ax.set_xlabel('Actual Values')
|
302 |
ax.set_ylabel('Predicted Values')
|
303 |
|
@@ -316,13 +420,25 @@ predefined_datasets = {
|
|
316 |
# Function to load a predefined dataset from the local path
|
317 |
def load_predefined_dataset(dataset_name):
|
318 |
val = predefined_datasets.get(dataset_name)
|
319 |
-
try:
|
320 |
-
|
|
|
|
|
321 |
|
322 |
if file_path:
|
323 |
df = pd.read_csv(file_path)
|
324 |
-
return
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
|
328 |
# Function to display the head of the uploaded CSV file
|
@@ -330,7 +446,11 @@ def display_csv_head(file):
|
|
330 |
if file is not None:
|
331 |
# Load the CSV file into a DataFrame
|
332 |
df = pd.read_csv(file.name)
|
333 |
-
return
|
|
|
|
|
|
|
|
|
334 |
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
|
335 |
|
336 |
|
@@ -338,28 +458,54 @@ def display_csv_head(file):
|
|
338 |
def handle_dataset_selection(selected_dataset):
|
339 |
if selected_dataset == "Custom Dataset":
|
340 |
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
|
341 |
-
return
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
else:
|
344 |
-
return
|
345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
|
347 |
|
348 |
# Function to select input and output columns and display a message
|
349 |
-
def select_columns(input_column, output_column, train_data, test_data,dataset_name):
|
350 |
if input_column and output_column:
|
351 |
return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
|
352 |
return "Please select both input and output columns."
|
353 |
|
354 |
-
|
|
|
355 |
if dataset_selector == "Custom Dataset":
|
356 |
return f"{dataset_name}"
|
357 |
return f"{dataset_selector}"
|
358 |
|
|
|
359 |
# Function to create model based on user input
|
360 |
-
def create_model(
|
|
|
|
|
361 |
if model_name == "XGBClassifier":
|
362 |
-
model = xgb.XGBClassifier(
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
elif model_name == "SVR":
|
364 |
model = SVR(degree=degree, kernel=kernel)
|
365 |
elif model_name == "Kernel Ridge":
|
@@ -374,6 +520,7 @@ def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degr
|
|
374 |
|
375 |
return f"{model_name} * {model.get_params()}"
|
376 |
|
|
|
377 |
# Define the Gradio layout
|
378 |
with gr.Blocks() as demo:
|
379 |
log_df = pd.DataFrame(
|
@@ -383,79 +530,174 @@ with gr.Blocks() as demo:
|
|
383 |
with gr.Row():
|
384 |
# Left Column
|
385 |
with gr.Column():
|
386 |
-
gr.HTML(
|
|
|
387 |
<div style="background-color: #6A8EAE; color: #FFFFFF; padding: 10px;">
|
388 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
|
389 |
</div>
|
390 |
-
'''
|
|
|
391 |
# Dropdown menu for predefined datasets including "Custom Dataset" option
|
392 |
-
dataset_selector = gr.Dropdown(
|
393 |
-
|
|
|
|
|
394 |
# Display the message for selected columns
|
395 |
-
selected_columns_message = gr.Textbox(
|
|
|
|
|
396 |
|
397 |
with gr.Accordion("Dataset Settings", open=True):
|
398 |
# File upload options for custom dataset (train and test)
|
399 |
dataset_name = gr.Textbox(label="Dataset Name", visible=False)
|
400 |
-
train_file = gr.File(
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
-
test_file = gr.File(
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
|
406 |
# Predefined dataset displays
|
407 |
-
predefined_display = gr.Dataframe(
|
408 |
-
|
|
|
|
|
|
|
409 |
|
410 |
# Dropdowns for selecting input and output columns for the custom dataset
|
411 |
-
input_column_selector = gr.Dropdown(
|
412 |
-
|
|
|
|
|
|
|
|
|
413 |
|
414 |
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
|
415 |
-
dataset_selector.change(
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
|
420 |
# When a predefined dataset is selected, load its head and update column selectors
|
421 |
-
dataset_selector.change(
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
# When a custom train file is uploaded, display its head and update column selectors
|
426 |
-
train_file.change(
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
|
429 |
# When a custom test file is uploaded, display its head
|
430 |
-
test_file.change(
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
-
dataset_selector.change(
|
434 |
-
|
435 |
-
|
|
|
|
|
436 |
|
437 |
# Update the selected columns information when dropdown values are changed
|
438 |
-
input_column_selector.change(
|
439 |
-
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
442 |
-
output_column_selector.change(
|
443 |
-
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
-
model_checkbox = gr.CheckboxGroup(
|
|
|
|
|
447 |
|
448 |
-
task_radiobutton = gr.Radio(
|
|
|
|
|
449 |
|
450 |
####### adding hyper parameter tuning ###########
|
451 |
-
model_name = gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
with gr.Accordion("Downstream Hyperparameter Settings", open=True):
|
453 |
# Create placeholders for hyperparameter components
|
454 |
-
max_depth = gr.Slider(1, 20, step=1,visible=False, label="max_depth")
|
455 |
-
n_estimators = gr.Slider(
|
|
|
|
|
456 |
alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
|
457 |
-
degree = gr.Slider(1, 20, step=1,visible=False, label="degree")
|
458 |
-
kernel = gr.Dropdown(
|
|
|
|
|
459 |
|
460 |
# Output textbox
|
461 |
output = gr.Textbox(label="Loaded Parameters")
|
@@ -463,24 +705,52 @@ with gr.Blocks() as demo:
|
|
463 |
# Dynamically show relevant hyperparameters based on selected model
|
464 |
def update_hyperparameters(model_name):
|
465 |
if model_name == "XGBClassifier":
|
466 |
-
return
|
467 |
-
|
|
|
|
|
|
|
|
|
|
|
468 |
elif model_name == "SVR":
|
469 |
-
return
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
471 |
elif model_name == "Kernel Ridge":
|
472 |
-
return
|
473 |
-
|
|
|
|
|
|
|
|
|
|
|
474 |
elif model_name == "Linear Regression":
|
475 |
-
return
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
477 |
elif model_name == "Default - Auto":
|
478 |
-
return
|
479 |
-
|
|
|
|
|
|
|
|
|
|
|
480 |
|
481 |
# When model is selected, update which hyperparameters are visible
|
482 |
-
model_name.change(
|
483 |
-
|
|
|
|
|
|
|
484 |
|
485 |
# Submit button to create the model with selected hyperparameters
|
486 |
submit_button = gr.Button("Create Downstream Model")
|
@@ -488,19 +758,29 @@ with gr.Blocks() as demo:
|
|
488 |
# Function to handle model creation based on input parameters
|
489 |
def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel):
|
490 |
if model_name == "XGBClassifier":
|
491 |
-
return create_model(
|
|
|
|
|
|
|
|
|
|
|
492 |
elif model_name == "SVR":
|
493 |
return create_model(model_name, degree=degree, kernel=kernel)
|
494 |
elif model_name == "Kernel Ridge":
|
495 |
-
return create_model(
|
|
|
|
|
496 |
elif model_name == "Linear Regression":
|
497 |
return create_model(model_name)
|
498 |
elif model_name == "Default - Auto":
|
499 |
return create_model(model_name)
|
500 |
|
501 |
# When the submit button is clicked, run the on_submit function
|
502 |
-
submit_button.click(
|
503 |
-
|
|
|
|
|
|
|
504 |
###### End of hyper param tuning #########
|
505 |
|
506 |
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
|
@@ -509,40 +789,69 @@ with gr.Blocks() as demo:
|
|
509 |
|
510 |
# Middle Column
|
511 |
with gr.Column():
|
512 |
-
gr.HTML(
|
|
|
513 |
<div style="background-color: #8F9779; color: #FFFFFF; padding: 10px;">
|
514 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
|
515 |
</div>
|
516 |
-
'''
|
|
|
517 |
eval_output = gr.Textbox(label="Train downstream model")
|
518 |
|
519 |
-
plot_radio = gr.Radio(
|
|
|
|
|
|
|
520 |
plot_output = gr.Plot(label="Visualization")
|
521 |
|
522 |
create_log = gr.Button("Store log")
|
523 |
|
524 |
-
log_table = gr.Dataframe(
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
|
532 |
# Function to gather selected models
|
533 |
def gather_selected_models(*models):
|
534 |
selected = [model for model in models if model]
|
535 |
return selected
|
536 |
|
537 |
-
create_log.click(
|
538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
# Right Column
|
540 |
with gr.Column():
|
541 |
-
gr.HTML(
|
|
|
542 |
<div style="background-color: #D2B48C; color: #FFFFFF; padding: 10px;">
|
543 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
|
544 |
</div>
|
545 |
-
'''
|
|
|
546 |
smiles_input = gr.Textbox(label="Input SMILES String")
|
547 |
image_display = gr.Image(label="Molecule Image", height=250, width=250)
|
548 |
# Show images for selection
|
@@ -554,17 +863,28 @@ with gr.Blocks() as demo:
|
|
554 |
)
|
555 |
image_selector.change(load_image, image_selector, image_display)
|
556 |
generate_button = gr.Button("Generate")
|
557 |
-
gen_image_display = gr.Image(
|
|
|
|
|
558 |
generated_output = gr.Textbox(label="Generated Output")
|
559 |
property_table = gr.Dataframe(label="Molecular Properties Comparison")
|
560 |
|
561 |
# Handle image selection
|
562 |
-
image_selector.change(
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
|
565 |
# Generate button to display canonical SMILES and molecule image
|
566 |
-
generate_button.click(
|
567 |
-
|
|
|
|
|
|
|
568 |
|
569 |
|
570 |
if __name__ == "__main__":
|
|
|
21 |
os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
|
22 |
|
23 |
import models.fm4m as fm4m
|
24 |
+
|
25 |
RDLogger.logger().setLevel(RDLogger.ERROR)
|
26 |
|
27 |
|
|
|
33 |
|
34 |
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
|
35 |
smiles_image_mapping = {
|
36 |
+
"Mol 1": {
|
37 |
+
"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
|
38 |
+
"image": "img/img1.png",
|
39 |
+
},
|
40 |
# Example SMILES for ethanol
|
41 |
+
"Mol 2": {
|
42 |
+
"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
|
43 |
+
"image": "img/img2.png",
|
44 |
+
},
|
45 |
# Example SMILES for butane
|
46 |
+
"Mol 3": {
|
47 |
+
"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
|
48 |
+
"image": "img/img3.png",
|
49 |
+
}, # Example SMILES for ethylamine
|
50 |
+
"Mol 4": {
|
51 |
+
"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
|
52 |
+
"image": "img/img4.png",
|
53 |
+
},
|
54 |
# Example SMILES for diethyl ether
|
55 |
+
"Mol 5": {
|
56 |
+
"smiles": "C=CCS[C@@H](C)CC(=O)OCC",
|
57 |
+
"image": "img/img5.png",
|
58 |
+
}, # Example SMILES for chloroethane
|
59 |
}
|
60 |
|
61 |
datasets = [" ", "BACE", "ESOL", "Load Custom Dataset"]
|
62 |
|
63 |
+
models_enabled = [
|
64 |
+
"SELFIES-TED",
|
65 |
+
"MHG-GED",
|
66 |
+
"MolFormer",
|
67 |
+
"SMI-TED",
|
68 |
+
"Mordred",
|
69 |
+
"MorganFingerprint",
|
70 |
+
]
|
71 |
|
72 |
fusion_available = ["Concat"]
|
73 |
|
|
|
78 |
result = f"{eval_output}"
|
79 |
result = result.replace(" Score", "")
|
80 |
|
81 |
+
new_entry = {
|
82 |
+
"Selected Models": str(models),
|
83 |
+
"Dataset": dataset,
|
84 |
+
"Task": task_dic[task_type],
|
85 |
+
"Result": result,
|
86 |
+
}
|
87 |
new_entry_df = pd.DataFrame([new_entry])
|
88 |
|
89 |
state["log_df"] = pd.concat([new_entry_df, state["log_df"]])
|
|
|
133 |
|
134 |
def generate(latent_vector, mask):
|
135 |
encoder_outputs = BaseModelOutput(latent_vector)
|
136 |
+
decoder_output = gen_model.generate(
|
137 |
+
encoder_outputs=encoder_outputs,
|
138 |
+
attention_mask=mask,
|
139 |
+
max_new_tokens=64,
|
140 |
+
do_sample=True,
|
141 |
+
top_k=5,
|
142 |
+
top_p=0.95,
|
143 |
+
num_return_sequences=1,
|
144 |
+
)
|
145 |
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
|
146 |
return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]
|
147 |
|
148 |
|
149 |
def perturb_latent(latent_vecs, noise_scale=0.5):
|
150 |
+
return (
|
151 |
+
torch.tensor(
|
152 |
+
np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
|
153 |
+
dtype=torch.float32,
|
154 |
+
)
|
155 |
+
+ latent_vecs
|
156 |
+
)
|
157 |
|
158 |
|
159 |
def encode(selfies):
|
160 |
+
encoding = gen_tokenizer(
|
161 |
+
selfies,
|
162 |
+
return_tensors='pt',
|
163 |
+
max_length=128,
|
164 |
+
truncation=True,
|
165 |
+
padding='max_length',
|
166 |
+
)
|
167 |
input_ids = encoding['input_ids']
|
168 |
attention_mask = encoding['attention_mask']
|
169 |
+
outputs = gen_model.model.encoder(
|
170 |
+
input_ids=input_ids, attention_mask=attention_mask
|
171 |
+
)
|
172 |
model_output = outputs.last_hidden_state
|
173 |
return model_output, attention_mask
|
174 |
|
|
|
187 |
mol = Chem.MolFromSmiles(gen[0])
|
188 |
if mol:
|
189 |
gen_mol = Chem.MolToSmiles(mol)
|
190 |
+
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
|
191 |
+
break
|
192 |
else:
|
193 |
print('Abnormal molecule:', gen[0])
|
194 |
|
|
|
202 |
# Prepare the table with ref mol and gen mol
|
203 |
data = {
|
204 |
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
|
205 |
+
"Reference Mol": [
|
206 |
+
ref_properties[0],
|
207 |
+
ref_properties[1],
|
208 |
+
ref_properties[2],
|
209 |
+
ref_properties[3],
|
210 |
+
tanimoto_similarity,
|
211 |
+
],
|
212 |
+
"Generated Mol": [
|
213 |
+
gen_properties[0],
|
214 |
+
gen_properties[1],
|
215 |
+
gen_properties[2],
|
216 |
+
gen_properties[3],
|
217 |
+
"",
|
218 |
+
],
|
219 |
}
|
220 |
df = pd.DataFrame(data)
|
221 |
|
|
|
243 |
downstream_model = downstream_model.rstrip()
|
244 |
params = None
|
245 |
|
|
|
246 |
try:
|
247 |
if not selected_models:
|
248 |
return "Please select at least one enabled model."
|
|
|
253 |
downstream_model = "DefaultClassifier"
|
254 |
params = None
|
255 |
|
256 |
+
(
|
257 |
+
result,
|
258 |
+
state["roc_auc"],
|
259 |
+
state["fpr"],
|
260 |
+
state["tpr"],
|
261 |
+
state["x_batch"],
|
262 |
+
state["y_batch"],
|
263 |
+
) = fm4m.multi_modal(
|
264 |
+
model_list=selected_models,
|
265 |
+
downstream_model=downstream_model,
|
266 |
+
params=params,
|
267 |
+
dataset=dataset,
|
268 |
+
)
|
269 |
|
270 |
elif task_type == "Regression":
|
271 |
if downstream_model == "Default Settings":
|
272 |
downstream_model = "DefaultRegressor"
|
273 |
params = None
|
274 |
|
275 |
+
(
|
276 |
+
result,
|
277 |
+
state["RMSE"],
|
278 |
+
state["y_batch_test"],
|
279 |
+
state["y_prob"],
|
280 |
+
state["x_batch"],
|
281 |
+
state["y_batch"],
|
282 |
+
) = fm4m.multi_modal(
|
283 |
+
model_list=selected_models,
|
284 |
+
downstream_model=downstream_model,
|
285 |
+
params=params,
|
286 |
+
dataset=dataset,
|
287 |
+
)
|
288 |
|
289 |
else:
|
290 |
if task_type == "Classification":
|
|
|
292 |
downstream_model = "DefaultClassifier"
|
293 |
params = None
|
294 |
|
295 |
+
(
|
296 |
+
result,
|
297 |
+
state["roc_auc"],
|
298 |
+
state["fpr"],
|
299 |
+
state["tpr"],
|
300 |
+
state["x_batch"],
|
301 |
+
state["y_batch"],
|
302 |
+
) = fm4m.single_modal(
|
303 |
+
model=selected_models[0],
|
304 |
+
downstream_model=downstream_model,
|
305 |
+
params=params,
|
306 |
+
dataset=dataset,
|
307 |
+
)
|
308 |
|
309 |
elif task_type == "Regression":
|
310 |
if downstream_model == "Default Settings":
|
311 |
downstream_model = "DefaultRegressor"
|
312 |
params = None
|
313 |
|
314 |
+
(
|
315 |
+
result,
|
316 |
+
state["RMSE"],
|
317 |
+
state["y_batch_test"],
|
318 |
+
state["y_prob"],
|
319 |
+
state["x_batch"],
|
320 |
+
state["y_batch"],
|
321 |
+
) = fm4m.single_modal(
|
322 |
+
model=selected_models[0],
|
323 |
+
downstream_model=downstream_model,
|
324 |
+
params=params,
|
325 |
+
dataset=dataset,
|
326 |
+
)
|
327 |
|
328 |
if result == None:
|
329 |
result = "Data & Model Setting is incorrect"
|
|
|
353 |
roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr")
|
354 |
ax.set_title("ROC-AUC Curve")
|
355 |
try:
|
356 |
+
ax.plot(
|
357 |
+
fpr,
|
358 |
+
tpr,
|
359 |
+
color='darkorange',
|
360 |
+
lw=2,
|
361 |
+
label=f'ROC curve (area = {roc_auc:.4f})',
|
362 |
+
)
|
363 |
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
364 |
ax.set_xlim([0.0, 1.0])
|
365 |
ax.set_ylim([0.0, 1.05])
|
|
|
371 |
ax.legend(loc='lower right')
|
372 |
|
373 |
elif plot_type == "Parity Plot":
|
374 |
+
RMSE, y_batch_test, y_prob = (
|
375 |
+
state.get("RMSE"),
|
376 |
+
state.get("y_batch_test"),
|
377 |
+
state.get("y_prob"),
|
378 |
+
)
|
379 |
ax.set_title("Parity plot")
|
380 |
|
381 |
# change format
|
|
|
384 |
print(y_prob)
|
385 |
y_batch_test = np.array(y_batch_test, dtype=float)
|
386 |
y_prob = np.array(y_prob, dtype=float)
|
387 |
+
ax.scatter(
|
388 |
+
y_batch_test,
|
389 |
+
y_prob,
|
390 |
+
color="blue",
|
391 |
+
label=f"Predicted vs Actual (RMSE: {RMSE:.4f})",
|
392 |
+
)
|
393 |
min_val = min(min(y_batch_test), min(y_prob))
|
394 |
max_val = max(max(y_batch_test), max(y_prob))
|
395 |
ax.plot([min_val, max_val], [min_val, max_val], 'r-')
|
|
|
402 |
print(y_batch_test)
|
403 |
print(y_prob)
|
404 |
|
|
|
405 |
ax.set_xlabel('Actual Values')
|
406 |
ax.set_ylabel('Predicted Values')
|
407 |
|
|
|
420 |
# Function to load a predefined dataset from the local path
|
421 |
def load_predefined_dataset(dataset_name):
|
422 |
val = predefined_datasets.get(dataset_name)
|
423 |
+
try:
|
424 |
+
file_path = val.split(",")[0]
|
425 |
+
except:
|
426 |
+
file_path = False
|
427 |
|
428 |
if file_path:
|
429 |
df = pd.read_csv(file_path)
|
430 |
+
return (
|
431 |
+
df.head(),
|
432 |
+
gr.update(choices=list(df.columns)),
|
433 |
+
gr.update(choices=list(df.columns)),
|
434 |
+
f"{dataset_name.lower()}",
|
435 |
+
)
|
436 |
+
return (
|
437 |
+
pd.DataFrame(),
|
438 |
+
gr.update(choices=[]),
|
439 |
+
gr.update(choices=[]),
|
440 |
+
f"Dataset not found",
|
441 |
+
)
|
442 |
|
443 |
|
444 |
# Function to display the head of the uploaded CSV file
|
|
|
446 |
if file is not None:
|
447 |
# Load the CSV file into a DataFrame
|
448 |
df = pd.read_csv(file.name)
|
449 |
+
return (
|
450 |
+
df.head(),
|
451 |
+
gr.update(choices=list(df.columns)),
|
452 |
+
gr.update(choices=list(df.columns)),
|
453 |
+
)
|
454 |
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
|
455 |
|
456 |
|
|
|
458 |
def handle_dataset_selection(selected_dataset):
|
459 |
if selected_dataset == "Custom Dataset":
|
460 |
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
|
461 |
+
return (
|
462 |
+
gr.update(visible=True),
|
463 |
+
gr.update(visible=True),
|
464 |
+
gr.update(visible=True),
|
465 |
+
gr.update(visible=True),
|
466 |
+
gr.update(visible=True),
|
467 |
+
gr.update(visible=False),
|
468 |
+
gr.update(visible=True),
|
469 |
+
gr.update(visible=True),
|
470 |
+
)
|
471 |
else:
|
472 |
+
return (
|
473 |
+
gr.update(visible=True),
|
474 |
+
gr.update(visible=False),
|
475 |
+
gr.update(visible=False),
|
476 |
+
gr.update(visible=False),
|
477 |
+
gr.update(visible=False),
|
478 |
+
gr.update(visible=False),
|
479 |
+
gr.update(visible=False),
|
480 |
+
gr.update(visible=False),
|
481 |
+
)
|
482 |
|
483 |
|
484 |
# Function to select input and output columns and display a message
|
485 |
+
def select_columns(input_column, output_column, train_data, test_data, dataset_name):
|
486 |
if input_column and output_column:
|
487 |
return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
|
488 |
return "Please select both input and output columns."
|
489 |
|
490 |
+
|
491 |
+
def set_dataname(dataset_name, dataset_selector):
|
492 |
if dataset_selector == "Custom Dataset":
|
493 |
return f"{dataset_name}"
|
494 |
return f"{dataset_selector}"
|
495 |
|
496 |
+
|
497 |
# Function to create model based on user input
|
498 |
+
def create_model(
|
499 |
+
model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None
|
500 |
+
):
|
501 |
if model_name == "XGBClassifier":
|
502 |
+
model = xgb.XGBClassifier(
|
503 |
+
objective='binary:logistic',
|
504 |
+
eval_metric='auc',
|
505 |
+
max_depth=max_depth,
|
506 |
+
n_estimators=n_estimators,
|
507 |
+
alpha=alpha,
|
508 |
+
)
|
509 |
elif model_name == "SVR":
|
510 |
model = SVR(degree=degree, kernel=kernel)
|
511 |
elif model_name == "Kernel Ridge":
|
|
|
520 |
|
521 |
return f"{model_name} * {model.get_params()}"
|
522 |
|
523 |
+
|
524 |
# Define the Gradio layout
|
525 |
with gr.Blocks() as demo:
|
526 |
log_df = pd.DataFrame(
|
|
|
530 |
with gr.Row():
|
531 |
# Left Column
|
532 |
with gr.Column():
|
533 |
+
gr.HTML(
|
534 |
+
'''
|
535 |
<div style="background-color: #6A8EAE; color: #FFFFFF; padding: 10px;">
|
536 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
|
537 |
</div>
|
538 |
+
'''
|
539 |
+
)
|
540 |
# Dropdown menu for predefined datasets including "Custom Dataset" option
|
541 |
+
dataset_selector = gr.Dropdown(
|
542 |
+
label="Select Dataset",
|
543 |
+
choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
|
544 |
+
)
|
545 |
# Display the message for selected columns
|
546 |
+
selected_columns_message = gr.Textbox(
|
547 |
+
label="Selected Columns Info", visible=False
|
548 |
+
)
|
549 |
|
550 |
with gr.Accordion("Dataset Settings", open=True):
|
551 |
# File upload options for custom dataset (train and test)
|
552 |
dataset_name = gr.Textbox(label="Dataset Name", visible=False)
|
553 |
+
train_file = gr.File(
|
554 |
+
label="Upload Custom Train Dataset",
|
555 |
+
file_types=[".csv"],
|
556 |
+
visible=False,
|
557 |
+
)
|
558 |
+
train_display = gr.Dataframe(
|
559 |
+
label="Train Dataset Preview (First 5 Rows)",
|
560 |
+
visible=False,
|
561 |
+
interactive=False,
|
562 |
+
)
|
563 |
|
564 |
+
test_file = gr.File(
|
565 |
+
label="Upload Custom Test Dataset",
|
566 |
+
file_types=[".csv"],
|
567 |
+
visible=False,
|
568 |
+
)
|
569 |
+
test_display = gr.Dataframe(
|
570 |
+
label="Test Dataset Preview (First 5 Rows)",
|
571 |
+
visible=False,
|
572 |
+
interactive=False,
|
573 |
+
)
|
574 |
|
575 |
# Predefined dataset displays
|
576 |
+
predefined_display = gr.Dataframe(
|
577 |
+
label="Predefined Dataset Preview (First 5 Rows)",
|
578 |
+
visible=False,
|
579 |
+
interactive=False,
|
580 |
+
)
|
581 |
|
582 |
# Dropdowns for selecting input and output columns for the custom dataset
|
583 |
+
input_column_selector = gr.Dropdown(
|
584 |
+
label="Select Input Column", choices=[], visible=False
|
585 |
+
)
|
586 |
+
output_column_selector = gr.Dropdown(
|
587 |
+
label="Select Output Column", choices=[], visible=False
|
588 |
+
)
|
589 |
|
590 |
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
|
591 |
+
dataset_selector.change(
|
592 |
+
handle_dataset_selection,
|
593 |
+
inputs=dataset_selector,
|
594 |
+
outputs=[
|
595 |
+
dataset_name,
|
596 |
+
train_file,
|
597 |
+
train_display,
|
598 |
+
test_file,
|
599 |
+
test_display,
|
600 |
+
predefined_display,
|
601 |
+
input_column_selector,
|
602 |
+
output_column_selector,
|
603 |
+
],
|
604 |
+
)
|
605 |
|
606 |
# When a predefined dataset is selected, load its head and update column selectors
|
607 |
+
dataset_selector.change(
|
608 |
+
load_predefined_dataset,
|
609 |
+
inputs=dataset_selector,
|
610 |
+
outputs=[
|
611 |
+
predefined_display,
|
612 |
+
input_column_selector,
|
613 |
+
output_column_selector,
|
614 |
+
selected_columns_message,
|
615 |
+
],
|
616 |
+
)
|
617 |
|
618 |
# When a custom train file is uploaded, display its head and update column selectors
|
619 |
+
train_file.change(
|
620 |
+
display_csv_head,
|
621 |
+
inputs=train_file,
|
622 |
+
outputs=[
|
623 |
+
train_display,
|
624 |
+
input_column_selector,
|
625 |
+
output_column_selector,
|
626 |
+
],
|
627 |
+
)
|
628 |
|
629 |
# When a custom test file is uploaded, display its head
|
630 |
+
test_file.change(
|
631 |
+
display_csv_head,
|
632 |
+
inputs=test_file,
|
633 |
+
outputs=[
|
634 |
+
test_display,
|
635 |
+
input_column_selector,
|
636 |
+
output_column_selector,
|
637 |
+
],
|
638 |
+
)
|
639 |
|
640 |
+
dataset_selector.change(
|
641 |
+
set_dataname,
|
642 |
+
inputs=[dataset_name, dataset_selector],
|
643 |
+
outputs=dataset_name,
|
644 |
+
)
|
645 |
|
646 |
# Update the selected columns information when dropdown values are changed
|
647 |
+
input_column_selector.change(
|
648 |
+
select_columns,
|
649 |
+
inputs=[
|
650 |
+
input_column_selector,
|
651 |
+
output_column_selector,
|
652 |
+
train_file,
|
653 |
+
test_file,
|
654 |
+
dataset_name,
|
655 |
+
],
|
656 |
+
outputs=selected_columns_message,
|
657 |
+
)
|
658 |
|
659 |
+
output_column_selector.change(
|
660 |
+
select_columns,
|
661 |
+
inputs=[
|
662 |
+
input_column_selector,
|
663 |
+
output_column_selector,
|
664 |
+
train_file,
|
665 |
+
test_file,
|
666 |
+
dataset_name,
|
667 |
+
],
|
668 |
+
outputs=selected_columns_message,
|
669 |
+
)
|
670 |
|
671 |
+
model_checkbox = gr.CheckboxGroup(
|
672 |
+
choices=models_enabled, label="Select Model"
|
673 |
+
)
|
674 |
|
675 |
+
task_radiobutton = gr.Radio(
|
676 |
+
choices=["Classification", "Regression"], label="Task Type"
|
677 |
+
)
|
678 |
|
679 |
####### adding hyper parameter tuning ###########
|
680 |
+
model_name = gr.Dropdown(
|
681 |
+
[
|
682 |
+
"Default - Auto",
|
683 |
+
"XGBClassifier",
|
684 |
+
"SVR",
|
685 |
+
"Kernel Ridge",
|
686 |
+
"Linear Regression",
|
687 |
+
],
|
688 |
+
label="Select Downstream Model",
|
689 |
+
)
|
690 |
with gr.Accordion("Downstream Hyperparameter Settings", open=True):
|
691 |
# Create placeholders for hyperparameter components
|
692 |
+
max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth")
|
693 |
+
n_estimators = gr.Slider(
|
694 |
+
100, 5000, step=100, visible=False, label="n_estimators"
|
695 |
+
)
|
696 |
alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
|
697 |
+
degree = gr.Slider(1, 20, step=1, visible=False, label="degree")
|
698 |
+
kernel = gr.Dropdown(
|
699 |
+
choices=["rbf", "poly", "linear"], visible=False, label="kernel"
|
700 |
+
)
|
701 |
|
702 |
# Output textbox
|
703 |
output = gr.Textbox(label="Loaded Parameters")
|
|
|
705 |
# Dynamically show relevant hyperparameters based on selected model
|
706 |
def update_hyperparameters(model_name):
|
707 |
if model_name == "XGBClassifier":
|
708 |
+
return (
|
709 |
+
gr.update(visible=True),
|
710 |
+
gr.update(visible=True),
|
711 |
+
gr.update(visible=True),
|
712 |
+
gr.update(visible=False),
|
713 |
+
gr.update(visible=False),
|
714 |
+
)
|
715 |
elif model_name == "SVR":
|
716 |
+
return (
|
717 |
+
gr.update(visible=False),
|
718 |
+
gr.update(visible=False),
|
719 |
+
gr.update(visible=False),
|
720 |
+
gr.update(visible=True),
|
721 |
+
gr.update(visible=True),
|
722 |
+
)
|
723 |
elif model_name == "Kernel Ridge":
|
724 |
+
return (
|
725 |
+
gr.update(visible=False),
|
726 |
+
gr.update(visible=False),
|
727 |
+
gr.update(visible=True),
|
728 |
+
gr.update(visible=True),
|
729 |
+
gr.update(visible=True),
|
730 |
+
)
|
731 |
elif model_name == "Linear Regression":
|
732 |
+
return (
|
733 |
+
gr.update(visible=False),
|
734 |
+
gr.update(visible=False),
|
735 |
+
gr.update(visible=False),
|
736 |
+
gr.update(visible=False),
|
737 |
+
gr.update(visible=False),
|
738 |
+
)
|
739 |
elif model_name == "Default - Auto":
|
740 |
+
return (
|
741 |
+
gr.update(visible=False),
|
742 |
+
gr.update(visible=False),
|
743 |
+
gr.update(visible=False),
|
744 |
+
gr.update(visible=False),
|
745 |
+
gr.update(visible=False),
|
746 |
+
)
|
747 |
|
748 |
# When model is selected, update which hyperparameters are visible
|
749 |
+
model_name.change(
|
750 |
+
update_hyperparameters,
|
751 |
+
inputs=[model_name],
|
752 |
+
outputs=[max_depth, n_estimators, alpha, degree, kernel],
|
753 |
+
)
|
754 |
|
755 |
# Submit button to create the model with selected hyperparameters
|
756 |
submit_button = gr.Button("Create Downstream Model")
|
|
|
758 |
# Function to handle model creation based on input parameters
|
759 |
def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel):
|
760 |
if model_name == "XGBClassifier":
|
761 |
+
return create_model(
|
762 |
+
model_name,
|
763 |
+
max_depth=max_depth,
|
764 |
+
n_estimators=n_estimators,
|
765 |
+
alpha=alpha,
|
766 |
+
)
|
767 |
elif model_name == "SVR":
|
768 |
return create_model(model_name, degree=degree, kernel=kernel)
|
769 |
elif model_name == "Kernel Ridge":
|
770 |
+
return create_model(
|
771 |
+
model_name, alpha=alpha, degree=degree, kernel=kernel
|
772 |
+
)
|
773 |
elif model_name == "Linear Regression":
|
774 |
return create_model(model_name)
|
775 |
elif model_name == "Default - Auto":
|
776 |
return create_model(model_name)
|
777 |
|
778 |
# When the submit button is clicked, run the on_submit function
|
779 |
+
submit_button.click(
|
780 |
+
on_submit,
|
781 |
+
inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
|
782 |
+
outputs=output,
|
783 |
+
)
|
784 |
###### End of hyper param tuning #########
|
785 |
|
786 |
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
|
|
|
789 |
|
790 |
# Middle Column
|
791 |
with gr.Column():
|
792 |
+
gr.HTML(
|
793 |
+
'''
|
794 |
<div style="background-color: #8F9779; color: #FFFFFF; padding: 10px;">
|
795 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
|
796 |
</div>
|
797 |
+
'''
|
798 |
+
)
|
799 |
eval_output = gr.Textbox(label="Train downstream model")
|
800 |
|
801 |
+
plot_radio = gr.Radio(
|
802 |
+
choices=["ROC-AUC", "Parity Plot", "Latent Space"],
|
803 |
+
label="Select Plot Type",
|
804 |
+
)
|
805 |
plot_output = gr.Plot(label="Visualization")
|
806 |
|
807 |
create_log = gr.Button("Store log")
|
808 |
|
809 |
+
log_table = gr.Dataframe(
|
810 |
+
value=log_df, label="Log of Selections and Results", interactive=False
|
811 |
+
)
|
812 |
+
|
813 |
+
eval_button.click(
|
814 |
+
display_eval,
|
815 |
+
inputs=[
|
816 |
+
model_checkbox,
|
817 |
+
selected_columns_message,
|
818 |
+
task_radiobutton,
|
819 |
+
output,
|
820 |
+
fusion_radiobutton,
|
821 |
+
state,
|
822 |
+
],
|
823 |
+
outputs=eval_output,
|
824 |
+
)
|
825 |
+
|
826 |
+
plot_radio.change(
|
827 |
+
display_plot, inputs=[plot_radio, state], outputs=plot_output
|
828 |
+
)
|
829 |
|
830 |
# Function to gather selected models
|
831 |
def gather_selected_models(*models):
|
832 |
selected = [model for model in models if model]
|
833 |
return selected
|
834 |
|
835 |
+
create_log.click(
|
836 |
+
evaluate_and_log,
|
837 |
+
inputs=[
|
838 |
+
model_checkbox,
|
839 |
+
dataset_name,
|
840 |
+
task_radiobutton,
|
841 |
+
eval_output,
|
842 |
+
state,
|
843 |
+
],
|
844 |
+
outputs=log_table,
|
845 |
+
)
|
846 |
# Right Column
|
847 |
with gr.Column():
|
848 |
+
gr.HTML(
|
849 |
+
'''
|
850 |
<div style="background-color: #D2B48C; color: #FFFFFF; padding: 10px;">
|
851 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
|
852 |
</div>
|
853 |
+
'''
|
854 |
+
)
|
855 |
smiles_input = gr.Textbox(label="Input SMILES String")
|
856 |
image_display = gr.Image(label="Molecule Image", height=250, width=250)
|
857 |
# Show images for selection
|
|
|
863 |
)
|
864 |
image_selector.change(load_image, image_selector, image_display)
|
865 |
generate_button = gr.Button("Generate")
|
866 |
+
gen_image_display = gr.Image(
|
867 |
+
label="Generated Molecule Image", height=250, width=250
|
868 |
+
)
|
869 |
generated_output = gr.Textbox(label="Generated Output")
|
870 |
property_table = gr.Dataframe(label="Molecular Properties Comparison")
|
871 |
|
872 |
# Handle image selection
|
873 |
+
image_selector.change(
|
874 |
+
handle_image_selection,
|
875 |
+
inputs=image_selector,
|
876 |
+
outputs=[smiles_input, image_display],
|
877 |
+
)
|
878 |
+
smiles_input.change(
|
879 |
+
smiles_to_image, inputs=smiles_input, outputs=image_display
|
880 |
+
)
|
881 |
|
882 |
# Generate button to display canonical SMILES and molecule image
|
883 |
+
generate_button.click(
|
884 |
+
generate_canonical,
|
885 |
+
inputs=smiles_input,
|
886 |
+
outputs=[property_table, generated_output, gen_image_display],
|
887 |
+
)
|
888 |
|
889 |
|
890 |
if __name__ == "__main__":
|