Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from util import imread, imsave, copy_skimage_data | |
import torch | |
from PIL import Image, ImageDraw | |
import numpy as np | |
from os.path import join | |
def torch_compile(*args, **kwargs): | |
def decorator(func): | |
return func | |
return decorator | |
torch.compile = torch_compile # temporary workaround | |
default_model = 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c' | |
default_score_thresh = .9 | |
default_nms_thresh = np.round(np.pi / 10, 4) | |
default_samples = 128 | |
default_order = 5 | |
examples_dir = 'examples' | |
copy_skimage_data(examples_dir) | |
examples = [ | |
[join(examples_dir, 'bbbc039_test_00014.png'), 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c', False, default_score_thresh, False, | |
default_nms_thresh, True, 64, True], | |
[join(examples_dir, 'coins.png'), 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c', False, default_score_thresh, False, | |
default_nms_thresh, True, 64, True], | |
[join(examples_dir, 'cell.png'), 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c', False, default_score_thresh, False, | |
default_nms_thresh, True, 64, True], | |
] | |
def predict( | |
filename, model=None, | |
enable_score_threshold=False, score_threshold=.9, | |
enable_nms_threshold=False, nms_threshold=0.3141592653589793, | |
enable_samples=False, samples=128, | |
use_label_channels=False, | |
enable_order=False, order=5, | |
device=None, | |
): | |
from cpn import CpnInterface | |
from prep import multi_norm | |
from celldetection import label_cmap, to_h5, data, __version__ | |
global default_model | |
assert isinstance(filename, str) | |
if device is None: | |
if torch.cuda.device_count(): | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
meta = dict( | |
cd_version=__version__, | |
filename=str(filename), | |
model=model, | |
device=device, | |
use_label_channels=use_label_channels, | |
enable_score_threshold=enable_score_threshold, | |
score_threshold=float(score_threshold), | |
enable_order=enable_order, | |
order=order, | |
enable_nms_threshold=enable_nms_threshold, | |
nms_threshold=float(nms_threshold), | |
) | |
print(meta, flush=True) | |
raw = img = imread(filename) | |
print('Image:', img.dtype, img.shape, (img.min(), img.max()), flush=True) | |
if model is None or len(str(model)) <= 0: | |
model = default_model | |
img = multi_norm(img, 'cstm-mix') # TODO | |
kw = {} | |
if enable_score_threshold: | |
kw['score_thresh'] = score_threshold | |
if enable_nms_threshold: | |
kw['nms_thresh'] = nms_threshold | |
if enable_order: | |
kw['order'] = order | |
if enable_samples: | |
kw['samples'] = samples | |
m = CpnInterface(model.strip(), device=device, **kw) | |
y = m(img, reduce_labels=not use_label_channels) | |
dst_h5 = '.'.join(filename.split('.')[:-1]) + '.h5' | |
to_h5( | |
dst_h5, inputs=img, **y, | |
attributes=dict(inputs=meta) | |
) | |
labels = y['labels'] | |
vis_labels = label_cmap(labels) | |
dst_csv = '.'.join(filename.split('.')[:-1]) + '.csv' | |
data.labels2property_table( | |
labels, | |
"label", "area", "feret_diameter_max", "bbox", "centroid", "convex_area", | |
"eccentricity", "equivalent_diameter", | |
"extent", "filled_area", "major_axis_length", | |
"minor_axis_length", "orientation", "perimeter", | |
"solidity", "mean_intensity", "max_intensity", "min_intensity", | |
intensity_image=raw | |
).to_csv(dst_csv) | |
return vis_labels, img, dst_h5, dst_csv | |
with gr.Blocks(title='Cell Segmentation with Contour Proposal Networks') as app: | |
with gr.Row(): | |
gr.Markdown("<center><strong><font size='7'>" | |
"Cell Segmentation with Contour Proposal Networks 🤗</font></strong></center>") | |
with gr.Row(): | |
with gr.Column(): | |
img = gr.components.Image(label="Upload Input Image", type="filepath", interactive=True, | |
value=examples[0][0]) | |
with gr.Column(): | |
model_name = gr.components.Textbox(label='Model Name', value=default_model, max_lines=1) | |
with gr.Row(): | |
score_thresh_ck = gr.components.Checkbox(label="Use custom Score Threshold", value=False) | |
score_thresh = gr.components.Slider(minimum=0, maximum=1, label="Score Threshold", | |
value=default_score_thresh) | |
with gr.Row(): | |
nms_thresh_ck = gr.components.Checkbox(label="Use custom NMS Threshold", value=False) | |
nms_thresh = gr.components.Slider(minimum=0, maximum=1, label="NMS Threshold", value=default_nms_thresh) | |
# with gr.Row(): | |
# # The range of this would need to be model dependent | |
# order_ck = gr.components.Checkbox(label="Use custom Order", value=False) | |
# order = gr.components.Slider(minimum=0, maximum=1, label="Order", value=default_order) | |
with gr.Row(): | |
samples_ck = gr.components.Checkbox(label="Use custom Sample Points", value=False) | |
samples = gr.components.Slider(minimum=8, maximum=256, label="Sample Points", value=default_samples) | |
with gr.Row(): | |
channels = gr.components.Checkbox(label="Allow overlapping objects", value=True) | |
with gr.Row(): | |
clr = gr.Button('Reset') | |
btn = gr.Button('Run') | |
with gr.Row(): | |
with gr.Column(): | |
out_img = gr.Image(label="Processed Image") | |
with gr.Column(): | |
out_vis = gr.Image(label="Label Image (random colors, transparent overlap)") | |
with gr.Row(): | |
out_h5 = gr.File(label="Download Results as HDF5 File") | |
out_csv = gr.File(label="Download Properties as CSV File") | |
with gr.Row(): | |
gr.Examples( | |
fn=predict, | |
examples=examples, | |
inputs=[img, model_name, score_thresh_ck, score_thresh, nms_thresh_ck, nms_thresh, samples_ck, samples, | |
channels], | |
outputs=[out_vis, out_img, out_h5, out_csv], | |
cache_examples=True, | |
batch=False | |
) | |
btn.click( | |
predict, | |
inputs=[img, model_name, score_thresh_ck, score_thresh, nms_thresh_ck, nms_thresh, samples_ck, samples, | |
channels], | |
outputs=[out_vis, out_img, out_h5, out_csv] | |
) | |
clr.click( | |
lambda: ( | |
None, default_score_thresh, default_nms_thresh, False, False, None, None, None, False, default_samples), | |
inputs=[], | |
outputs=[img, score_thresh, nms_thresh, score_thresh_ck, nms_thresh_ck, out_img, out_h5, out_vis, samples_ck, | |
samples] | |
) | |
with gr.Row(): | |
gr.Markdown("<center><font size='3'>" | |
"<a href='https://github.com/FZJ-INM1-BDA/celldetection'>Visit us on GitHub</a></font></center>") | |
app.launch() | |