gradio / lung_seg_gradio.py
eprakash's picture
Upload folder using huggingface_hub
711ffc5 verified
raw
history blame
6.26 kB
import gradio as gr
import random
import time
import os
from glob import glob
from PIL import Image
import torchvision.transforms as transforms
num_rank = 200
image_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/synth/"
mask_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/orig/"
image_ids = []
img_list = "/deep/u/eprakash/lung_seg/train_60.csv"
with open(img_list) as fp:
for line in fp:
image_ids.append("('" + line.strip().split(",")[0] + "',)")
image_ids = image_ids[301:501]
save_path = "lung_seg_ranks"
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
def load_img(img_path, size=512):
img = Image.open(img_path).convert('RGB')
transform_list = [transforms.Resize((size, size))]
transform = transforms.Compose(transform_list)
img = transform(img)
return img
def find_completed_idxs(save_path=save_path):
files = os.listdir(save_path)
incorrect_files = []
if len(files) == 0:
return [-1], []
else:
file_list = []
for f in files:
f_name = int(f.split(".")[0])
with open(save_path + "/" + f) as fp:
for line in fp:
items = line.strip().split(",")
if (len(items) != 5 and f_name != -1):
incorrect_files.append(f_name)
else:
if ((not is_int(items[1].strip()) or not is_int(items[2].strip()) or not is_int(items[3].strip()) or not is_int(items[4].strip())) and f_name != -1):
incorrect_files.append(f_name)
file_list.append(f_name)
file_list = sorted(file_list)
incorrect_files = sorted(incorrect_files)
return file_list, incorrect_files
def load_next(rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example, ids=image_ids, image_prefix=image_prefix, save_path=save_path):
file_list, incorrect_files = find_completed_idxs()
print(str(file_list) + " " + str(incorrect_files))
if (int(example) not in file_list or int(example) in incorrect_files):
r = str(image_ids[int(example)]).split(",")[0].split("(")[1] + "," + rank
r_fp = open(save_path + "/" + str(int(example)) +".txt", "w")
r_fp.write(r + "\n")
r_fp.close()
file_list, incorrect_files = find_completed_idxs()
if (len(incorrect_files) != 0):
example = incorrect_files[-1]
else:
example = file_list[-1] + 1
if int(example) == num_rank:
rank = "DONE!"
example = -1
mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
else:
rank = ""
img_1 = gr.Image(label="Sample #1", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_0.png"), interactive=False)
mask_1 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img(image_prefix+ str(image_ids[int(example)]) + "_synthetic_1.png"), interactive=False)
mask_2 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_2.png"), interactive=False)
mask_3 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_3.png"), interactive=False)
mask_4 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
return [rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example]
with gr.Blocks() as demo:
last_idx = -1
example = gr.Number(label="Example #. Click next for #-1 (blank starting page).", value=last_idx, interactive=False)
rank = gr.Textbox(label="Rankings (Best to worst, comma-separated, no spaces).")
with gr.Column(scale=1):
with gr.Row():
mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
next_btn = gr.Button(value="Next")
next_btn.click(fn=load_next, inputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], outputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], queue=False)
demo.queue()
demo.launch(share=True)