Spaces:
Running
Running
File size: 2,700 Bytes
5d9796e 90ad424 5d9796e 90ad424 5d9796e 90ad424 5d9796e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from viscy.light.engine import VSUNet
import torch
import gradio as gr
import numpy as np
from numpy.typing import ArrayLike
from skimage import exposure
from huggingface_hub import hf_hub_download
class VSGradio:
def __init__(self, model_config, model_ckpt_path):
self.model_config = model_config
self.model_ckpt_path = model_ckpt_path
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
) # Check if GPU is available
self.model = None
self.load_model()
def load_model(self):
# Load the model checkpoint
self.model = VSUNet.load_from_checkpoint(
self.model_ckpt_path,
architecture="UNeXt2_2D",
model_config=self.model_config,
)
self.model.to(self.device)
self.model.eval()
def normalize_fov(self, input: ArrayLike):
"Normalizing the fov with zero mean and unit variance"
mean = np.mean(input)
std = np.std(input)
return (input - mean) / std
def predict(self, inp):
# Setup the Trainer
# ensure inp is tensor has to be a (B,C,D,H,W) tensor
inp = self.normalize_fov(inp)
inp = torch.from_numpy(np.array(inp).astype(np.float32))
test_dict = dict(
index=None,
source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
)
with torch.inference_mode():
self.model.on_predict_start()
pred = self.model.predict_step(test_dict, 0, 0).cpu().numpy()
# Return a 2D image
nuc_pred = pred[0, 0, 0]
mem_pred = pred[0, 1, 0]
nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
return nuc_pred, mem_pred
# %%
if __name__ == "__main__":
model_ckpt_path = hf_hub_download(
repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
)
model_config = {
"in_channels": 1,
"out_channels": 2,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"decoder_conv_blocks": 2,
"stem_kernel_size": [1, 2, 2],
"in_stack_depth": 1,
"pretraining": False,
}
vsgradio = VSGradio(model_config, model_ckpt_path)
gr.Interface(
fn=vsgradio.predict,
inputs=gr.Image(type="numpy", image_mode="L", format="png"),
outputs=[
gr.Image(type="numpy", format="png"),
gr.Image(type="numpy", format="png"),
],
examples=[
"examples/a549.png",
"examples/hek.png",
],
).launch()
|