Spaces:
Running
Running
edyoshikun
commited on
Commit
•
5d9796e
1
Parent(s):
2021fee
adding vs-gradio
Browse files- a549.png +0 -0
- app.py +84 -0
- examples/a549.png +0 -0
- examples/hek.png +0 -0
- hek.png +0 -0
- requirements.txt +3 -0
a549.png
ADDED
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from viscy.light.engine import VSUNet
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
from numpy.typing import ArrayLike
|
6 |
+
from skimage import exposure
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
|
10 |
+
class VSGradio:
|
11 |
+
def __init__(self, model_config, model_ckpt_path):
|
12 |
+
self.model_config = model_config
|
13 |
+
self.model_ckpt_path = model_ckpt_path
|
14 |
+
self.model = None
|
15 |
+
self.load_model()
|
16 |
+
|
17 |
+
def load_model(self):
|
18 |
+
# Load the model checkpoint
|
19 |
+
self.model = VSUNet.load_from_checkpoint(
|
20 |
+
self.model_ckpt_path,
|
21 |
+
architecture="UNeXt2_2D",
|
22 |
+
model_config=self.model_config,
|
23 |
+
accelerator="gpu",
|
24 |
+
)
|
25 |
+
self.model.eval()
|
26 |
+
self.model
|
27 |
+
|
28 |
+
def normalize_fov(self, input: ArrayLike):
|
29 |
+
"Normalizing the fov with zero mean and unit variance"
|
30 |
+
mean = np.mean(input)
|
31 |
+
std = np.std(input)
|
32 |
+
return (input - mean) / std
|
33 |
+
|
34 |
+
def predict(self, inp):
|
35 |
+
# Setup the Trainer
|
36 |
+
# ensure inp is tensor has to be a (B,C,D,H,W) tensor
|
37 |
+
inp = self.normalize_fov(inp)
|
38 |
+
inp = torch.from_numpy(np.array(inp).astype(np.float32))
|
39 |
+
test_dict = dict(
|
40 |
+
index=None,
|
41 |
+
source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.model.device),
|
42 |
+
)
|
43 |
+
with torch.inference_mode():
|
44 |
+
self.model.on_predict_start()
|
45 |
+
pred = self.model.predict_step(test_dict, 0, 0).cpu().numpy()
|
46 |
+
# Return a 2D image
|
47 |
+
nuc_pred = pred[0, 0, 0]
|
48 |
+
mem_pred = pred[0, 1, 0]
|
49 |
+
nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
|
50 |
+
mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
|
51 |
+
return nuc_pred, mem_pred
|
52 |
+
|
53 |
+
|
54 |
+
# %%
|
55 |
+
if __name__ == "__main__":
|
56 |
+
model_ckpt_path = hf_hub_download(
|
57 |
+
repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
|
58 |
+
)
|
59 |
+
|
60 |
+
model_config = {
|
61 |
+
"in_channels": 1,
|
62 |
+
"out_channels": 2,
|
63 |
+
"encoder_blocks": [3, 3, 9, 3],
|
64 |
+
"dims": [96, 192, 384, 768],
|
65 |
+
"decoder_conv_blocks": 2,
|
66 |
+
"stem_kernel_size": [1, 2, 2],
|
67 |
+
"in_stack_depth": 1,
|
68 |
+
"pretraining": False,
|
69 |
+
}
|
70 |
+
|
71 |
+
vsgradio = VSGradio(model_config, model_ckpt_path)
|
72 |
+
|
73 |
+
gr.Interface(
|
74 |
+
fn=vsgradio.predict,
|
75 |
+
inputs=gr.Image(type="numpy", image_mode="L", format="png"),
|
76 |
+
outputs=[
|
77 |
+
gr.Image(type="numpy", format="png"),
|
78 |
+
gr.Image(type="numpy", format="png"),
|
79 |
+
],
|
80 |
+
examples=[
|
81 |
+
"examples/a549.png",
|
82 |
+
"examples/hek.png",
|
83 |
+
],
|
84 |
+
).launch()
|
examples/a549.png
ADDED
examples/hek.png
ADDED
hek.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
viscy<0.3.0
|
2 |
+
gradio
|
3 |
+
skimage
|