from datasets import load_dataset, IterableDataset from functools import partial from pandas import DataFrame import gradio as gr import numpy as np import tqdm import json import os DEBUG = False sets = { "satellogic": { "shards" : 3676, }, "sentinel_1": { "shards" : 1763, }, "neon": { "config" : "default", "shards" : 607, "path" : "data", } } def open_dataset(dataset, set_name, split, batch_size, shard = -1): global dsi, ds if shard == -1: data_files = None shards = 100 else: config = sets[set_name].get("config", set_name) shards = sets[set_name]["shards"] path = sets[set_name].get("path", set_name) data_files = {"train":[f"{path}/{split}-{shard:05d}-of-{shards:05d}.parquet"]} if DEBUG: ds = lambda:None ds.n_shards = 1234 dsi = range(100) else: ds = load_dataset( dataset, config, split=split, cache_dir="dataset", data_files=data_files, streaming=True, use_auth_token=os.environ.get("HF_TOKEN")) dsi = iter(ds) return ( gr.update(label=f"Shards (max {shards})", value=shard, maximum=shards), *get_images(batch_size) ) def get_images(batch_size): global dsi items = [] metadatas = [] for i in tqdm.trange(batch_size, desc=f"Getting images"): if DEBUG: image = np.random.randint(0,255,(384,384,3)) metadata = {"bounds":[[1,1,4,4]], } else: try: item = next(dsi) except StopIteration: break metadata = item["metadata"] if ds.config_name == "satellogic": image = np.asarray(item["rgb"][0]).astype(np.uint8) items.append(image.transpose(1,2,0)) if ds.config_name == "sentinel_1": metadata = json.loads(metadata) data = np.asarray(item["10m"]) for i in range(data.shape[0]): # Mapping of V and H to RGB. May not be correct # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels image = np.zeros((3,384,384), "uint8") image[0] = data[i][0] image[1] = data[i][1] image[2] = (image[0]/(image[1]+0.1))*256 items.append(image.transpose(1,2,0)) if ds.config_name == "default": dataRGB = np.asarray(item["rgb"]).astype("uint8") dataCHM = np.asarray(item["chm"]).astype("uint8") data1m = np.asarray(item["1m"]).astype("uint8") for i in range(dataRGB.shape[0]): image = dataRGB[i,:,:,:] items.append(image.transpose(1,2,0)) image = dataCHM[i,0,:,:] items.append(image) image = data1m[i,0,:,:] items.append(image) metadatas.append(metadata) return items, DataFrame(metadatas) def skip(count, batch_size): global dsi skip = count*batch_size gr.Info(f"Skipping {skip} images (it's slow)") for i in tqdm.trange(skip, desc=f"Skipping {skip} images"): if DEBUG: pass else: next(dsi) return get_images(batch_size) def update_shape(rows, columns): return gr.update(rows=rows, columns=columns) with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo: batch_size = gr.Number(10, label = "Batch Size", render=False) shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False) table = gr.DataFrame(render = False) # headers=["Index","TimeStamp","Bounds","CRS"], gallery = gr.Gallery( label="satellogic/EarthView", interactive=False, columns=5, rows=2, render=False) with gr.Row(): dataset = gr.Textbox(label="Dataset", value="satellogic/EarthView") config = gr.Dropdown(choices=["satellogic", "sentinel_1", "neon"], label="Subset", value="satellogic", ) split = gr.Textbox(label="Split", value="train") initial_shard = gr.Number(label = "Initial shard", value=0) gr.Button("Load (minutes)").click( open_dataset, inputs=[dataset, config, split, batch_size, initial_shard], outputs=[shard, gallery, table]) gallery.render() with gr.Row(): batch_size.render() rows = gr.Number(2, label="Rows") columns = gr.Number(5, label="Coluns") rows.change(update_shape, [rows, columns], [gallery]) columns.change(update_shape, [rows, columns], [gallery]) with gr.Row(): shard.render() shard.release( open_dataset, inputs=[dataset, config, split, batch_size, shard], outputs=[shard, gallery, table]) btn = gr.Button("Get More Images", scale=0) btn.click(get_images, [batch_size], [gallery, table]) btn.click() # btn = gr.Button("Skip 10 Batches", scale=0) # btn.click(partial(skip, 10), [batch], gallery) # btn = gr.Button("Skip 25 Batches", scale=0) # btn.click(partial(skip, 25), [batch], gallery) table.render() demo.launch(show_api=False)