edyoshikun commited on
Commit
5d9796e
1 Parent(s): 2021fee

adding vs-gradio

Browse files
Files changed (6) hide show
  1. a549.png +0 -0
  2. app.py +84 -0
  3. examples/a549.png +0 -0
  4. examples/hek.png +0 -0
  5. hek.png +0 -0
  6. requirements.txt +3 -0
a549.png ADDED
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from viscy.light.engine import VSUNet
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ from numpy.typing import ArrayLike
6
+ from skimage import exposure
7
+ from huggingface_hub import hf_hub_download
8
+
9
+
10
+ class VSGradio:
11
+ def __init__(self, model_config, model_ckpt_path):
12
+ self.model_config = model_config
13
+ self.model_ckpt_path = model_ckpt_path
14
+ self.model = None
15
+ self.load_model()
16
+
17
+ def load_model(self):
18
+ # Load the model checkpoint
19
+ self.model = VSUNet.load_from_checkpoint(
20
+ self.model_ckpt_path,
21
+ architecture="UNeXt2_2D",
22
+ model_config=self.model_config,
23
+ accelerator="gpu",
24
+ )
25
+ self.model.eval()
26
+ self.model
27
+
28
+ def normalize_fov(self, input: ArrayLike):
29
+ "Normalizing the fov with zero mean and unit variance"
30
+ mean = np.mean(input)
31
+ std = np.std(input)
32
+ return (input - mean) / std
33
+
34
+ def predict(self, inp):
35
+ # Setup the Trainer
36
+ # ensure inp is tensor has to be a (B,C,D,H,W) tensor
37
+ inp = self.normalize_fov(inp)
38
+ inp = torch.from_numpy(np.array(inp).astype(np.float32))
39
+ test_dict = dict(
40
+ index=None,
41
+ source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.model.device),
42
+ )
43
+ with torch.inference_mode():
44
+ self.model.on_predict_start()
45
+ pred = self.model.predict_step(test_dict, 0, 0).cpu().numpy()
46
+ # Return a 2D image
47
+ nuc_pred = pred[0, 0, 0]
48
+ mem_pred = pred[0, 1, 0]
49
+ nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
50
+ mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
51
+ return nuc_pred, mem_pred
52
+
53
+
54
+ # %%
55
+ if __name__ == "__main__":
56
+ model_ckpt_path = hf_hub_download(
57
+ repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
58
+ )
59
+
60
+ model_config = {
61
+ "in_channels": 1,
62
+ "out_channels": 2,
63
+ "encoder_blocks": [3, 3, 9, 3],
64
+ "dims": [96, 192, 384, 768],
65
+ "decoder_conv_blocks": 2,
66
+ "stem_kernel_size": [1, 2, 2],
67
+ "in_stack_depth": 1,
68
+ "pretraining": False,
69
+ }
70
+
71
+ vsgradio = VSGradio(model_config, model_ckpt_path)
72
+
73
+ gr.Interface(
74
+ fn=vsgradio.predict,
75
+ inputs=gr.Image(type="numpy", image_mode="L", format="png"),
76
+ outputs=[
77
+ gr.Image(type="numpy", format="png"),
78
+ gr.Image(type="numpy", format="png"),
79
+ ],
80
+ examples=[
81
+ "examples/a549.png",
82
+ "examples/hek.png",
83
+ ],
84
+ ).launch()
examples/a549.png ADDED
examples/hek.png ADDED
hek.png ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ viscy<0.3.0
2
+ gradio
3
+ skimage