|
import sys |
|
sys.path.append('./rxn/') |
|
import torch |
|
from rxn.reaction import Reaction |
|
import json |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
|
|
ckpt_path = "./rxn/model/model.ckpt" |
|
model = Reaction(ckpt_path, device=torch.device('cpu')) |
|
device = torch.device('cpu') |
|
|
|
def get_reaction(image_path: str) -> list: |
|
image_file = image_path |
|
return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True)) |
|
|
|
|
|
|
|
def generate_combined_image(predictions, image_file): |
|
""" |
|
将预测的图像整合到一个对称的布局中输出。 |
|
""" |
|
output = model.draw_predictions(predictions, image_file=image_file) |
|
n_images = len(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
n_cols = 1 |
|
n_rows = (n_images + n_cols - 1) // n_cols |
|
|
|
|
|
processed_images = [] |
|
for img in output: |
|
if len(img.shape) == 2: |
|
img = np.stack([img] * 3, axis=-1) |
|
elif img.shape[2] > 3: |
|
img = img[:, :, :3] |
|
if img.dtype == np.float32 or img.dtype == np.float64: |
|
img = (img * 255).astype(np.uint8) |
|
processed_images.append(img) |
|
output = processed_images |
|
|
|
|
|
if n_images < n_rows * n_cols: |
|
blank_image = np.ones_like(output[0]) * 255 |
|
while len(output) < n_rows * n_cols: |
|
output.append(blank_image) |
|
|
|
|
|
fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 12 * n_rows)) |
|
|
|
|
|
if isinstance(axes, np.ndarray): |
|
axes = axes.flatten() |
|
else: |
|
axes = [axes] |
|
|
|
|
|
for idx, img in enumerate(output): |
|
ax = axes[idx] |
|
ax.imshow(img) |
|
ax.axis('off') |
|
if idx < n_images: |
|
ax.set_title(f"### Reaction {idx + 1} ###",fontsize=42) |
|
|
|
|
|
for idx in range(n_images, len(axes)): |
|
fig.delaxes(axes[idx]) |
|
|
|
|
|
combined_image_path = "combined_output.png" |
|
plt.tight_layout() |
|
plt.savefig(combined_image_path) |
|
plt.close(fig) |
|
return combined_image_path |
|
|