from datasets import load_dataset, get_dataset_config_names
from functools import partial
from pandas import DataFrame
import earthview as ev
import gradio as gr 
import tqdm
import os

DEBUG = False

if DEBUG:
    import numpy as np

def open_dataset(dataset, set_name, split, batch_size, state, shard = -1):
    if shard == -1:
        # Trick to open the whole dataset
        data_files = None
        shards = 100
    else:
        config = ev.sets[set_name].get("config", set_name)
        shards = ev.sets[set_name]["shards"]
        path   = ev.sets[set_name].get("path", set_name)
        data_files = {"train":[f"{path}/{split}-{shard:05d}-of-{shards:05d}.parquet"]}

    if DEBUG:
        ds  = lambda:None
        ds.n_shards = 1234
        dsi = range(100)
    else:
        ds = load_dataset(
            dataset,
            config,
            split=split,
            cache_dir="dataset",
            data_files=data_files,
            streaming=True,
            token=os.environ.get("HF_TOKEN", None))
    
        dsi = iter(ds)

    state["config"]  = config
    state["dsi"] = dsi
    return (
        gr.update(label=f"Shards (max {shards})", value=shard, maximum=shards),
        *get_images(batch_size, state),
        state
    )
    
def get_images(batch_size, state):
    config = state["config"]

    images = []
    metadatas = []
    
    for i in tqdm.trange(batch_size, desc=f"Getting images"):
        if DEBUG:
            image = np.random.randint(0,255,(384,384,3))
            metadata = {"bounds":[[1,1,4,4]], }
        else:
            try:
                item = next(state["dsi"])
            except StopIteration:
                break
            metadata = item["metadata"]
            item = ev.item_to_images(config, item)

            if  config == "satellogic":
                images.extend(item["rgb"])
                # images.extend(item["1m"])
            if  config == "sentinel_1":
                images.extend(item["10m"])
            if  config == "default":
                images.extend(item["rgb"])
                images.extend(item["chm"])
                images.extend(item["1m"])
            metadatas.append(item["metadata"])

    return images, DataFrame(metadatas)

def update_shape(rows, columns):
    return gr.update(rows=rows, columns=columns)

def new_state():
    return gr.State({})

if __name__ == "__main__":
    with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo:
        state = new_state()

        gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset")
        batch_size = gr.Number(10, label = "Batch Size", render=False)
        shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False)
        table = gr.DataFrame(render = False)
        # headers=["Index","TimeStamp","Bounds","CRS"], 

        gallery = gr.Gallery(
            label=ev.DATASET,
            interactive=False,
            columns=5, rows=2, render=False)

        with gr.Row():
            dataset = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
            config = gr.Dropdown(choices=ev.get_sets(), label="Config", value="satellogic", )
            split = gr.Textbox(label="Split", value="train")
            initial_shard = gr.Number(label = "Initial shard", value=10, info="-1 for whole dataset")

            gr.Button("Load (minutes)").click(
                open_dataset,
                inputs=[dataset, config, split, batch_size, state, initial_shard],
                outputs=[shard, gallery, table, state])

        gallery.render()
        
        with gr.Row():
            batch_size.render()

            rows = gr.Number(2, label="Rows")
            columns = gr.Number(5, label="Coluns")

            rows.change(update_shape, [rows, columns], [gallery])
            columns.change(update_shape, [rows, columns], [gallery])

        with gr.Row():
            shard.render()
            shard.release(
                open_dataset,
                inputs=[dataset, config, split, batch_size, state, shard],
                outputs=[shard, gallery, table, state])

            btn = gr.Button("Next Batch (same shard)", scale=0)
            btn.click(get_images, [batch_size, state], [gallery, table])
            btn.click()
        
        table.render()

    demo.launch(show_api=False)