File size: 2,700 Bytes
5d9796e
 
 
 
 
 
 
 
 
 
 
 
 
90ad424
 
 
5d9796e
 
 
 
 
 
 
 
 
 
90ad424
5d9796e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90ad424
5d9796e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from viscy.light.engine import VSUNet
import torch
import gradio as gr
import numpy as np
from numpy.typing import ArrayLike
from skimage import exposure
from huggingface_hub import hf_hub_download


class VSGradio:
    def __init__(self, model_config, model_ckpt_path):
        self.model_config = model_config
        self.model_ckpt_path = model_ckpt_path
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # Check if GPU is available
        self.model = None
        self.load_model()

    def load_model(self):
        # Load the model checkpoint
        self.model = VSUNet.load_from_checkpoint(
            self.model_ckpt_path,
            architecture="UNeXt2_2D",
            model_config=self.model_config,
        )
        self.model.to(self.device)
        self.model.eval()

    def normalize_fov(self, input: ArrayLike):
        "Normalizing the fov with zero mean and unit variance"
        mean = np.mean(input)
        std = np.std(input)
        return (input - mean) / std

    def predict(self, inp):
        # Setup the Trainer
        # ensure inp is tensor has to be a (B,C,D,H,W) tensor
        inp = self.normalize_fov(inp)
        inp = torch.from_numpy(np.array(inp).astype(np.float32))
        test_dict = dict(
            index=None,
            source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
        )
        with torch.inference_mode():
            self.model.on_predict_start()
            pred = self.model.predict_step(test_dict, 0, 0).cpu().numpy()
        # Return a 2D image
        nuc_pred = pred[0, 0, 0]
        mem_pred = pred[0, 1, 0]
        nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
        mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
        return nuc_pred, mem_pred


# %%
if __name__ == "__main__":
    model_ckpt_path = hf_hub_download(
        repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
    )

    model_config = {
        "in_channels": 1,
        "out_channels": 2,
        "encoder_blocks": [3, 3, 9, 3],
        "dims": [96, 192, 384, 768],
        "decoder_conv_blocks": 2,
        "stem_kernel_size": [1, 2, 2],
        "in_stack_depth": 1,
        "pretraining": False,
    }

    vsgradio = VSGradio(model_config, model_ckpt_path)

    gr.Interface(
        fn=vsgradio.predict,
        inputs=gr.Image(type="numpy", image_mode="L", format="png"),
        outputs=[
            gr.Image(type="numpy", format="png"),
            gr.Image(type="numpy", format="png"),
        ],
        examples=[
            "examples/a549.png",
            "examples/hek.png",
        ],
    ).launch()