edyoshikun commited on
Commit
8cfcc1c
1 Parent(s): 100fc99

image adjustments, colormap and cell diameter

Browse files
Files changed (4) hide show
  1. app.py +136 -23
  2. examples/ctc_HeLa.png +0 -0
  3. examples/livecell_A172.png +0 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -5,6 +5,10 @@ from huggingface_hub import hf_hub_download
5
  from numpy.typing import ArrayLike
6
  import numpy as np
7
  from skimage import exposure
 
 
 
 
8
 
9
 
10
  class VSGradio:
@@ -31,9 +35,26 @@ class VSGradio:
31
  std = np.std(input)
32
  return (input - mean) / std
33
 
34
- def predict(self, inp):
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Normalize the input and convert to tensor
36
  inp = self.normalize_fov(inp)
 
 
 
 
 
37
  inp = torch.from_numpy(np.array(inp).astype(np.float32))
38
 
39
  # Prepare the input dictionary and move input to the correct device (GPU or CPU)
@@ -52,10 +73,60 @@ class VSGradio:
52
  # Post-process the model output and rescale intensity
53
  nuc_pred = pred[0, 0, 0]
54
  mem_pred = pred[0, 1, 0]
55
- nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
56
- mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
57
 
58
- return nuc_pred, mem_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  # Load the custom CSS from the file
@@ -64,7 +135,6 @@ def load_css(file_path):
64
  return file.read()
65
 
66
 
67
- # %%
68
  if __name__ == "__main__":
69
  # Download the model checkpoint from Hugging Face
70
  model_ckpt_path = hf_hub_download(
@@ -83,59 +153,102 @@ if __name__ == "__main__":
83
  "pretraining": False,
84
  }
85
 
 
 
86
  # Initialize the Gradio app using Blocks
87
  with gr.Blocks(css=load_css("style.css")) as demo:
88
  # Title and description
89
  gr.HTML(
90
  "<div class='title-block'>Image Translation (Virtual Staining) of cellular landmark organelles</div>"
91
  )
92
- # Improved description block with better formatting
93
  gr.HTML(
94
  """
95
  <div class='description-block'>
96
  <p><b>Model:</b> VSCyto2D</p>
97
- <p>
98
- <b>Input:</b> label-free image (e.g., QPI or phase contrast) <br>
99
- <b>Output:</b> two virtually stained channels: one for the <b>nucleus</b> and one for the <b>cell membrane</b>.
100
- </p>
101
- <p>
102
- Check out our preprint:
103
- <a href='https://www.biorxiv.org/content/10.1101/2024.05.31.596901' target='_blank'><i>Liu et al.,Robust virtual staining of landmark organelles</i></a>
104
- </p>
105
  </div>
106
  """
107
  )
108
 
109
- vsgradio = VSGradio(model_config, model_ckpt_path)
110
-
111
  # Layout for input and output images
112
  with gr.Row():
113
  input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image")
 
 
 
 
114
  with gr.Column():
115
- output_nucleus = gr.Image(type="numpy", label="VS Nucleus")
116
- output_membrane = gr.Image(type="numpy", label="VS Membrane")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Button to trigger prediction
119
  submit_button = gr.Button("Submit")
120
 
121
- # Define what happens when the button is clicked
122
  submit_button.click(
123
  vsgradio.predict,
124
- inputs=input_image,
125
  outputs=[output_nucleus, output_membrane],
126
  )
127
 
128
  # Example images and article
129
  gr.Examples(
130
- examples=["examples/a549.png", "examples/hek.png"], inputs=input_image
 
 
 
 
 
 
131
  )
132
 
133
  # Article or footer information
134
  gr.HTML(
135
  """
136
  <div class='article-block'>
137
- <p> Model trained primarily on HEK293T, BJ5, and A549 cells. For best results, use quantitative phase images (QPI) or Zernike phase contrast.</p>
138
- <p> For training, inference and evaluation of the model refer to the <a href='https://github.com/mehta-lab/VisCy/tree/main/examples/virtual_staining/dlmbl_exercise' target='_blank'>GitHub repository</a>.</p>
139
  </div>
140
  """
141
  )
 
5
  from numpy.typing import ArrayLike
6
  import numpy as np
7
  from skimage import exposure
8
+ from skimage.transform import resize
9
+ from skimage import img_as_float
10
+ from skimage.util import invert
11
+ import cmap
12
 
13
 
14
  class VSGradio:
 
35
  std = np.std(input)
36
  return (input - mean) / std
37
 
38
+ def preprocess_image_standard(self, input: ArrayLike):
39
+ # Perform standard preprocessing here
40
+ input = exposure.equalize_adapthist(input)
41
+ return input
42
+
43
+ def downscale_image(self, inp: ArrayLike, scale_factor: float):
44
+ """Downscales the image by the given scaling factor"""
45
+ height, width = inp.shape
46
+ new_height = int(height * scale_factor)
47
+ new_width = int(width * scale_factor)
48
+ return resize(inp, (new_height, new_width), anti_aliasing=True)
49
+
50
+ def predict(self, inp, cell_diameter: float):
51
  # Normalize the input and convert to tensor
52
  inp = self.normalize_fov(inp)
53
+ original_shape = inp.shape
54
+ # Resize the input image to the expected cell diameter
55
+ inp = apply_rescale_image(inp, cell_diameter, expected_cell_diameter=30)
56
+
57
+ # Convert the input to a tensor
58
  inp = torch.from_numpy(np.array(inp).astype(np.float32))
59
 
60
  # Prepare the input dictionary and move input to the correct device (GPU or CPU)
 
73
  # Post-process the model output and rescale intensity
74
  nuc_pred = pred[0, 0, 0]
75
  mem_pred = pred[0, 1, 0]
 
 
76
 
77
+ # Resize predictions back to the original image size
78
+ nuc_pred = resize(nuc_pred, original_shape, anti_aliasing=True)
79
+ mem_pred = resize(mem_pred, original_shape, anti_aliasing=True)
80
+
81
+ # Define colormaps
82
+ green_colormap = cmap.Colormap("green") # Nucleus: black to green
83
+ magenta_colormap = cmap.Colormap("magenta")
84
+
85
+ # Apply the colormap to the predictions
86
+ nuc_rgb = apply_colormap(nuc_pred, green_colormap)
87
+ mem_rgb = apply_colormap(mem_pred, magenta_colormap)
88
+
89
+ return nuc_rgb, mem_rgb
90
+
91
+
92
+ def apply_colormap(prediction, colormap: cmap.Colormap):
93
+ """Apply a colormap to a single-channel prediction image."""
94
+ # Ensure the prediction is within the valid range [0, 1]
95
+ prediction = exposure.rescale_intensity(prediction, out_range=(0, 1))
96
+
97
+ # Apply the colormap to get an RGB image
98
+ rgb_image = colormap(prediction)
99
+
100
+ # Convert the output from [0, 1] to [0, 255] for display
101
+ rgb_image_uint8 = (rgb_image * 255).astype(np.uint8)
102
+
103
+ return rgb_image_uint8
104
+
105
+
106
+ def apply_image_adjustments(image, invert_image: bool, gamma_factor: float):
107
+ """Applies all the image adjustments (invert, contrast, gamma) in sequence"""
108
+ # Apply invert
109
+ if invert_image:
110
+ image = invert(image, signed_float=False)
111
+
112
+ # Apply gamma adjustment
113
+ image = exposure.adjust_gamma(image, gamma_factor)
114
+
115
+ return exposure.rescale_intensity(image, out_range=(0, 255)).astype(np.uint8)
116
+
117
+
118
+ def apply_rescale_image(
119
+ image, cell_diameter: float, expected_cell_diameter: float = 30
120
+ ):
121
+ # Assume the model was trained with cells ~30 microns in diameter
122
+ # Resize the input image according to the scaling factor
123
+ scale_factor = expected_cell_diameter / float(cell_diameter)
124
+ image = resize(
125
+ image,
126
+ (int(image.shape[0] * scale_factor), int(image.shape[1] * scale_factor)),
127
+ anti_aliasing=True,
128
+ )
129
+ return image
130
 
131
 
132
  # Load the custom CSS from the file
 
135
  return file.read()
136
 
137
 
 
138
  if __name__ == "__main__":
139
  # Download the model checkpoint from Hugging Face
140
  model_ckpt_path = hf_hub_download(
 
153
  "pretraining": False,
154
  }
155
 
156
+ vsgradio = VSGradio(model_config, model_ckpt_path)
157
+
158
  # Initialize the Gradio app using Blocks
159
  with gr.Blocks(css=load_css("style.css")) as demo:
160
  # Title and description
161
  gr.HTML(
162
  "<div class='title-block'>Image Translation (Virtual Staining) of cellular landmark organelles</div>"
163
  )
 
164
  gr.HTML(
165
  """
166
  <div class='description-block'>
167
  <p><b>Model:</b> VSCyto2D</p>
168
+ <p><b>Input:</b> label-free image (e.g., QPI or phase contrast).</p>
169
+ <p><b>Output:</b> Virtual staining of nucleus and membrane.</p>
170
+ <p><b>Note:</b> The model works well with QPI, and sometimes generalizes to phase contrast and DIC. We continue to diagnose and improve generalization<p>
171
+ <p>Check out our preprint: <a href='https://www.biorxiv.org/content/10.1101/2024.05.31.596901' target='_blank'><i>Liu et al., Robust virtual staining of landmark organelles</i></a></p>
172
+ <p> For training, inference and evaluation of the model refer to the <a href='https://github.com/mehta-lab/VisCy/tree/main/examples/virtual_staining/dlmbl_exercise' target='_blank'>GitHub repository</a>.</p>
 
 
 
173
  </div>
174
  """
175
  )
176
 
 
 
177
  # Layout for input and output images
178
  with gr.Row():
179
  input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image")
180
+ adjusted_image = gr.Image(
181
+ type="numpy", image_mode="L", label="Adjusted Image (Preview)"
182
+ )
183
+
184
  with gr.Column():
185
+ output_nucleus = gr.Image(
186
+ type="numpy", image_mode="RGB", label="VS Nucleus"
187
+ )
188
+ output_membrane = gr.Image(
189
+ type="numpy", image_mode="RGB", label="VS Membrane"
190
+ )
191
+
192
+ # Checkbox for applying invert
193
+ preprocess_invert = gr.Checkbox(label="Apply Invert", value=False)
194
+
195
+ # Slider for gamma adjustment
196
+ gamma_factor = gr.Slider(
197
+ label="Adjust Gamma", minimum=0.1, maximum=5.0, value=1.0, step=0.1
198
+ )
199
+
200
+ # Input field for the cell diameter in microns
201
+ cell_diameter = gr.Textbox(
202
+ label="Cell Diameter [um]",
203
+ value="30.0",
204
+ placeholder="Enter cell diameter in microns",
205
+ )
206
+
207
+ # Update the adjusted image based on all the transformations
208
+ input_image.change(
209
+ fn=apply_image_adjustments,
210
+ inputs=[input_image, preprocess_invert, gamma_factor],
211
+ outputs=adjusted_image,
212
+ )
213
+
214
+ gamma_factor.change(
215
+ fn=apply_image_adjustments,
216
+ inputs=[input_image, preprocess_invert, gamma_factor],
217
+ outputs=adjusted_image,
218
+ )
219
+
220
+ preprocess_invert.change(
221
+ fn=apply_image_adjustments,
222
+ inputs=[input_image, preprocess_invert, gamma_factor],
223
+ outputs=adjusted_image,
224
+ )
225
 
226
  # Button to trigger prediction
227
  submit_button = gr.Button("Submit")
228
 
229
+ # Define what happens when the button is clicked (send adjusted image to predict)
230
  submit_button.click(
231
  vsgradio.predict,
232
+ inputs=[adjusted_image, cell_diameter],
233
  outputs=[output_nucleus, output_membrane],
234
  )
235
 
236
  # Example images and article
237
  gr.Examples(
238
+ examples=[
239
+ "examples/a549.png",
240
+ "examples/hek.png",
241
+ "examples/ctc_HeLa.png",
242
+ "examples/livecell_A172.png",
243
+ ],
244
+ inputs=input_image,
245
  )
246
 
247
  # Article or footer information
248
  gr.HTML(
249
  """
250
  <div class='article-block'>
251
+ <p> Model trained primarily on HEK293T, BJ5, and A549 cells. For best results, use quantitative phase images (QPI)</p>
 
252
  </div>
253
  """
254
  )
examples/ctc_HeLa.png ADDED
examples/livecell_A172.png ADDED
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  viscy<0.3.0
2
  gradio
3
- scikit-image
 
 
1
  viscy<0.3.0
2
  gradio
3
+ scikit-image
4
+ cmap