import gradio as gr
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from PIL import Image
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

feature_extractor = SegformerFeatureExtractor.from_pretrained("zoheb/mit-b5-finetuned-sidewalk-semantic")
model = SegformerForSemanticSegmentation.from_pretrained("zoheb/mit-b5-finetuned-sidewalk-semantic")

def sidewalk_palette():
    """Sidewalk palette that maps each class to RGB values."""
    return [
        [0, 0, 0],
        [216, 82, 24],
        [255, 255, 0],
        [125, 46, 141],
        [118, 171, 47],
        [161, 19, 46],
        [255, 0, 0],
        [0, 128, 128],
        [190, 190, 0],
        [0, 255, 0],
        [0, 0, 255],
        [170, 0, 255],
        [84, 84, 0],
        [84, 170, 0],
        [84, 255, 0],
        [170, 84, 0],
        [170, 170, 0],
        [170, 255, 0],
        [255, 84, 0],
        [255, 170, 0],
        [255, 255, 0],
        [33, 138, 200],
        [0, 170, 127],
        [0, 255, 127],
        [84, 0, 127],
        [84, 84, 127],
        [84, 170, 127],
        [84, 255, 127],
        [170, 0, 127],
        [170, 84, 127],
        [170, 170, 127],
        [170, 255, 127],
        [255, 0, 127],
        [255, 84, 127],
        [255, 170, 127],
    ]

labels_list = []

with open(r'labels.txt', 'r') as fp:
    labels_list.extend(line[:-1] for line in fp)

colormap = np.asarray(sidewalk_palette())

def label_to_color_image(label):
    if label.ndim != 2:
        raise ValueError("Expect 2-D input label")

    if np.max(label) >= len(colormap):
        raise ValueError("label value too large.")

    return colormap[label]

def draw_plot(pred_img, seg):
    fig = plt.figure(figsize=(20, 15))

    grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(pred_img)
    plt.axis('off')

    LABEL_NAMES = np.asarray(labels_list)
    FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

    unique_labels = np.unique(seg.numpy().astype("uint8"))
    ax = plt.subplot(grid_spec[1])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0, labelsize=25)
    return fig

def main(input_img):
	input_img = Image.fromarray(input_img)

	inputs = feature_extractor(images=input_img, return_tensors="pt")
	outputs = model(**inputs)
	logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)

	# First, rescale logits to original image size
	upsampled_logits = nn.functional.interpolate(
		logits,
		size=input_img.size[::-1], # (height, width)
		mode='bilinear',
		align_corners=False
	)

	# Second, apply argmax on the class dimension
	pred_seg = upsampled_logits.argmax(dim=1)[0]

	color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8) # height, width, 3
	palette = np.array(sidewalk_palette())
	for label, color in enumerate(palette):
		color_seg[pred_seg == label, :] = color

	# Show image + mask
	img = np.array(input_img) * 0.5 + color_seg * 0.5
	pred_img = img.astype(np.uint8)    

	return draw_plot(pred_img, pred_seg)

demo = gr.Interface(main, 
                    gr.Image(shape=(200, 200)), 
                    outputs=['plot'], 
					examples=["test.jpg"],
                    allow_flagging='never')

demo.launch()