| | import numpy as np |
| | import gradio as gr |
| | import spaces |
| | import cv2 |
| | from cellpose import models |
| | from matplotlib.colors import hsv_to_rgb |
| | import matplotlib.pyplot as plt |
| | import os, io, base64 |
| | from PIL import Image |
| | from cellpose.io import imread, imsave |
| | import glob |
| |
|
| | from huggingface_hub import hf_hub_download |
| |
|
| | img = np.zeros((96, 128), dtype = np.uint8) |
| | fp0 = Image.fromarray(img) |
| | |
| | |
| |
|
| | |
| | def download_weights(): |
| | return hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam") |
| | |
| | |
| |
|
| | def download_weights_old(): |
| | import os, requests |
| | |
| | fname = ['cpsam'] |
| | |
| | url = ["https://osf.io/d7c8e/download"] |
| | |
| | for j in range(len(url)): |
| | if not os.path.isfile(fname[j]): |
| | ntries = 0 |
| | while ntries<10: |
| | try: |
| | r = requests.get(url[j]) |
| | except: |
| | print("!!! Failed to download data !!!") |
| | ntries += 1 |
| | print(ntries) |
| | |
| | if r.status_code != requests.codes.ok: |
| | print("!!! Failed to download data !!!") |
| | else: |
| | with open(fname[j], "wb") as fid: |
| | fid.write(r.content) |
| |
|
| | try: |
| | fpath = download_weights() |
| | model = models.CellposeModel(gpu=True, pretrained_model = fpath) |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | exit(1) |
| |
|
| |
|
| |
|
| | |
| | def plot_flows(y): |
| | Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2 |
| | X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2 |
| | H = (np.arctan2(Y, X) + np.pi) / (2*np.pi) |
| | S = normalize99(y[0][0]**2 + y[1][0]**2) |
| | HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1) |
| | HSV = np.clip(HSV, 0.0, 1.0) |
| | flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8) |
| | return flow |
| |
|
| | def plot_outlines(img, masks): |
| | img = normalize99(img) |
| | img = np.clip(img, 0, 1) |
| | outpix = [] |
| | contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE) |
| | for c in range(len(contours)): |
| | pix = contours[c].astype(int).squeeze() |
| | if len(pix)>4: |
| | peri = cv2.arcLength(contours[c], True) |
| | approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:] |
| | outpix.append(approx) |
| | |
| | figsize = (6,6) |
| | if img.shape[0]>img.shape[1]: |
| | figsize = (6*img.shape[1]/img.shape[0], 6) |
| | else: |
| | figsize = (6, 6*img.shape[0]/img.shape[1]) |
| | fig = plt.figure(figsize=figsize, facecolor='k') |
| | ax = fig.add_axes([0.0,0.0,1,1]) |
| | ax.set_xlim([0,img.shape[1]]) |
| | ax.set_ylim([0,img.shape[0]]) |
| | ax.imshow(img[::-1], origin='upper', aspect = 'auto') |
| | if outpix is not None: |
| | for o in outpix: |
| | ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1) |
| | ax.axis('off') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | buf = io.BytesIO() |
| | fig.savefig(buf, bbox_inches='tight') |
| | buf.seek(0) |
| | pil_img = Image.open(buf) |
| |
|
| | plt.close(fig) |
| |
|
| | return pil_img |
| |
|
| | def plot_overlay(img, masks): |
| | if img.ndim>2: |
| | img_gray = img.astype(np.float32).mean(axis=-1) |
| | else: |
| | img_gray = img.astype(np.float32) |
| | |
| | img = normalize99(img_gray) |
| | |
| | HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32) |
| | HSV[:,:,2] = np.clip(img*1.5, 0, 1.0) |
| | for n in range(int(masks.max())): |
| | ipix = (masks==n+1).nonzero() |
| | HSV[ipix[0],ipix[1],0] = np.random.rand() |
| | HSV[ipix[0],ipix[1],1] = 1.0 |
| | RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8) |
| | return RGB |
| |
|
| | def normalize99(img): |
| | X = img.copy() |
| | X = (X - np.percentile(X, 1)) / (1e-10 + np.percentile(X, 99) - np.percentile(X, 1)) |
| | return X |
| |
|
| | def image_resize(img, resize=400): |
| | ny,nx = img.shape[:2] |
| | if np.array(img.shape).max() > resize: |
| | if ny>nx: |
| | nx = int(nx/ny * resize) |
| | ny = resize |
| | else: |
| | ny = int(ny/nx * resize) |
| | nx = resize |
| | shape = (nx,ny) |
| | img = cv2.resize(img, shape) |
| | img = img.astype(np.uint8) |
| | return img |
| |
|
| | |
| | @spaces.GPU(duration=10) |
| | def run_model_gpu(img, max_iter, flow_threshold, cellprob_threshold): |
| | masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold) |
| | return masks, flows |
| |
|
| | @spaces.GPU(duration=60) |
| | def run_model_gpu60(img, max_iter, flow_threshold, cellprob_threshold): |
| | masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold) |
| | return masks, flows |
| |
|
| | @spaces.GPU(duration=240) |
| | def run_model_gpu240(img, max_iter, flow_threshold, cellprob_threshold): |
| | masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold) |
| | return masks, flows |
| |
|
| | import datetime |
| | from zipfile import ZipFile |
| | def cellpose_segment(filepath, resize = 1000,max_iter = 250, flow_threshold= 0.4, cellprob_threshold = 0): |
| |
|
| | zip_path = os.path.splitext(filepath[-1])[0]+"_masks.zip" |
| | |
| | with ZipFile(zip_path, 'w') as myzip: |
| | for j in range((len(filepath))): |
| | now = datetime.datetime.now() |
| | formatted_now = now.strftime("%Y-%m-%d %H:%M:%S") |
| | |
| | img_input = imread(filepath[j]) |
| | |
| | img = image_resize(img_input, resize = resize) |
| | |
| | maxsize = np.max(img.shape) |
| | if maxsize<=1000: |
| | masks, flows = run_model_gpu(img, max_iter, flow_threshold, cellprob_threshold) |
| | elif maxsize < 5000: |
| | masks, flows = run_model_gpu60(img, max_iter, flow_threshold, cellprob_threshold) |
| | elif maxsize < 20000: |
| | masks, flows = run_model_gpu240(img, max_iter, flow_threshold, cellprob_threshold) |
| | else: |
| | raise ValueError("Image size must be less than 20,000") |
| |
|
| | print(formatted_now, j, masks.max(), os.path.split(filepath[j])[-1]) |
| | |
| | target_size = (img_input.shape[1], img_input.shape[0]) |
| | if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]): |
| | |
| | masks_rsz = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16') |
| | else: |
| | masks_rsz = masks.copy() |
| | |
| | fname_masks = os.path.splitext(filepath[j])[0]+"_masks.tif" |
| | imsave(fname_masks, masks_rsz) |
| | |
| | myzip.write(fname_masks, arcname = os.path.split(fname_masks)[-1]) |
| | |
| | |
| | |
| | flows = flows[0] |
| | |
| | |
| |
|
| | outpix = plot_outlines(img, masks) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | flows = Image.fromarray(flows) |
| |
|
| | Ly, Lx = img.shape[:2] |
| | outpix = outpix.resize((Lx, Ly), resample = Image.BICUBIC) |
| | |
| | flows = flows.resize((Lx, Ly), resample = Image.BICUBIC) |
| |
|
| | fname_out = os.path.splitext(filepath[-1])[0]+"_outlines.png" |
| | outpix.save(fname_out) |
| | |
| | |
| | |
| |
|
| | if len(filepath)>1: |
| | b1 = gr.DownloadButton(visible=True, value = zip_path) |
| | else: |
| | b1 = gr.DownloadButton(visible=True, value = fname_masks) |
| | b2 = gr.DownloadButton(visible=True, value = fname_out) |
| | |
| | return outpix, flows, b1, b2 |
| |
|
| | def download_function(): |
| | b1 = gr.DownloadButton("Download masks as TIFF", visible=False) |
| | b2 = gr.DownloadButton("Download outline image as PNG", visible=False) |
| | return b1, b2 |
| |
|
| | def tif_view(filepath): |
| | fpath, fext = os.path.splitext(filepath) |
| | if fext in ['tiff', 'tif']: |
| | img = imread(filepath[-1]) |
| | if img.ndim==2: |
| | img = np.tile(img[:,:,np.newxis], [1,1,3]) |
| | elif img.ndim==3: |
| | imin = np.argmin(img.shape) |
| | if imin<2: |
| | img = np.tranpose(img, [2, imin]) |
| | else: |
| | raise ValueError("TIF cannot have more than three dimensions") |
| |
|
| | Ly, Lx, nchan = img.shape |
| | imgi = np.zeros((Ly, Lx, 3)) |
| | nn = np.minimum(3, img.shape[-1]) |
| | imgi[:,:,:nn] = img[:,:,:nn] |
| | |
| | |
| | imsave(filepath, imgi) |
| | return filepath |
| |
|
| | def norm_path(filepath): |
| | img = imread(filepath) |
| | img = normalize99(img) |
| | img = np.clip(img, 0, 1) |
| | fpath, fext = os.path.splitext(filepath) |
| | filepath = fpath +'.png' |
| | pil_image = Image.fromarray((255. * img).astype(np.uint8)) |
| | pil_image.save(filepath) |
| | |
| | return filepath |
| | |
| | def update_image(filepath): |
| | for f in filepath: |
| | f = tif_view(f) |
| | filepath_show = norm_path(filepath[-1]) |
| | return filepath_show, filepath, fp0, fp0 |
| |
|
| | def update_button(filepath): |
| | filepath = tif_view(filepath) |
| | filepath_show = norm_path(filepath) |
| | return filepath_show, [filepath], fp0, fp0 |
| | |
| | with gr.Blocks(title = "Hello", |
| | css=".gradio-container {background:purple;}") as demo: |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:20pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular |
| | segmentation <a style="color:#cfe7fe; font-size:14pt;" href="https://www.biorxiv.org/content/10.1101/2025.04.28.651001v1" target="_blank">[paper]</a> |
| | <a style="color:white; font-size:14pt;" href="https://github.com/MouseLand/cellpose" target="_blank">[github]</a> |
| | <a style="color:white; font-size:14pt;" href="https://www.youtube.com/watch?v=KIdYXgQemcI" target="_blank">[talk]</a> |
| | </div>""") |
| | gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day (enough to process hundreds of images). </h4>""") |
| | |
| | input_image = gr.Image(label = "Input", type = "filepath") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | with gr.Row(): |
| | resize = gr.Number(label = 'max resize', value = 1000) |
| | max_iter = gr.Number(label = 'max iterations', value = 250) |
| | flow_threshold = gr.Number(label = 'flow threshold', value = 0.4) |
| | cellprob_threshold = gr.Number(label = 'cellprob threshold', value = 0) |
| | |
| | up_btn = gr.UploadButton("Multi-file upload (png, jpg, tif etc)", visible=True, file_count = "multiple") |
| | |
| | |
| | |
| | with gr.Column(scale=1): |
| | send_btn = gr.Button("Run Cellpose-SAM") |
| | down_btn = gr.DownloadButton("Download masks (TIF)", visible=False) |
| | down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False) |
| | |
| | with gr.Column(scale=2): |
| | outlines = gr.Image(label = "Outlines", type = "pil", format = 'png', value = fp0) |
| | |
| | flows = gr.Image(label = "Cellpose flows", type = "pil", format = 'png', value = fp0) |
| |
|
| | |
| | |
| | sample_list = glob.glob("samples/*.png") |
| | |
| | |
| | |
| | |
| | gr.Examples(sample_list, fn = update_button, inputs=input_image, outputs = [input_image, up_btn, outlines, flows], examples_per_page=50, label = "Click on an example to try it") |
| | input_image.upload(update_button, input_image, [input_image, up_btn, outlines, flows]) |
| | up_btn.upload(update_image, up_btn, [input_image, up_btn, outlines, flows]) |
| | |
| | send_btn.click(cellpose_segment, [up_btn, resize, max_iter, flow_threshold, cellprob_threshold], [outlines, flows, down_btn, down_btn2]) |
| |
|
| | |
| | |
| | gr.HTML("""<h4 style="color:white;"> Notes:<br> |
| | <li>you can load and process 2D, multi-channel tifs. |
| | <li>the smallest dimension of a tif --> channels |
| | <li>you can upload multiple files and download a zip of the segmentations |
| | <li>install Cellpose-SAM locally for full functionality. |
| | </h4>""") |
| | |
| | |
| | demo.launch() |
| |
|