Spaces:
Sleeping
Sleeping
"""Gradio clone of https://google-research.github.io/vision_transformer/lit/. | |
Features: | |
- Models are downloaded dynamically. | |
- Models are cached on local disk, and in RAM. | |
- Progress bars when downloading/reading/computing. | |
- Dynamic update of model controls. | |
- Dynamic generation of output sliders. | |
- Use of `gr.State()` for better use of progress bars. | |
""" | |
import dataclasses | |
import functools | |
import json | |
import logging | |
import os | |
import time | |
import urllib.request | |
import gradio as gr | |
import PIL.Image | |
# pylint: disable=g-bad-import-order | |
import big_vision_contrastive_models as models | |
import gradio_helpers | |
INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json' | |
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg' | |
MAX_ANSWERS = 10 | |
MAX_DISK_CACHE = 20e9 | |
MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM | |
LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10} | |
# family/variant/res -> name | |
MODEL_MAP = { | |
'lit': { | |
'B/16': { | |
224: 'lit_b16b', | |
}, | |
'L/16': { | |
224: 'lit_l16l', | |
}, | |
}, | |
'siglip': { | |
'B/16': { | |
224: 'siglip_b16b_224', | |
256: 'siglip_b16b_256', | |
384: 'siglip_b16b_384', | |
512: 'siglip_b16b_512', | |
}, | |
'L/16': { | |
256: 'siglip_l16l_256', | |
384: 'siglip_l16l_384', | |
}, | |
'So400m/14': { | |
224: 'siglip_so400m14so440m_224', | |
384: 'siglip_so400m14so440m_384', | |
}, | |
}, | |
} | |
def get_cache_status(): | |
"""Returns a string summarizing cache status.""" | |
mem_n, mem_sz = gradio_helpers.get_memory_cache_info() | |
disk_n, disk_sz = gradio_helpers.get_disk_cache_info() | |
return ( | |
f'memory cache {mem_n} items [{mem_sz/1e9:.2f}G], ' | |
f'disk cache {disk_n} items [{disk_sz/1e9:.2f}G]' | |
) | |
def compute( | |
image_path, prompts, family, variant, res, bias, progress=gr.Progress() | |
): | |
"""Loads model and computes answers.""" | |
if image_path is None: | |
raise gr.Error('Must first select an image!') | |
t0 = time.monotonic() | |
model_name = MODEL_MAP[family][variant][res] | |
config = models.MODEL_CONFIGS[model_name] | |
local_ckpt = gradio_helpers.get_disk_cache( | |
config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE) | |
config = dataclasses.replace(config, ckpt=local_ckpt) | |
params, model = gradio_helpers.get_memory_cache( | |
config, | |
lambda: models.load_model(config), | |
max_cache_size_bytes=MAX_RAM_CACHE, | |
progress=progress, | |
estimated_secs={ | |
('lit', 'B/16'): 1, | |
('lit', 'L/16'): 2.5, | |
('siglip', 'B/16'): 9, | |
('siglip', 'L/16'): 28, | |
('siglip', 'So400m/14'): 36, | |
}.get((family, variant)) | |
) | |
model: models.ContrastiveModel = model | |
it = progress.tqdm(list(range(3)), desc='compute') | |
logging.info('Opening image "%s"', image_path) | |
with gradio_helpers.timed(f'opening image "{image_path}"'): | |
image = PIL.Image.open(image_path) | |
next(it) | |
with gradio_helpers.timed('image features'): | |
zimg, unused_out = model.embed_images( | |
params, model.preprocess_images([image]) | |
) | |
next(it) | |
with gradio_helpers.timed('text features'): | |
prompts = prompts.split('\n') | |
ztxt, out = model.embed_texts( | |
params, model.preprocess_texts(prompts) | |
) | |
next(it) | |
t = model.get_temperature(out) | |
text_probs = [] | |
if family == 'lit': | |
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0]) | |
elif family == 'siglip': | |
text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0]) | |
state = list(zip(prompts, [round(p.item(), 3) for p in text_probs])) | |
dt = time.monotonic() - t0 | |
status = gr.Markdown( | |
f'Computed inference in {dt:.1f} seconds ({get_cache_status()})') | |
if 'b' in out: | |
logging.info('model_name=%s default bias=%f', model_name, out['b']) | |
return status, state | |
def update_answers(state): | |
"""Generates visible sliders for answers.""" | |
answers = [] | |
for prompt, prob in state[:MAX_ANSWERS]: | |
answers.append( | |
gr.Slider(value=round(100*prob, 2), label=prompt, visible=True)) | |
while len(answers) < MAX_ANSWERS: | |
answers.append(gr.Slider(visible=False)) | |
return answers | |
def create_app(): | |
"""Creates demo UI.""" | |
css = ''' | |
.slider input[type="number"] { width: 5em; } | |
#examples td.textbox > div { | |
white-space: pre-wrap !important; | |
text-align: left; | |
} | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
'Gradio clone of the original ' | |
'[LiT demo](https://google-research.github.io/vision_transformer/lit/).' | |
) | |
status = gr.Markdown(f'Ready ({get_cache_status()})') | |
with gr.Row(): | |
image = gr.Image(label='Image', type='filepath') | |
source = gr.Markdown('', visible=False) | |
state = gr.State([]) | |
with gr.Column(): | |
prompts = gr.Textbox( | |
label='Prompts (press Shift-ENTER to add a prompt)') | |
with gr.Row(): | |
family = gr.Dropdown( | |
value='lit', choices=list(MODEL_MAP), label='Model family') | |
make_variant = functools.partial(gr.Dropdown, label='Variant') | |
variant = make_variant(list(MODEL_MAP['lit']), value='B/16') | |
make_res = functools.partial(gr.Dropdown, label='Resolution') | |
res = make_res(list(MODEL_MAP['lit']['B/16']), value=224) | |
def make_bias(family, variant, res): | |
visible = family == 'siglip' | |
value = { | |
('siglip', 'B/16', 224): -12.9, | |
('siglip', 'L/16', 256): -12.7, | |
('siglip', 'L/16', 256): -16.5, | |
# ... | |
}.get((family, variant, res), -10.0) | |
return gr.Slider( | |
value=value, | |
minimum=-20, | |
maximum=0, | |
step=0.05, | |
label='Bias', | |
visible=visible, | |
) | |
bias = make_bias(family.value, variant.value, res.value) | |
def update_inputs(family, variant, res): | |
d = MODEL_MAP[family] | |
variants = list(d) | |
variant = variant if variant in variants else variants[0] | |
d = d[variant] | |
ress = list(d) | |
res = res if res in ress else ress[0] | |
return [ | |
make_variant(variants, value=variant), | |
make_res(ress, value=res), | |
make_bias(family, variant, res), | |
] | |
gr.on( | |
[family.change, variant.change, res.change], | |
update_inputs, | |
[family, variant, res], | |
[variant, res, bias], | |
) | |
# (end of code for reactive UI) | |
run = gr.Button('Run') | |
answers = [ | |
# Will be set to visible in `update_answers()`. | |
gr.Slider(0, 100, 0, visible=False, elem_classes='slider') | |
for _ in range(MAX_ANSWERS) | |
] | |
# We want to avoid showing multiple progress bars, so we only update | |
# a single `status` widget here, and store the computed information in | |
# `state`... | |
run.click( | |
fn=compute, | |
inputs=[image, prompts, family, variant, res, bias], | |
outputs=[status, state], | |
) | |
# ... then we use `state` to update UI components without showing a | |
# progress bar in their place. | |
status.change(fn=update_answers, inputs=state, outputs=answers) | |
info = json.load(urllib.request.urlopen(INFO_URL)) | |
gr.Markdown('Note: below images have 224 px resolution only:') | |
gr.Examples( | |
examples=[ | |
[ | |
IMG_URL_FMT.format(ex['id']), | |
ex['prompts'].replace(', ', '\n'), | |
'[source](%s)' % ex['source'], | |
] | |
for ex in info | |
], | |
inputs=[image, prompts, source, license], | |
outputs=answers, | |
elem_id='examples', | |
) | |
return demo | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s') | |
for k, v in os.environ.items(): | |
logging.info('environ["%s"] = %r', k, v) | |
models.setup() | |
create_app().queue().launch() | |