Spaces:
Sleeping
Sleeping
Reformatted code a bit.
Browse files- app.py +57 -35
- big_vision_contrastive_models.py +32 -17
- gradio_helpers.py +7 -3
app.py
CHANGED
@@ -19,6 +19,7 @@ import urllib.request
|
|
19 |
import gradio as gr
|
20 |
import PIL.Image
|
21 |
|
|
|
22 |
import big_vision_contrastive_models as models
|
23 |
import gradio_helpers
|
24 |
|
@@ -37,26 +38,26 @@ LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
|
|
37 |
MODEL_MAP = {
|
38 |
'lit': {
|
39 |
'B/16': {
|
40 |
-
|
41 |
},
|
42 |
'L/16': {
|
43 |
-
|
44 |
},
|
45 |
},
|
46 |
'siglip': {
|
47 |
'B/16': {
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
},
|
53 |
'L/16': {
|
54 |
-
|
55 |
-
|
56 |
},
|
57 |
'So400m/14': {
|
58 |
-
|
59 |
-
|
60 |
},
|
61 |
},
|
62 |
}
|
@@ -72,7 +73,9 @@ def get_cache_status():
|
|
72 |
)
|
73 |
|
74 |
|
75 |
-
def compute(
|
|
|
|
|
76 |
"""Loads model and computes answers."""
|
77 |
|
78 |
if image_path is None:
|
@@ -83,7 +86,7 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
|
|
83 |
model_name = MODEL_MAP[family][variant][res]
|
84 |
config = models.MODEL_CONFIGS[model_name]
|
85 |
local_ckpt = gradio_helpers.get_disk_cache(
|
86 |
-
|
87 |
config = dataclasses.replace(config, ckpt=local_ckpt)
|
88 |
params, model = gradio_helpers.get_memory_cache(
|
89 |
config,
|
@@ -91,11 +94,11 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
|
|
91 |
max_cache_size_bytes=MAX_RAM_CACHE,
|
92 |
progress=progress,
|
93 |
estimated_secs={
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
}.get((family, variant))
|
100 |
)
|
101 |
model: models.ContrastiveModel = model
|
@@ -107,18 +110,19 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
|
|
107 |
image = PIL.Image.open(image_path)
|
108 |
next(it)
|
109 |
with gradio_helpers.timed('image features'):
|
110 |
-
zimg,
|
111 |
params, model.preprocess_images([image])
|
112 |
)
|
113 |
next(it)
|
114 |
with gradio_helpers.timed('text features'):
|
115 |
prompts = prompts.split('\n')
|
116 |
ztxt, out = model.embed_texts(
|
117 |
-
|
118 |
)
|
119 |
next(it)
|
120 |
|
121 |
t = model.get_temperature(out)
|
|
|
122 |
if family == 'lit':
|
123 |
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
|
124 |
elif family == 'siglip':
|
@@ -140,7 +144,8 @@ def update_answers(state):
|
|
140 |
"""Generates visible sliders for answers."""
|
141 |
answers = []
|
142 |
for prompt, prob in state[:MAX_ANSWERS]:
|
143 |
-
answers.append(
|
|
|
144 |
while len(answers) < MAX_ANSWERS:
|
145 |
answers.append(gr.Slider(visible=False))
|
146 |
return answers
|
@@ -159,7 +164,10 @@ def create_app():
|
|
159 |
|
160 |
with gr.Blocks(css=css) as demo:
|
161 |
|
162 |
-
gr.Markdown(
|
|
|
|
|
|
|
163 |
|
164 |
status = gr.Markdown(f'Ready ({get_cache_status()})')
|
165 |
|
@@ -168,12 +176,14 @@ def create_app():
|
|
168 |
source = gr.Markdown('', visible=False)
|
169 |
state = gr.State([])
|
170 |
with gr.Column():
|
171 |
-
prompts = gr.Textbox(
|
|
|
172 |
with gr.Row():
|
173 |
|
174 |
values = {}
|
175 |
|
176 |
-
family = gr.Dropdown(
|
|
|
177 |
values['family'] = family.value
|
178 |
|
179 |
# Unfortunately below reactive UI code is a bit convoluted, because:
|
@@ -185,25 +195,34 @@ def create_app():
|
|
185 |
def make_variant(family_value):
|
186 |
choices = list(MODEL_MAP[family_value])
|
187 |
values['variant'] = choices[0]
|
188 |
-
return gr.Dropdown(
|
|
|
189 |
variant = make_variant(family.value)
|
190 |
|
191 |
def make_res(family, variant):
|
192 |
choices = list(MODEL_MAP[family][variant])
|
193 |
values['res'] = choices[0]
|
194 |
-
return gr.Dropdown(
|
|
|
195 |
res = make_res(family.value, variant.value)
|
196 |
values['res'] = res.value
|
197 |
|
198 |
def make_bias(family, variant, res):
|
199 |
visible = family == 'siglip'
|
200 |
value = {
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
}.get((family, variant, res), -10.0)
|
206 |
-
return gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
bias = make_bias(family.value, variant.value, res.value)
|
208 |
values['bias'] = bias.value
|
209 |
|
@@ -248,7 +267,10 @@ def create_app():
|
|
248 |
# a single `status` widget here, and store the computed information in
|
249 |
# `state`...
|
250 |
run.click(
|
251 |
-
fn=compute,
|
|
|
|
|
|
|
252 |
# ... then we use `state` to update UI components without showing a
|
253 |
# progress bar in their place.
|
254 |
status.change(fn=update_answers, inputs=state, outputs=answers)
|
@@ -258,9 +280,9 @@ def create_app():
|
|
258 |
gr.Examples(
|
259 |
examples=[
|
260 |
[
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
]
|
265 |
for ex in info
|
266 |
],
|
@@ -272,7 +294,7 @@ def create_app():
|
|
272 |
return demo
|
273 |
|
274 |
|
275 |
-
if __name__ ==
|
276 |
|
277 |
logging.basicConfig(level=logging.INFO,
|
278 |
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
19 |
import gradio as gr
|
20 |
import PIL.Image
|
21 |
|
22 |
+
# pylint: disable=g-bad-import-order
|
23 |
import big_vision_contrastive_models as models
|
24 |
import gradio_helpers
|
25 |
|
|
|
38 |
MODEL_MAP = {
|
39 |
'lit': {
|
40 |
'B/16': {
|
41 |
+
224: 'lit_b16b',
|
42 |
},
|
43 |
'L/16': {
|
44 |
+
224: 'lit_l16l',
|
45 |
},
|
46 |
},
|
47 |
'siglip': {
|
48 |
'B/16': {
|
49 |
+
224: 'siglip_b16b_224',
|
50 |
+
256: 'siglip_b16b_256',
|
51 |
+
384: 'siglip_b16b_384',
|
52 |
+
512: 'siglip_b16b_512',
|
53 |
},
|
54 |
'L/16': {
|
55 |
+
256: 'siglip_l16l_256',
|
56 |
+
384: 'siglip_l16l_384',
|
57 |
},
|
58 |
'So400m/14': {
|
59 |
+
224: 'siglip_so400m14so440m_224',
|
60 |
+
384: 'siglip_so400m14so440m_384',
|
61 |
},
|
62 |
},
|
63 |
}
|
|
|
73 |
)
|
74 |
|
75 |
|
76 |
+
def compute(
|
77 |
+
image_path, prompts, family, variant, res, bias, progress=gr.Progress()
|
78 |
+
):
|
79 |
"""Loads model and computes answers."""
|
80 |
|
81 |
if image_path is None:
|
|
|
86 |
model_name = MODEL_MAP[family][variant][res]
|
87 |
config = models.MODEL_CONFIGS[model_name]
|
88 |
local_ckpt = gradio_helpers.get_disk_cache(
|
89 |
+
config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
|
90 |
config = dataclasses.replace(config, ckpt=local_ckpt)
|
91 |
params, model = gradio_helpers.get_memory_cache(
|
92 |
config,
|
|
|
94 |
max_cache_size_bytes=MAX_RAM_CACHE,
|
95 |
progress=progress,
|
96 |
estimated_secs={
|
97 |
+
('lit', 'B/16'): 1,
|
98 |
+
('lit', 'L/16'): 2.5,
|
99 |
+
('siglip', 'B/16'): 9,
|
100 |
+
('siglip', 'L/16'): 28,
|
101 |
+
('siglip', 'So400m/14'): 36,
|
102 |
}.get((family, variant))
|
103 |
)
|
104 |
model: models.ContrastiveModel = model
|
|
|
110 |
image = PIL.Image.open(image_path)
|
111 |
next(it)
|
112 |
with gradio_helpers.timed('image features'):
|
113 |
+
zimg, unused_out = model.embed_images(
|
114 |
params, model.preprocess_images([image])
|
115 |
)
|
116 |
next(it)
|
117 |
with gradio_helpers.timed('text features'):
|
118 |
prompts = prompts.split('\n')
|
119 |
ztxt, out = model.embed_texts(
|
120 |
+
params, model.preprocess_texts(prompts)
|
121 |
)
|
122 |
next(it)
|
123 |
|
124 |
t = model.get_temperature(out)
|
125 |
+
text_probs = []
|
126 |
if family == 'lit':
|
127 |
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
|
128 |
elif family == 'siglip':
|
|
|
144 |
"""Generates visible sliders for answers."""
|
145 |
answers = []
|
146 |
for prompt, prob in state[:MAX_ANSWERS]:
|
147 |
+
answers.append(
|
148 |
+
gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
|
149 |
while len(answers) < MAX_ANSWERS:
|
150 |
answers.append(gr.Slider(visible=False))
|
151 |
return answers
|
|
|
164 |
|
165 |
with gr.Blocks(css=css) as demo:
|
166 |
|
167 |
+
gr.Markdown(
|
168 |
+
'Gradio clone of the original '
|
169 |
+
'[LiT demo](https://google-research.github.io/vision_transformer/lit/).'
|
170 |
+
)
|
171 |
|
172 |
status = gr.Markdown(f'Ready ({get_cache_status()})')
|
173 |
|
|
|
176 |
source = gr.Markdown('', visible=False)
|
177 |
state = gr.State([])
|
178 |
with gr.Column():
|
179 |
+
prompts = gr.Textbox(
|
180 |
+
label='Prompts (press Shift-ENTER to add a prompt)')
|
181 |
with gr.Row():
|
182 |
|
183 |
values = {}
|
184 |
|
185 |
+
family = gr.Dropdown(
|
186 |
+
value='lit', choices=list(MODEL_MAP), label='Model family')
|
187 |
values['family'] = family.value
|
188 |
|
189 |
# Unfortunately below reactive UI code is a bit convoluted, because:
|
|
|
195 |
def make_variant(family_value):
|
196 |
choices = list(MODEL_MAP[family_value])
|
197 |
values['variant'] = choices[0]
|
198 |
+
return gr.Dropdown(
|
199 |
+
value=values['variant'], choices=choices, label='Variant')
|
200 |
variant = make_variant(family.value)
|
201 |
|
202 |
def make_res(family, variant):
|
203 |
choices = list(MODEL_MAP[family][variant])
|
204 |
values['res'] = choices[0]
|
205 |
+
return gr.Dropdown(
|
206 |
+
value=values['res'], choices=choices, label='Resolution')
|
207 |
res = make_res(family.value, variant.value)
|
208 |
values['res'] = res.value
|
209 |
|
210 |
def make_bias(family, variant, res):
|
211 |
visible = family == 'siglip'
|
212 |
value = {
|
213 |
+
('siglip', 'B/16', 224): -12.9,
|
214 |
+
('siglip', 'L/16', 256): -12.7,
|
215 |
+
('siglip', 'L/16', 256): -16.5,
|
216 |
+
# ...
|
217 |
}.get((family, variant, res), -10.0)
|
218 |
+
return gr.Slider(
|
219 |
+
value=value,
|
220 |
+
minimum=-20,
|
221 |
+
maximum=0,
|
222 |
+
step=0.05,
|
223 |
+
label='Bias',
|
224 |
+
visible=visible,
|
225 |
+
)
|
226 |
bias = make_bias(family.value, variant.value, res.value)
|
227 |
values['bias'] = bias.value
|
228 |
|
|
|
267 |
# a single `status` widget here, and store the computed information in
|
268 |
# `state`...
|
269 |
run.click(
|
270 |
+
fn=compute,
|
271 |
+
inputs=[image, prompts, family, variant, res, bias],
|
272 |
+
outputs=[status, state],
|
273 |
+
)
|
274 |
# ... then we use `state` to update UI components without showing a
|
275 |
# progress bar in their place.
|
276 |
status.change(fn=update_answers, inputs=state, outputs=answers)
|
|
|
280 |
gr.Examples(
|
281 |
examples=[
|
282 |
[
|
283 |
+
IMG_URL_FMT.format(ex['id']),
|
284 |
+
ex['prompts'].replace(', ', '\n'),
|
285 |
+
'[source](%s)' % ex['source'],
|
286 |
]
|
287 |
for ex in info
|
288 |
],
|
|
|
294 |
return demo
|
295 |
|
296 |
|
297 |
+
if __name__ == '__main__':
|
298 |
|
299 |
logging.basicConfig(level=logging.INFO,
|
300 |
format='%(asctime)s - %(levelname)s - %(message)s')
|
big_vision_contrastive_models.py
CHANGED
@@ -27,15 +27,17 @@ import transformers
|
|
27 |
|
28 |
|
29 |
def _clone_git(url, destination_folder, commit_hash=None):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
36 |
|
37 |
|
38 |
def setup(commit_hash=None):
|
|
|
39 |
for url, dst_name in (
|
40 |
('https://github.com/google-research/big_vision', 'big_vision_repo'),
|
41 |
('https://github.com/google/flaxformer', 'flaxformer_repo'),
|
@@ -43,11 +45,12 @@ def setup(commit_hash=None):
|
|
43 |
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
|
44 |
if not os.path.exists(dst_path):
|
45 |
_clone_git(url, dst_path, commit_hash)
|
46 |
-
if not
|
47 |
sys.path.insert(0, dst_path)
|
48 |
|
49 |
|
50 |
class ContrastiveModelFamily(enum.Enum):
|
|
|
51 |
LIT = 'lit'
|
52 |
SIGLIP = 'siglip'
|
53 |
|
@@ -96,18 +99,21 @@ class ContrastiveModel:
|
|
96 |
return ztxt, out
|
97 |
|
98 |
def preprocess_texts(self, texts):
|
|
|
99 |
|
100 |
def tokenize_pad(text, seqlen=self.config.seqlen):
|
101 |
|
102 |
if self.config.family == ContrastiveModelFamily.LIT:
|
103 |
-
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)
|
|
|
104 |
tokens = tokens[:seqlen]
|
105 |
return tokens + [0] * (seqlen - len(tokens))
|
106 |
|
107 |
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
108 |
tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
|
109 |
if len(tokens) >= seqlen:
|
110 |
-
|
|
|
111 |
return tokens + [0] * (seqlen - len(tokens))
|
112 |
|
113 |
return np.array([tokenize_pad(text) for text in texts])
|
@@ -125,7 +131,9 @@ class ContrastiveModel:
|
|
125 |
]) / 127.5 - 1.0
|
126 |
|
127 |
def get_bias(self, out):
|
128 |
-
assert
|
|
|
|
|
129 |
return out['b'].item()
|
130 |
|
131 |
def get_temperature(self, out):
|
@@ -145,7 +153,9 @@ class ContrastiveModel:
|
|
145 |
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
|
146 |
|
147 |
|
148 |
-
def _make_config(
|
|
|
|
|
149 |
if family == 'lit':
|
150 |
tokenizer = ckpt.replace('.npz', '.txt')
|
151 |
else:
|
@@ -153,11 +163,12 @@ def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_
|
|
153 |
return ContrastiveModelConfig(
|
154 |
family=ContrastiveModelFamily(family), variant=variant, res=res,
|
155 |
textvariant=textvariant, embdim=embdim, seqlen=seqlen,
|
156 |
-
tokenizer=tokenizer, vocab_size=
|
157 |
ckpt=ckpt,
|
158 |
)
|
159 |
|
160 |
|
|
|
161 |
MODEL_CONFIGS = dict(
|
162 |
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
|
163 |
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
|
@@ -173,6 +184,7 @@ MODEL_CONFIGS = dict(
|
|
173 |
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
|
174 |
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
|
175 |
)
|
|
|
176 |
|
177 |
|
178 |
@functools.cache
|
@@ -187,7 +199,6 @@ def load_tokenizer_sp(name_or_path):
|
|
187 |
|
188 |
@functools.cache
|
189 |
def load_tokenizer_bert(path):
|
190 |
-
tok = sentencepiece.SentencePieceProcessor()
|
191 |
if path.startswith('gs://'):
|
192 |
dst = tempfile.mktemp()
|
193 |
gfile.copy(path, dst)
|
@@ -203,7 +214,9 @@ def load_model(config, check_params=False):
|
|
203 |
cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
|
204 |
if config.family == ContrastiveModelFamily.LIT:
|
205 |
cfg.text_model = 'proj.flaxformer.bert'
|
206 |
-
cfg.image = dict(
|
|
|
|
|
207 |
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
|
208 |
cfg.text = dict(config=bert_config, head_zeroinit=False)
|
209 |
tokenizer_bert = load_tokenizer_bert(config.tokenizer)
|
@@ -211,10 +224,12 @@ def load_model(config, check_params=False):
|
|
211 |
if config.variant == 'L/16':
|
212 |
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
213 |
else:
|
214 |
-
|
|
|
215 |
else:
|
216 |
cfg.image = dict(variant=config.variant, pool_type='map')
|
217 |
-
|
|
|
218 |
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
|
219 |
cfg.bias_init = -10.0
|
220 |
tokenizer_sp = load_tokenizer_sp(config.tokenizer)
|
@@ -223,7 +238,7 @@ def load_model(config, check_params=False):
|
|
223 |
cfg.temperature_init = 10.0
|
224 |
|
225 |
model_mod = importlib.import_module(
|
226 |
-
|
227 |
model = model_mod.Model(**cfg)
|
228 |
|
229 |
init_params = None # Faster but bypasses loading sanity-checks.
|
|
|
27 |
|
28 |
|
29 |
def _clone_git(url, destination_folder, commit_hash=None):
|
30 |
+
subprocess.run(
|
31 |
+
['git', 'clone', '--depth=1', url, destination_folder], check=True
|
32 |
+
)
|
33 |
+
if commit_hash:
|
34 |
+
subprocess.run(
|
35 |
+
['git', '-C', destination_folder, 'checkout', commit_hash], check=True
|
36 |
+
)
|
37 |
|
38 |
|
39 |
def setup(commit_hash=None):
|
40 |
+
"""Checks out required non-pypi code from Github."""
|
41 |
for url, dst_name in (
|
42 |
('https://github.com/google-research/big_vision', 'big_vision_repo'),
|
43 |
('https://github.com/google/flaxformer', 'flaxformer_repo'),
|
|
|
45 |
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
|
46 |
if not os.path.exists(dst_path):
|
47 |
_clone_git(url, dst_path, commit_hash)
|
48 |
+
if dst_path not in sys.path:
|
49 |
sys.path.insert(0, dst_path)
|
50 |
|
51 |
|
52 |
class ContrastiveModelFamily(enum.Enum):
|
53 |
+
"""Defines a contrastive model family."""
|
54 |
LIT = 'lit'
|
55 |
SIGLIP = 'siglip'
|
56 |
|
|
|
99 |
return ztxt, out
|
100 |
|
101 |
def preprocess_texts(self, texts):
|
102 |
+
"""Converts texts to padded tokens."""
|
103 |
|
104 |
def tokenize_pad(text, seqlen=self.config.seqlen):
|
105 |
|
106 |
if self.config.family == ContrastiveModelFamily.LIT:
|
107 |
+
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)
|
108 |
+
tokens = tokens[:-1] # removes [SEP]
|
109 |
tokens = tokens[:seqlen]
|
110 |
return tokens + [0] * (seqlen - len(tokens))
|
111 |
|
112 |
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
113 |
tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
|
114 |
if len(tokens) >= seqlen:
|
115 |
+
eos_id = self.tokenizer_sp.eos_id()
|
116 |
+
return tokens[:seqlen - 1] + [eos_id] # "sticky" eos
|
117 |
return tokens + [0] * (seqlen - len(tokens))
|
118 |
|
119 |
return np.array([tokenize_pad(text) for text in texts])
|
|
|
131 |
]) / 127.5 - 1.0
|
132 |
|
133 |
def get_bias(self, out):
|
134 |
+
assert (
|
135 |
+
self.config.family == ContrastiveModelFamily.SIGLIP
|
136 |
+
), self.config.family
|
137 |
return out['b'].item()
|
138 |
|
139 |
def get_temperature(self, out):
|
|
|
153 |
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
|
154 |
|
155 |
|
156 |
+
def _make_config(
|
157 |
+
family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size
|
158 |
+
):
|
159 |
if family == 'lit':
|
160 |
tokenizer = ckpt.replace('.npz', '.txt')
|
161 |
else:
|
|
|
163 |
return ContrastiveModelConfig(
|
164 |
family=ContrastiveModelFamily(family), variant=variant, res=res,
|
165 |
textvariant=textvariant, embdim=embdim, seqlen=seqlen,
|
166 |
+
tokenizer=tokenizer, vocab_size=vocab_size,
|
167 |
ckpt=ckpt,
|
168 |
)
|
169 |
|
170 |
|
171 |
+
# pylint: disable=line-too-long
|
172 |
MODEL_CONFIGS = dict(
|
173 |
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
|
174 |
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
|
|
|
184 |
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
|
185 |
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
|
186 |
)
|
187 |
+
# pylint: enable=line-too-long
|
188 |
|
189 |
|
190 |
@functools.cache
|
|
|
199 |
|
200 |
@functools.cache
|
201 |
def load_tokenizer_bert(path):
|
|
|
202 |
if path.startswith('gs://'):
|
203 |
dst = tempfile.mktemp()
|
204 |
gfile.copy(path, dst)
|
|
|
214 |
cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
|
215 |
if config.family == ContrastiveModelFamily.LIT:
|
216 |
cfg.text_model = 'proj.flaxformer.bert'
|
217 |
+
cfg.image = dict(
|
218 |
+
variant=config.variant, pool_type='tok', head_zeroinit=False
|
219 |
+
)
|
220 |
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
|
221 |
cfg.text = dict(config=bert_config, head_zeroinit=False)
|
222 |
tokenizer_bert = load_tokenizer_bert(config.tokenizer)
|
|
|
224 |
if config.variant == 'L/16':
|
225 |
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
226 |
else:
|
227 |
+
# (image_out_dim, text_out_dim)
|
228 |
+
cfg.out_dim = (config.embdim, config.embdim)
|
229 |
else:
|
230 |
cfg.image = dict(variant=config.variant, pool_type='map')
|
231 |
+
# TODO(lbeyer): remove later, default
|
232 |
+
cfg.text_model = 'proj.image_text.text_transformer'
|
233 |
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
|
234 |
cfg.bias_init = -10.0
|
235 |
tokenizer_sp = load_tokenizer_sp(config.tokenizer)
|
|
|
238 |
cfg.temperature_init = 10.0
|
239 |
|
240 |
model_mod = importlib.import_module(
|
241 |
+
'big_vision.models.proj.image_text.two_towers')
|
242 |
model = model_mod.Model(**cfg)
|
243 |
|
244 |
init_params = None # Faster but bypasses loading sanity-checks.
|
gradio_helpers.py
CHANGED
@@ -30,8 +30,9 @@ def timed(name):
|
|
30 |
logging.info('Timed %s: %.1f secs', name, timing['secs'])
|
31 |
|
32 |
|
33 |
-
|
34 |
-
|
|
|
35 |
"""Copies a file with progress bar.
|
36 |
|
37 |
Args:
|
@@ -39,6 +40,7 @@ def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite
|
|
39 |
dst: Destination file. Path must be readable by `tf.io.gfile`.
|
40 |
progress: An object with a `.tqdm` attribute, or `None`.
|
41 |
block_size: Size of individual blocks to be read/written.
|
|
|
42 |
"""
|
43 |
if os.path.dirname(dst):
|
44 |
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
@@ -87,7 +89,9 @@ def _get_array_sizes(tree):
|
|
87 |
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
|
88 |
|
89 |
|
90 |
-
def get_memory_cache(
|
|
|
|
|
91 |
"""Keeps cache below specified size by removing elements not last accessed."""
|
92 |
if key in _memory_cache:
|
93 |
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
|
|
|
30 |
logging.info('Timed %s: %.1f secs', name, timing['secs'])
|
31 |
|
32 |
|
33 |
+
def copy_file(
|
34 |
+
src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
|
35 |
+
):
|
36 |
"""Copies a file with progress bar.
|
37 |
|
38 |
Args:
|
|
|
40 |
dst: Destination file. Path must be readable by `tf.io.gfile`.
|
41 |
progress: An object with a `.tqdm` attribute, or `None`.
|
42 |
block_size: Size of individual blocks to be read/written.
|
43 |
+
overwrite: If `True`, overwrite `dst` if it exists.
|
44 |
"""
|
45 |
if os.path.dirname(dst):
|
46 |
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
|
|
89 |
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
|
90 |
|
91 |
|
92 |
+
def get_memory_cache(
|
93 |
+
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
|
94 |
+
):
|
95 |
"""Keeps cache below specified size by removing elements not last accessed."""
|
96 |
if key in _memory_cache:
|
97 |
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
|