from __future__ import annotations import random import gradio as gr import matplotlib import matplotlib.pyplot as plt import numpy as np import torch from CCAgT_utils.categories import CategoriesInfos from CCAgT_utils.types.mask import Mask from CCAgT_utils.visualization import plot from PIL import Image from torch import nn from transformers import SegformerFeatureExtractor from transformers import SegformerForSemanticSegmentation from transformers.modeling_outputs import SemanticSegmenterOutput matplotlib.use('Agg') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300' model = SegformerForSemanticSegmentation.from_pretrained( model_hub_name, ).to(device) model.eval() feature_extractor = SegformerFeatureExtractor.from_pretrained( model_hub_name, ) def segment( image: Image.Image, ) -> SemanticSegmenterOutput: inputs = feature_extractor( image, return_tensors='pt', ).to(device) outputs = model(**inputs) return outputs def post_processing( outputs: SemanticSegmenterOutput, target_size: tuple[int, int], ) -> np.ndarray: logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=target_size, mode='bilinear', align_corners=False, ) segmentation_mask = upsampled_logits.argmax(dim=1)[0] return np.array(segmentation_mask) def colorize( mask: Mask, ) -> np.ndarray: return mask.colorized(CategoriesInfos()) / 255 # Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L35 def get_random_crop_coords( height: int, width: int, crop_height: int, crop_width: int, h_start: float, w_start: float, ): y1 = int((height - crop_height + 1) * h_start) y2 = y1 + crop_height x1 = int((width - crop_width + 1) * w_start) x2 = x1 + crop_width return x1, y1, x2, y2 # Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L46 def random_crop( img: np.ndarray, crop_height: int, crop_width: int, h_start: float, w_start: float, ) -> np.ndarray: height, width = img.shape[:2] x1, y1, x2, y2 = get_random_crop_coords( height, width, crop_height, crop_width, h_start, w_start, ) img = img[y1:y2, x1:x2] return img def process_big_images( image: Image.Image, ) -> Mask: '''Process and post-processing for images bigger than 400x300''' img = np.asarray(image) if img.shape[0] > 300 or img.shape[1] > 400: img = random_crop(img, 300, 400, random.random(), random.random()) target_size = (img.shape[0], img.shape[1]) outputs = segment(Image.fromarray(img)) msk = post_processing(outputs, target_size) return img, Mask(msk) def image_with_mask( image: Image.Image, mask: Mask, ) -> plt.Figure: fig = plt.figure(dpi=600) plt.imshow(image) plt.imshow( mask.categorical, cmap=mask.cmap(CategoriesInfos()), vmax=max(mask.unique_ids), vmin=min(mask.unique_ids), interpolation='nearest', alpha=0.4, ) plt.axis('off') return fig def categories_map( mask: Mask, ) -> plt.Figure: fig = plt.figure(dpi=600) handles = plot.create_handles( CategoriesInfos(), selected_categories=mask.unique_ids, ) plt.legend(handles=handles, fontsize=24, loc='center') plt.axis('off') return fig def main(image): image = Image.fromarray(image) img, mask = process_big_images(image) mask_colorized = colorize(mask) fig = image_with_mask(img, mask) return categories_map(mask), Image.fromarray(img), mask_colorized, fig title = 'SegFormer (b3) - CCAgT dataset' description = f""" This is demo for the SegFormer fine-tuned on sub-dataset from [CCAgT dataset](https://huggingface.co/datasets/lapix/CCAgT). This model was trained to segment cervical cells silver-stained (AgNOR technique) images with resolution of 400x300. The model was available at HF hub at [{model_hub_name}](https://huggingface.co/{model_hub_name}). If input an image bigger than 400x300, the demo will random crop it. """ examples = [ [f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'], [f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'], ] + [ [f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg'] for x in {3, 10, 12, 18, 35, 78, 89} ] demo = gr.Interface( main, inputs=[gr.Image()], outputs=[ gr.Plot(label='Categories map'), gr.Image(label='Image'), gr.Image(label='Mask'), gr.Plot(label='Image with mask'), ], title=title, description=description, examples=examples, allow_flagging='never', cache_examples=False, ) if __name__ == '__main__': demo.launch()