edyoshikun commited on
Commit
1f6e4a8
1 Parent(s): 6b043c2

custom CSS and adding title, description and references

Browse files
Files changed (2) hide show
  1. app.py +86 -28
  2. style.css +45 -0
app.py CHANGED
@@ -1,30 +1,28 @@
1
- from viscy.light.engine import VSUNet
2
- import torch
3
  import gradio as gr
4
- import numpy as np
 
 
5
  from numpy.typing import ArrayLike
 
6
  from skimage import exposure
7
- from huggingface_hub import hf_hub_download
8
 
9
 
10
  class VSGradio:
11
  def __init__(self, model_config, model_ckpt_path):
12
  self.model_config = model_config
13
  self.model_ckpt_path = model_ckpt_path
14
- self.device = torch.device(
15
- "cuda" if torch.cuda.is_available() else "cpu"
16
- ) # Check if GPU is available
17
  self.model = None
18
  self.load_model()
19
 
20
  def load_model(self):
21
- # Load the model checkpoint
22
  self.model = VSUNet.load_from_checkpoint(
23
  self.model_ckpt_path,
24
  architecture="UNeXt2_2D",
25
  model_config=self.model_config,
26
  )
27
- self.model.to(self.device)
28
  self.model.eval()
29
 
30
  def normalize_fov(self, input: ArrayLike):
@@ -34,31 +32,46 @@ class VSGradio:
34
  return (input - mean) / std
35
 
36
  def predict(self, inp):
37
- # Setup the Trainer
38
- # ensure inp is tensor has to be a (B,C,D,H,W) tensor
39
  inp = self.normalize_fov(inp)
40
  inp = torch.from_numpy(np.array(inp).astype(np.float32))
 
 
41
  test_dict = dict(
42
  index=None,
43
  source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
44
  )
 
 
45
  with torch.inference_mode():
46
- self.model.on_predict_start()
47
- pred = self.model.predict_step(test_dict, 0, 0).cpu().numpy()
48
- # Return a 2D image
 
 
 
49
  nuc_pred = pred[0, 0, 0]
50
  mem_pred = pred[0, 1, 0]
51
  nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
52
  mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
 
53
  return nuc_pred, mem_pred
54
 
55
 
 
 
 
 
 
 
56
  # %%
57
  if __name__ == "__main__":
 
58
  model_ckpt_path = hf_hub_download(
59
  repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
60
  )
61
 
 
62
  model_config = {
63
  "in_channels": 1,
64
  "out_channels": 2,
@@ -70,17 +83,62 @@ if __name__ == "__main__":
70
  "pretraining": False,
71
  }
72
 
73
- vsgradio = VSGradio(model_config, model_ckpt_path)
74
-
75
- gr.Interface(
76
- fn=vsgradio.predict,
77
- inputs=gr.Image(type="numpy", image_mode="L", format="png"),
78
- outputs=[
79
- gr.Image(type="numpy", format="png"),
80
- gr.Image(type="numpy", format="png"),
81
- ],
82
- examples=[
83
- "examples/a549.png",
84
- "examples/hek.png",
85
- ],
86
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from viscy.light.engine import VSUNet
4
+ from huggingface_hub import hf_hub_download
5
  from numpy.typing import ArrayLike
6
+ import numpy as np
7
  from skimage import exposure
 
8
 
9
 
10
  class VSGradio:
11
  def __init__(self, model_config, model_ckpt_path):
12
  self.model_config = model_config
13
  self.model_ckpt_path = model_ckpt_path
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
15
  self.model = None
16
  self.load_model()
17
 
18
  def load_model(self):
19
+ # Load the model checkpoint and move it to the correct device (GPU or CPU)
20
  self.model = VSUNet.load_from_checkpoint(
21
  self.model_ckpt_path,
22
  architecture="UNeXt2_2D",
23
  model_config=self.model_config,
24
  )
25
+ self.model.to(self.device) # Move the model to the correct device (GPU/CPU)
26
  self.model.eval()
27
 
28
  def normalize_fov(self, input: ArrayLike):
 
32
  return (input - mean) / std
33
 
34
  def predict(self, inp):
35
+ # Normalize the input and convert to tensor
 
36
  inp = self.normalize_fov(inp)
37
  inp = torch.from_numpy(np.array(inp).astype(np.float32))
38
+
39
+ # Prepare the input dictionary and move input to the correct device (GPU or CPU)
40
  test_dict = dict(
41
  index=None,
42
  source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
43
  )
44
+
45
+ # Run model inference
46
  with torch.inference_mode():
47
+ self.model.on_predict_start() # Necessary preprocessing for the model
48
+ pred = (
49
+ self.model.predict_step(test_dict, 0, 0).cpu().numpy()
50
+ ) # Move output back to CPU for post-processing
51
+
52
+ # Post-process the model output and rescale intensity
53
  nuc_pred = pred[0, 0, 0]
54
  mem_pred = pred[0, 1, 0]
55
  nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
56
  mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
57
+
58
  return nuc_pred, mem_pred
59
 
60
 
61
+ # Load the custom CSS from the file
62
+ def load_css(file_path):
63
+ with open(file_path, "r") as file:
64
+ return file.read()
65
+
66
+
67
  # %%
68
  if __name__ == "__main__":
69
+ # Download the model checkpoint from Hugging Face
70
  model_ckpt_path = hf_hub_download(
71
  repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
72
  )
73
 
74
+ # Model configuration
75
  model_config = {
76
  "in_channels": 1,
77
  "out_channels": 2,
 
83
  "pretraining": False,
84
  }
85
 
86
+ # Initialize the Gradio app using Blocks
87
+ with gr.Blocks(css=load_css("style.css")) as demo:
88
+ # Title and description
89
+ gr.HTML(
90
+ "<div class='title-block'>Image Translation (Virtual Staining) of cellular landmark organelles</div>"
91
+ )
92
+ # Improved description block with better formatting
93
+ gr.HTML(
94
+ """
95
+ <div class='description-block'>
96
+ <p><b>Model:</b> VSCyto2D</p>
97
+ <p>
98
+ <b>Input:</b> label-free image (e.g., QPI or phase contrast) <br>
99
+ <b>Output:</b> two virtually stained channels: one for the <b>nucleus</b> and one for the <b>cell membrane</b>.
100
+ </p>
101
+ <p>
102
+ Check out our preprint:
103
+ <a href='https://www.biorxiv.org/content/10.1101/2024.05.31.596901' target='_blank'><i>Liu et al.,Robust virtual staining of landmark organelles</i></a>
104
+ </p>
105
+ </div>
106
+ """
107
+ )
108
+
109
+ vsgradio = VSGradio(model_config, model_ckpt_path)
110
+
111
+ # Layout for input and output images
112
+ with gr.Row():
113
+ input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image")
114
+ with gr.Column():
115
+ output_nucleus = gr.Image(type="numpy", label="VS Nucleus")
116
+ output_membrane = gr.Image(type="numpy", label="VS Membrane")
117
+
118
+ # Button to trigger prediction
119
+ submit_button = gr.Button("Submit")
120
+
121
+ # Define what happens when the button is clicked
122
+ submit_button.click(
123
+ vsgradio.predict,
124
+ inputs=input_image,
125
+ outputs=[output_nucleus, output_membrane],
126
+ )
127
+
128
+ # Example images and article
129
+ gr.Examples(
130
+ examples=["examples/a549.png", "examples/hek.png"], inputs=input_image
131
+ )
132
+
133
+ # Article or footer information
134
+ gr.HTML(
135
+ """
136
+ <div class='article-block'>
137
+ <p> Model trained primarily on HEK293T, BJ5, and A549 cells. For best results, use quantitative phase images (QPI) or Zernike phase contrast.</p>
138
+ <p> For training, inference and evaluation of the model refer to the <a href='https://github.com/mehta-lab/VisCy/tree/main/examples/virtual_staining/dlmbl_exercise' target='_blank'>GitHub repository</a>.</p>
139
+ </div>
140
+ """
141
+ )
142
+
143
+ # Launch the Gradio app
144
+ demo.launch()
style.css ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Default styling for light mode */
2
+ .title-block, .description-block, .article-block {
3
+ background-color: #f0f0f0; /* Light background for light mode */
4
+ border-radius: 10px;
5
+ padding: 20px;
6
+ margin-bottom: 20px;
7
+ text-align: center;
8
+ }
9
+
10
+ .title-block {
11
+ font-size: 28px;
12
+ font-weight: bold;
13
+ color: #333; /* Dark text for light mode */
14
+ }
15
+
16
+ .description-block {
17
+ font-size: 18px;
18
+ color: #444; /* Slightly lighter text for light mode */
19
+ }
20
+
21
+ .article-block {
22
+ font-size: 16px;
23
+ margin-top: 30px;
24
+ color: #555; /* Even lighter text for light mode */
25
+ }
26
+
27
+ /* Dark mode styling */
28
+ @media (prefers-color-scheme: dark) {
29
+ .title-block, .description-block, .article-block {
30
+ background-color: #2b2b2b; /* Dark background for dark mode */
31
+ color: #f0f0f0; /* Light text for dark mode */
32
+ }
33
+
34
+ .title-block {
35
+ color: #e0e0e0; /* Light text for dark mode */
36
+ }
37
+
38
+ .description-block {
39
+ color: #d0d0d0; /* Lighter text for dark mode */
40
+ }
41
+
42
+ .article-block {
43
+ color: #c0c0c0; /* Even lighter text for dark mode */
44
+ }
45
+ }