from viscy.light.engine import VSUNet import torch import gradio as gr import numpy as np from numpy.typing import ArrayLike from skimage import exposure from huggingface_hub import hf_hub_download class VSGradio: def __init__(self, model_config, model_ckpt_path): self.model_config = model_config self.model_ckpt_path = model_ckpt_path self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Check if GPU is available self.model = None self.load_model() def load_model(self): # Load the model checkpoint self.model = VSUNet.load_from_checkpoint( self.model_ckpt_path, architecture="UNeXt2_2D", model_config=self.model_config, ) self.model.to(self.device) self.model.eval() def normalize_fov(self, input: ArrayLike): "Normalizing the fov with zero mean and unit variance" mean = np.mean(input) std = np.std(input) return (input - mean) / std def predict(self, inp): # Setup the Trainer # ensure inp is tensor has to be a (B,C,D,H,W) tensor inp = self.normalize_fov(inp) inp = torch.from_numpy(np.array(inp).astype(np.float32)) test_dict = dict( index=None, source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device), ) with torch.inference_mode(): self.model.on_predict_start() pred = self.model.predict_step(test_dict, 0, 0).cpu().numpy() # Return a 2D image nuc_pred = pred[0, 0, 0] mem_pred = pred[0, 1, 0] nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1)) mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1)) return nuc_pred, mem_pred # %% if __name__ == "__main__": model_ckpt_path = hf_hub_download( repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt" ) model_config = { "in_channels": 1, "out_channels": 2, "encoder_blocks": [3, 3, 9, 3], "dims": [96, 192, 384, 768], "decoder_conv_blocks": 2, "stem_kernel_size": [1, 2, 2], "in_stack_depth": 1, "pretraining": False, } vsgradio = VSGradio(model_config, model_ckpt_path) gr.Interface( fn=vsgradio.predict, inputs=gr.Image(type="numpy", image_mode="L", format="png"), outputs=[ gr.Image(type="numpy", format="png"), gr.Image(type="numpy", format="png"), ], examples=[ "examples/a549.png", "examples/hek.png", ], ).launch()