|
import os |
|
import gradio as gr |
|
import json |
|
from rxnim import RXNIM |
|
from getReaction import generate_combined_image |
|
import torch |
|
from rxn.reaction import Reaction |
|
from rdkit import Chem |
|
from rdkit.Chem import rdChemReactions |
|
from rdkit.Chem import Draw |
|
from rdkit.Chem import AllChem |
|
from rdkit.Chem.Draw import rdMolDraw2D |
|
import cairosvg |
|
import re |
|
|
|
PROMPT_DIR = "prompts/" |
|
ckpt_path = "./rxn/model/model.ckpt" |
|
model = Reaction(ckpt_path, device=torch.device('cpu')) |
|
|
|
|
|
PROMPT_NAMES = { |
|
"2_RxnOCR.txt": "Reaction Image Parsing Workflow", |
|
} |
|
example_diagram = "examples/exp.png" |
|
rdkit_image = "examples/rdkit.png" |
|
|
|
def list_prompt_files_with_names(): |
|
""" |
|
列出 prompts 目录下的所有 .txt 文件,为没有名字的生成默认名字。 |
|
返回 {friendly_name: filename} 映射。 |
|
""" |
|
prompt_files = {} |
|
for f in os.listdir(PROMPT_DIR): |
|
if f.endswith(".txt"): |
|
|
|
friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}") |
|
prompt_files[friendly_name] = f |
|
return prompt_files |
|
|
|
def parse_reactions(output_json): |
|
""" |
|
解析 JSON 格式的反应数据并格式化输出,包含颜色定制。 |
|
""" |
|
reactions_data = json.loads(output_json) |
|
reactions_list = reactions_data.get("reactions", []) |
|
detailed_output = [] |
|
smiles_output = [] |
|
|
|
for reaction in reactions_list: |
|
reaction_id = reaction.get("reaction_id", "Unknown ID") |
|
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] |
|
conditions = [ |
|
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" |
|
for c in reaction.get("conditions", []) |
|
] |
|
conditions_1 = [ |
|
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" |
|
for c in reaction.get("conditions", []) |
|
] |
|
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] |
|
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] |
|
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] |
|
|
|
|
|
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}" |
|
full_reaction = f"<span style='color:black'>{full_reaction}</span>" |
|
|
|
|
|
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>" |
|
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>" |
|
reaction_output += f" Conditions: {', '.join(conditions)}<br>" |
|
reaction_output += f" Products: {', '.join(products)}<br>" |
|
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>" |
|
reaction_output += "<br>" |
|
detailed_output.append(reaction_output) |
|
|
|
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" |
|
smiles_output.append(reaction_smiles) |
|
|
|
|
|
|
|
return detailed_output, smiles_output |
|
|
|
def process_chem_image(image, selected_task): |
|
chem_mllm = RXNIM() |
|
|
|
|
|
prompt_path = os.path.join(PROMPT_DIR, prompts_with_names[selected_task]) |
|
image_path = "temp_image.png" |
|
image.save(image_path) |
|
|
|
|
|
rxnim_result = chem_mllm.process(image_path, prompt_path) |
|
|
|
|
|
detailed_reactions, smiles_output = parse_reactions(rxnim_result) |
|
|
|
|
|
predictions = model.predict_image_file(image_path, molscribe=True, ocr=True) |
|
combined_image_path = generate_combined_image(predictions, image_path) |
|
|
|
|
|
json_file_path = "output.json" |
|
with open(json_file_path, "w") as json_file: |
|
json.dump(json.loads(rxnim_result), json_file, indent=4) |
|
|
|
|
|
|
|
return "\n\n".join(detailed_reactions), smiles_output, combined_image_path, example_diagram, json_file_path |
|
|
|
|
|
|
|
prompts_with_names = list_prompt_files_with_names() |
|
|
|
|
|
examples = [ |
|
|
|
["examples/reaction1.png", "Reaction Image Parsing Workflow"], |
|
["examples/reaction2.png", "Reaction Image Parsing Workflow"], |
|
["examples/reaction3.png", "Reaction Image Parsing Workflow"], |
|
["examples/reaction4.png", "Reaction Image Parsing Workflow"], |
|
] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
|
|
<center> <h1>Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model<h1></center> |
|
|
|
Upload a reaction image and select a predefined task prompt. |
|
""") |
|
|
|
|
|
|
|
|
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="pil", label="Upload Reaction Image") |
|
task_radio = gr.Radio( |
|
choices=list(prompts_with_names.keys()), |
|
label="Select a predefined task", |
|
) |
|
with gr.Row(): |
|
clear_button = gr.Button("Clear") |
|
process_button = gr.Button("Run", elem_id="submit-btn") |
|
|
|
gr.Markdown("### Reaction Imge Parsing Output") |
|
reaction_output = gr.HTML(label="Reaction outputs") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown("### Reaction Extraction Output") |
|
visualization_output = gr.Image(label="Visualization Output") |
|
schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram") |
|
|
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### Machine-readable Data Output") |
|
smiles_output = gr.Textbox( |
|
label="Reaction SMILES", |
|
show_copy_button=True, |
|
interactive=False, |
|
visible=False, |
|
) |
|
|
|
|
|
|
|
@gr.render(inputs = smiles_output) |
|
def show_split(inputs): |
|
if not inputs or isinstance(inputs, str) and inputs.strip() == "": |
|
return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100) |
|
else: |
|
|
|
smiles_list = inputs.split(",") |
|
smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list] |
|
components = [] |
|
for i, smiles in enumerate(smiles_list): |
|
smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "") |
|
rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True) |
|
|
|
if rxn: |
|
|
|
new_rxn = AllChem.ChemicalReaction() |
|
for mol in rxn.GetReactants(): |
|
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) |
|
new_rxn.AddReactantTemplate(mol) |
|
for mol in rxn.GetProducts(): |
|
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) |
|
new_rxn.AddProductTemplate(mol) |
|
|
|
rxn = new_rxn |
|
|
|
def atom_mapping_remover(rxn): |
|
for reactant in rxn.GetReactants(): |
|
for atom in reactant.GetAtoms(): |
|
atom.SetAtomMapNum(0) |
|
for product in rxn.GetProducts(): |
|
for atom in product.GetAtoms(): |
|
atom.SetAtomMapNum(0) |
|
return rxn |
|
|
|
atom_mapping_remover(rxn) |
|
|
|
reactant1 = rxn.GetReactantTemplate(0) |
|
print(reactant1.GetNumBonds) |
|
reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None |
|
|
|
if reactant1.GetNumBonds() > 0: |
|
bond_length_reference = Draw.MeanBondLength(reactant1) |
|
elif reactant2 and reactant2.GetNumBonds() > 0: |
|
bond_length_reference = Draw.MeanBondLength(reactant2) |
|
else: |
|
bond_length_reference = 1.0 |
|
|
|
|
|
drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1) |
|
dopts = drawer.drawOptions() |
|
dopts.padding = 0.1 |
|
dopts.includeRadicals = True |
|
Draw.SetACS1996Mode(dopts, bond_length_reference*0.55) |
|
dopts.bondLineWidth = 1.5 |
|
drawer.DrawReaction(rxn) |
|
drawer.FinishDrawing() |
|
svg_content = drawer.GetDrawingText() |
|
svg_file = f"reaction{i+1}.svg" |
|
with open(svg_file, "w") as f: |
|
f.write(svg_content) |
|
png_file = f"reaction_{i+1}.png" |
|
cairosvg.svg2png(url=svg_file, write_to=png_file) |
|
|
|
|
|
|
|
components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i + 1}", show_copy_button=True, interactive=False)) |
|
components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i + 1}")) |
|
return components |
|
|
|
download_json = gr.File(label="Download JSON File",) |
|
|
|
|
|
|
|
|
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[image_input, task_radio], |
|
outputs=[reaction_output, smiles_output, visualization_output], |
|
) |
|
|
|
|
|
clear_button.click( |
|
lambda: (None, None, None, None, None), |
|
inputs=[], |
|
outputs=[ |
|
image_input, |
|
task_radio, |
|
reaction_output, |
|
smiles_output, |
|
visualization_output, |
|
], |
|
) |
|
|
|
process_button.click( |
|
process_chem_image, |
|
inputs=[image_input, task_radio], |
|
outputs=[ |
|
reaction_output, |
|
smiles_output, |
|
visualization_output, |
|
schematic_diagram, |
|
download_json, |
|
], |
|
) |
|
|
|
demo.css = """ |
|
#submit-btn { |
|
background-color: #FF914D; |
|
color: white; |
|
font-weight: bold; |
|
} |
|
""" |
|
demo.launch() |