"""Gradio clone of https://google-research.github.io/vision_transformer/lit/.

Features:

- Models are downloaded dynamically.
- Models are cached on local disk, and in RAM.
- Progress bars when downloading/reading/computing.
- Dynamic update of model controls.
- Dynamic generation of output sliders.
- Use of `gr.State()` for better use of progress bars.
"""
import dataclasses
import functools
import json
import logging
import os
import time
import urllib.request

import gradio as gr
import PIL.Image

# pylint: disable=g-bad-import-order
import big_vision_contrastive_models as models
import gradio_helpers


INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'
MAX_ANSWERS = 10

MAX_DISK_CACHE = 20e9
MAX_RAM_CACHE = 10e9  # CPU basic has 16G RAM

LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}


# family/variant/res -> name
MODEL_MAP = {
    'lit': {
        'B/16': {
            224: 'lit_b16b',
        },
        'L/16': {
            224: 'lit_l16l',
        },
    },
    'siglip': {
        'B/16': {
            224: 'siglip_b16b_224',
            256: 'siglip_b16b_256',
            384: 'siglip_b16b_384',
            512: 'siglip_b16b_512',
        },
        'L/16': {
            256: 'siglip_l16l_256',
            384: 'siglip_l16l_384',
        },
        'So400m/14': {
            224: 'siglip_so400m14so440m_224',
            384: 'siglip_so400m14so440m_384',
        },
    },
}


def get_cache_status():
  """Returns a string summarizing cache status."""
  mem_n, mem_sz = gradio_helpers.get_memory_cache_info()
  disk_n, disk_sz = gradio_helpers.get_disk_cache_info()
  return (
      f'memory cache {mem_n} items [{mem_sz/1e9:.2f}G], '
      f'disk cache {disk_n} items [{disk_sz/1e9:.2f}G]'
  )


def compute(
    image_path, prompts, family, variant, res, bias, progress=gr.Progress()
):
  """Loads model and computes answers."""

  if image_path is None:
    raise gr.Error('Must first select an image!')

  t0 = time.monotonic()

  model_name = MODEL_MAP[family][variant][res]
  config = models.MODEL_CONFIGS[model_name]
  local_ckpt = gradio_helpers.get_disk_cache(
      config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
  config = dataclasses.replace(config, ckpt=local_ckpt)
  params, model = gradio_helpers.get_memory_cache(
      config,
      lambda: models.load_model(config),
      max_cache_size_bytes=MAX_RAM_CACHE,
      progress=progress,
      estimated_secs={
          ('lit', 'B/16'): 1,
          ('lit', 'L/16'): 2.5,
          ('siglip', 'B/16'): 9,
          ('siglip', 'L/16'): 28,
          ('siglip', 'So400m/14'): 36,
      }.get((family, variant))
  )
  model: models.ContrastiveModel = model

  it = progress.tqdm(list(range(3)), desc='compute')

  logging.info('Opening image "%s"', image_path)
  with gradio_helpers.timed(f'opening image "{image_path}"'):
    image = PIL.Image.open(image_path)
    next(it)
  with gradio_helpers.timed('image features'):
    zimg, unused_out = model.embed_images(
        params, model.preprocess_images([image])
    )
    next(it)
  with gradio_helpers.timed('text features'):
    prompts = prompts.split('\n')
    ztxt, out = model.embed_texts(
        params, model.preprocess_texts(prompts)
    )
    next(it)

  t = model.get_temperature(out)
  text_probs = []
  if family == 'lit':
    text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
  elif family == 'siglip':
    text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0])

  state = list(zip(prompts, [round(p.item(), 3) for p in text_probs]))

  dt = time.monotonic() - t0
  status = gr.Markdown(
      f'Computed inference in {dt:.1f} seconds ({get_cache_status()})')

  if 'b' in out:
    logging.info('model_name=%s default bias=%f', model_name, out['b'])

  return status, state


def update_answers(state):
  """Generates visible sliders for answers."""
  answers = []
  for prompt, prob in state[:MAX_ANSWERS]:
    answers.append(
        gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
  while len(answers) < MAX_ANSWERS:
    answers.append(gr.Slider(visible=False))
  return answers


def create_app():
  """Creates demo UI."""

  css = '''
  .slider input[type="number"] { width: 5em; }
  #examples td.textbox > div {
    white-space: pre-wrap !important;
    text-align: left;
  }
  '''

  with gr.Blocks(css=css) as demo:

    gr.Markdown(
        'Gradio clone of the original '
        '[LiT demo](https://google-research.github.io/vision_transformer/lit/).'
    )

    status = gr.Markdown(f'Ready ({get_cache_status()})')

    with gr.Row():
      image = gr.Image(label='Image', type='filepath')
      source = gr.Markdown('', visible=False)
      state = gr.State([])
      with gr.Column():
        prompts = gr.Textbox(
            label='Prompts (press Shift-ENTER to add a prompt)')
        with gr.Row():

          family = gr.Dropdown(
              value='lit', choices=list(MODEL_MAP), label='Model family')

          make_variant = functools.partial(gr.Dropdown, label='Variant')
          variant = make_variant(list(MODEL_MAP['lit']), value='B/16')

          make_res = functools.partial(gr.Dropdown, label='Resolution')
          res = make_res(list(MODEL_MAP['lit']['B/16']), value=224)

          def make_bias(family, variant, res):
            visible = family == 'siglip'
            value = {
                ('siglip', 'B/16', 224): -12.9,
                ('siglip', 'L/16', 256): -12.7,
                ('siglip', 'L/16', 256): -16.5,
                # ...
            }.get((family, variant, res), -10.0)
            return gr.Slider(
                value=value,
                minimum=-20,
                maximum=0,
                step=0.05,
                label='Bias',
                visible=visible,
            )
          bias = make_bias(family.value, variant.value, res.value)

          def update_inputs(family, variant, res):
            d = MODEL_MAP[family]
            variants = list(d)
            variant = variant if variant in variants else variants[0]
            d = d[variant]
            ress = list(d)
            res = res if res in ress else ress[0]
            return [
                make_variant(variants, value=variant),
                make_res(ress, value=res),
                make_bias(family, variant, res),
            ]

          gr.on(
              [family.change, variant.change, res.change],
              update_inputs,
              [family, variant, res],
              [variant, res, bias],
          )

          # (end of code for reactive UI)

        run = gr.Button('Run')
        answers = [
            # Will be set to visible in `update_answers()`.
            gr.Slider(0, 100, 0, visible=False, elem_classes='slider')
            for _ in range(MAX_ANSWERS)
        ]

        # We want to avoid showing multiple progress bars, so we   only update
        # a single `status` widget here, and store the computed information in
        # `state`...
        run.click(
            fn=compute,
            inputs=[image, prompts, family, variant, res, bias],
            outputs=[status, state],
        )
        # ... then we use `state` to update UI components without showing a
        # progress bar in their place.
        status.change(fn=update_answers, inputs=state, outputs=answers)

    info = json.load(urllib.request.urlopen(INFO_URL))
    gr.Markdown('Note: below images have 224 px resolution only:')
    gr.Examples(
        examples=[
            [
                IMG_URL_FMT.format(ex['id']),
                ex['prompts'].replace(', ', '\n'),
                '[source](%s)' % ex['source'],
            ]
            for ex in info
        ],
        inputs=[image, prompts, source, license],
        outputs=answers,
        elem_id='examples',
    )

  return demo


if __name__ == '__main__':

  logging.basicConfig(level=logging.INFO,
                      format='%(asctime)s - %(levelname)s - %(message)s')

  for k, v in os.environ.items():
    logging.info('environ["%s"] = %r', k, v)

  models.setup()

  create_app().queue().launch()