lit-demo-bv / app.py
andsteing's picture
Reformatting, remove TODO.
2805894
"""Gradio clone of https://google-research.github.io/vision_transformer/lit/.
Features:
- Models are downloaded dynamically.
- Models are cached on local disk, and in RAM.
- Progress bars when downloading/reading/computing.
- Dynamic update of model controls.
- Dynamic generation of output sliders.
- Use of `gr.State()` for better use of progress bars.
"""
import dataclasses
import functools
import json
import logging
import os
import time
import urllib.request
import gradio as gr
import PIL.Image
# pylint: disable=g-bad-import-order
import big_vision_contrastive_models as models
import gradio_helpers
INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'
MAX_ANSWERS = 10
MAX_DISK_CACHE = 20e9
MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM
LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
# family/variant/res -> name
MODEL_MAP = {
'lit': {
'B/16': {
224: 'lit_b16b',
},
'L/16': {
224: 'lit_l16l',
},
},
'siglip': {
'B/16': {
224: 'siglip_b16b_224',
256: 'siglip_b16b_256',
384: 'siglip_b16b_384',
512: 'siglip_b16b_512',
},
'L/16': {
256: 'siglip_l16l_256',
384: 'siglip_l16l_384',
},
'So400m/14': {
224: 'siglip_so400m14so440m_224',
384: 'siglip_so400m14so440m_384',
},
},
}
def get_cache_status():
"""Returns a string summarizing cache status."""
mem_n, mem_sz = gradio_helpers.get_memory_cache_info()
disk_n, disk_sz = gradio_helpers.get_disk_cache_info()
return (
f'memory cache {mem_n} items [{mem_sz/1e9:.2f}G], '
f'disk cache {disk_n} items [{disk_sz/1e9:.2f}G]'
)
def compute(
image_path, prompts, family, variant, res, bias, progress=gr.Progress()
):
"""Loads model and computes answers."""
if image_path is None:
raise gr.Error('Must first select an image!')
t0 = time.monotonic()
model_name = MODEL_MAP[family][variant][res]
config = models.MODEL_CONFIGS[model_name]
local_ckpt = gradio_helpers.get_disk_cache(
config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
config = dataclasses.replace(config, ckpt=local_ckpt)
params, model = gradio_helpers.get_memory_cache(
config,
lambda: models.load_model(config),
max_cache_size_bytes=MAX_RAM_CACHE,
progress=progress,
estimated_secs={
('lit', 'B/16'): 1,
('lit', 'L/16'): 2.5,
('siglip', 'B/16'): 9,
('siglip', 'L/16'): 28,
('siglip', 'So400m/14'): 36,
}.get((family, variant))
)
model: models.ContrastiveModel = model
it = progress.tqdm(list(range(3)), desc='compute')
logging.info('Opening image "%s"', image_path)
with gradio_helpers.timed(f'opening image "{image_path}"'):
image = PIL.Image.open(image_path)
next(it)
with gradio_helpers.timed('image features'):
zimg, unused_out = model.embed_images(
params, model.preprocess_images([image])
)
next(it)
with gradio_helpers.timed('text features'):
prompts = prompts.split('\n')
ztxt, out = model.embed_texts(
params, model.preprocess_texts(prompts)
)
next(it)
t = model.get_temperature(out)
text_probs = []
if family == 'lit':
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
elif family == 'siglip':
text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0])
state = list(zip(prompts, [round(p.item(), 3) for p in text_probs]))
dt = time.monotonic() - t0
status = gr.Markdown(
f'Computed inference in {dt:.1f} seconds ({get_cache_status()})')
if 'b' in out:
logging.info('model_name=%s default bias=%f', model_name, out['b'])
return status, state
def update_answers(state):
"""Generates visible sliders for answers."""
answers = []
for prompt, prob in state[:MAX_ANSWERS]:
answers.append(
gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
while len(answers) < MAX_ANSWERS:
answers.append(gr.Slider(visible=False))
return answers
def create_app():
"""Creates demo UI."""
css = '''
.slider input[type="number"] { width: 5em; }
#examples td.textbox > div {
white-space: pre-wrap !important;
text-align: left;
}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown(
'Gradio clone of the original '
'[LiT demo](https://google-research.github.io/vision_transformer/lit/).'
)
status = gr.Markdown(f'Ready ({get_cache_status()})')
with gr.Row():
image = gr.Image(label='Image', type='filepath')
source = gr.Markdown('', visible=False)
state = gr.State([])
with gr.Column():
prompts = gr.Textbox(
label='Prompts (press Shift-ENTER to add a prompt)')
with gr.Row():
family = gr.Dropdown(
value='lit', choices=list(MODEL_MAP), label='Model family')
make_variant = functools.partial(gr.Dropdown, label='Variant')
variant = make_variant(list(MODEL_MAP['lit']), value='B/16')
make_res = functools.partial(gr.Dropdown, label='Resolution')
res = make_res(list(MODEL_MAP['lit']['B/16']), value=224)
def make_bias(family, variant, res):
visible = family == 'siglip'
value = {
('siglip', 'B/16', 224): -12.9,
('siglip', 'L/16', 256): -12.7,
('siglip', 'L/16', 256): -16.5,
# ...
}.get((family, variant, res), -10.0)
return gr.Slider(
value=value,
minimum=-20,
maximum=0,
step=0.05,
label='Bias',
visible=visible,
)
bias = make_bias(family.value, variant.value, res.value)
def update_inputs(family, variant, res):
d = MODEL_MAP[family]
variants = list(d)
variant = variant if variant in variants else variants[0]
d = d[variant]
ress = list(d)
res = res if res in ress else ress[0]
return [
make_variant(variants, value=variant),
make_res(ress, value=res),
make_bias(family, variant, res),
]
gr.on(
[family.change, variant.change, res.change],
update_inputs,
[family, variant, res],
[variant, res, bias],
)
# (end of code for reactive UI)
run = gr.Button('Run')
answers = [
# Will be set to visible in `update_answers()`.
gr.Slider(0, 100, 0, visible=False, elem_classes='slider')
for _ in range(MAX_ANSWERS)
]
# We want to avoid showing multiple progress bars, so we only update
# a single `status` widget here, and store the computed information in
# `state`...
run.click(
fn=compute,
inputs=[image, prompts, family, variant, res, bias],
outputs=[status, state],
)
# ... then we use `state` to update UI components without showing a
# progress bar in their place.
status.change(fn=update_answers, inputs=state, outputs=answers)
info = json.load(urllib.request.urlopen(INFO_URL))
gr.Markdown('Note: below images have 224 px resolution only:')
gr.Examples(
examples=[
[
IMG_URL_FMT.format(ex['id']),
ex['prompts'].replace(', ', '\n'),
'[source](%s)' % ex['source'],
]
for ex in info
],
inputs=[image, prompts, source, license],
outputs=answers,
elem_id='examples',
)
return demo
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
for k, v in os.environ.items():
logging.info('environ["%s"] = %r', k, v)
models.setup()
create_app().queue().launch()