Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import glob | |
import warnings | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from utils import OrthogonalRegularizer | |
from huggingface_hub.keras_mixin import from_pretrained_keras | |
# load model | |
model = from_pretrained_keras( | |
"keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer} | |
) | |
# Examples | |
samples = [] | |
input_images = glob.glob("asset/source/*.csv") | |
examples = [[im] for im in input_images] | |
LABELS = ["wing", "body", "tail", "engine"] | |
COLORS = ["blue", "green", "red", "pink"] | |
def visualize_data(point_cloud, labels, output_path=None): | |
df = pd.DataFrame( | |
data={ | |
"x": point_cloud[:, 0], | |
"y": point_cloud[:, 1], | |
"z": point_cloud[:, 2], | |
"label": labels, | |
} | |
) | |
fig = plt.figure(figsize=(15, 10)) | |
ax = plt.axes(projection="3d") | |
for index, label in enumerate(LABELS): | |
c_df = df[df["label"] == label] | |
try: | |
ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]) | |
except IndexError: | |
pass | |
ax.legend() | |
if output_path: | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
plt.savefig(output_path) | |
def inference( | |
csv_file, | |
output_path="asset/output", | |
cpu=False, | |
): | |
csv_path = csv_file.name | |
im_name = csv_path.split("/")[-1].split(".")[0] | |
if os.path.exists(csv_path): | |
df = pd.read_csv(csv_path, index_col=None) | |
inputs = df[["x", "y", "z"]].values | |
y_test = df.iloc[:, 3:].values # TODO: show ground truth image if y_test is not None | |
else: | |
warnings.warn(f"{csv_path} not found for {im_path}") | |
return | |
preds = model.predict(np.expand_dims(inputs, 0))[0] | |
label_map = LABELS + ["none"] | |
visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png") | |
return f"{output_path}/{im_name}.png" | |
article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>" | |
iface = gr.Interface( | |
inference, # main function | |
inputs=[ | |
"file", | |
], | |
outputs=[ | |
gr.outputs.Image(label="result"), # generated image | |
], | |
title="Point cloud segmentation with PointNet", | |
article=article, | |
examples=examples, | |
).launch(enable_queue=True, cache_examples=True) | |