gradio / cxr_gradio.py
eprakash's picture
Upload folder using huggingface_hub
711ffc5 verified
raw
history blame
4.42 kB
import gradio as gr
import random
import time
import os
from glob import glob
from PIL import Image
import torchvision.transforms as transforms
image_prefix = "/deep/u/eprakash/AngioSeg/diffusion/cxr_synthetic_data_25_no_transform/synth/"
image_ids = glob(os.path.join(image_prefix, '*' + '.png'))
image_ids = list(set([os.path.splitext(os.path.basename(p))[0].split("_")[0] for p in image_ids]))
save_path = "cxr_ranks"
def load_img(img_path, size=1024):
img = Image.open(img_path).convert('RGB')
transform_list = [transforms.Resize((size, size))]
transform = transforms.Compose(transform_list)
img = transform(img)
return img
def find_completed_idxs(save_path=save_path):
files = os.listdir(save_path)
if len(files) == 0:
return [-1]
else:
file_list = []
for f in files:
f = int(f.split(".")[0])
file_list.append(f)
file_list = sorted(file_list)
return file_list
def load_next(rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example, ids=image_ids, image_prefix=image_prefix, save_path=save_path):
if int(example) == len(image_ids) - 1:
return [None, None, None, None, None, None, None]
else:
file_list = find_completed_idxs()
if (int(example) not in file_list):
r = str(image_ids[int(example)]) + "," + rank
r_fp = open(save_path + "/" + str(int(example)) +".txt", "w")
r_fp.write(r + "\n")
r_fp.close()
file_list = find_completed_idxs()
example = file_list[-1] + 1
rank = ""
img_1 = gr.Image(label="Sample #1", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_0.png"), interactive=False)
mask_1 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_1.png"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img(image_prefix+ str(image_ids[int(example)]) + "_synthetic_1.png"), interactive=False)
mask_2 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_2.png"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_2.png"), interactive=False)
mask_3 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_3.png"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_3.png"), interactive=False)
mask_4 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_4.png"), interactive=False)
return [rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example]
with gr.Blocks() as demo:
last_idx = -1
example = gr.Number(label="Example #. Click next for #-1 (blank starting page).", value=last_idx, interactive=False)
rank = gr.Textbox(label="Rankings (Best to worst, comma-separated, no spaces).")
with gr.Column(scale=1):
with gr.Row():
mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
next_btn = gr.Button(value="Next")
next_btn.click(fn=load_next, inputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], outputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], queue=False)
demo.queue()
demo.launch(share=True)