andsteing commited on
Commit
3cfc2e7
1 Parent(s): bc8a162

Reformatted code a bit.

Browse files
Files changed (3) hide show
  1. app.py +57 -35
  2. big_vision_contrastive_models.py +32 -17
  3. gradio_helpers.py +7 -3
app.py CHANGED
@@ -19,6 +19,7 @@ import urllib.request
19
  import gradio as gr
20
  import PIL.Image
21
 
 
22
  import big_vision_contrastive_models as models
23
  import gradio_helpers
24
 
@@ -37,26 +38,26 @@ LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
37
  MODEL_MAP = {
38
  'lit': {
39
  'B/16': {
40
- 224: 'lit_b16b',
41
  },
42
  'L/16': {
43
- 224: 'lit_l16l',
44
  },
45
  },
46
  'siglip': {
47
  'B/16': {
48
- 224: 'siglip_b16b_224',
49
- 256: 'siglip_b16b_256',
50
- 384: 'siglip_b16b_384',
51
- 512: 'siglip_b16b_512',
52
  },
53
  'L/16': {
54
- 256: 'siglip_l16l_256',
55
- 384: 'siglip_l16l_384',
56
  },
57
  'So400m/14': {
58
- 224: 'siglip_so400m14so440m_224',
59
- 384: 'siglip_so400m14so440m_384',
60
  },
61
  },
62
  }
@@ -72,7 +73,9 @@ def get_cache_status():
72
  )
73
 
74
 
75
- def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progress()):
 
 
76
  """Loads model and computes answers."""
77
 
78
  if image_path is None:
@@ -83,7 +86,7 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
83
  model_name = MODEL_MAP[family][variant][res]
84
  config = models.MODEL_CONFIGS[model_name]
85
  local_ckpt = gradio_helpers.get_disk_cache(
86
- config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
87
  config = dataclasses.replace(config, ckpt=local_ckpt)
88
  params, model = gradio_helpers.get_memory_cache(
89
  config,
@@ -91,11 +94,11 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
91
  max_cache_size_bytes=MAX_RAM_CACHE,
92
  progress=progress,
93
  estimated_secs={
94
- ('lit', 'B/16'): 1,
95
- ('lit', 'L/16'): 2.5,
96
- ('siglip', 'B/16'): 9,
97
- ('siglip', 'L/16'): 28,
98
- ('siglip', 'So400m/14'): 36,
99
  }.get((family, variant))
100
  )
101
  model: models.ContrastiveModel = model
@@ -107,18 +110,19 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
107
  image = PIL.Image.open(image_path)
108
  next(it)
109
  with gradio_helpers.timed('image features'):
110
- zimg, out = model.embed_images(
111
  params, model.preprocess_images([image])
112
  )
113
  next(it)
114
  with gradio_helpers.timed('text features'):
115
  prompts = prompts.split('\n')
116
  ztxt, out = model.embed_texts(
117
- params, model.preprocess_texts(prompts)
118
  )
119
  next(it)
120
 
121
  t = model.get_temperature(out)
 
122
  if family == 'lit':
123
  text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
124
  elif family == 'siglip':
@@ -140,7 +144,8 @@ def update_answers(state):
140
  """Generates visible sliders for answers."""
141
  answers = []
142
  for prompt, prob in state[:MAX_ANSWERS]:
143
- answers.append(gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
 
144
  while len(answers) < MAX_ANSWERS:
145
  answers.append(gr.Slider(visible=False))
146
  return answers
@@ -159,7 +164,10 @@ def create_app():
159
 
160
  with gr.Blocks(css=css) as demo:
161
 
162
- gr.Markdown('Gradio clone of the original [LiT demo](https://google-research.github.io/vision_transformer/lit/).')
 
 
 
163
 
164
  status = gr.Markdown(f'Ready ({get_cache_status()})')
165
 
@@ -168,12 +176,14 @@ def create_app():
168
  source = gr.Markdown('', visible=False)
169
  state = gr.State([])
170
  with gr.Column():
171
- prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)')
 
172
  with gr.Row():
173
 
174
  values = {}
175
 
176
- family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
 
177
  values['family'] = family.value
178
 
179
  # Unfortunately below reactive UI code is a bit convoluted, because:
@@ -185,25 +195,34 @@ def create_app():
185
  def make_variant(family_value):
186
  choices = list(MODEL_MAP[family_value])
187
  values['variant'] = choices[0]
188
- return gr.Dropdown(value=values['variant'], choices=choices, label='Variant')
 
189
  variant = make_variant(family.value)
190
 
191
  def make_res(family, variant):
192
  choices = list(MODEL_MAP[family][variant])
193
  values['res'] = choices[0]
194
- return gr.Dropdown(value=values['res'], choices=choices, label='Resolution')
 
195
  res = make_res(family.value, variant.value)
196
  values['res'] = res.value
197
 
198
  def make_bias(family, variant, res):
199
  visible = family == 'siglip'
200
  value = {
201
- ('siglip', 'B/16', 224): -12.9,
202
- ('siglip', 'L/16', 256): -12.7,
203
- ('siglip', 'L/16', 256): -16.5,
204
- # ...
205
  }.get((family, variant, res), -10.0)
206
- return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible)
 
 
 
 
 
 
 
207
  bias = make_bias(family.value, variant.value, res.value)
208
  values['bias'] = bias.value
209
 
@@ -248,7 +267,10 @@ def create_app():
248
  # a single `status` widget here, and store the computed information in
249
  # `state`...
250
  run.click(
251
- fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state])
 
 
 
252
  # ... then we use `state` to update UI components without showing a
253
  # progress bar in their place.
254
  status.change(fn=update_answers, inputs=state, outputs=answers)
@@ -258,9 +280,9 @@ def create_app():
258
  gr.Examples(
259
  examples=[
260
  [
261
- IMG_URL_FMT.format(ex['id']),
262
- ex['prompts'].replace(', ', '\n'),
263
- '[source](%s)' % ex['source'],
264
  ]
265
  for ex in info
266
  ],
@@ -272,7 +294,7 @@ def create_app():
272
  return demo
273
 
274
 
275
- if __name__ == "__main__":
276
 
277
  logging.basicConfig(level=logging.INFO,
278
  format='%(asctime)s - %(levelname)s - %(message)s')
 
19
  import gradio as gr
20
  import PIL.Image
21
 
22
+ # pylint: disable=g-bad-import-order
23
  import big_vision_contrastive_models as models
24
  import gradio_helpers
25
 
 
38
  MODEL_MAP = {
39
  'lit': {
40
  'B/16': {
41
+ 224: 'lit_b16b',
42
  },
43
  'L/16': {
44
+ 224: 'lit_l16l',
45
  },
46
  },
47
  'siglip': {
48
  'B/16': {
49
+ 224: 'siglip_b16b_224',
50
+ 256: 'siglip_b16b_256',
51
+ 384: 'siglip_b16b_384',
52
+ 512: 'siglip_b16b_512',
53
  },
54
  'L/16': {
55
+ 256: 'siglip_l16l_256',
56
+ 384: 'siglip_l16l_384',
57
  },
58
  'So400m/14': {
59
+ 224: 'siglip_so400m14so440m_224',
60
+ 384: 'siglip_so400m14so440m_384',
61
  },
62
  },
63
  }
 
73
  )
74
 
75
 
76
+ def compute(
77
+ image_path, prompts, family, variant, res, bias, progress=gr.Progress()
78
+ ):
79
  """Loads model and computes answers."""
80
 
81
  if image_path is None:
 
86
  model_name = MODEL_MAP[family][variant][res]
87
  config = models.MODEL_CONFIGS[model_name]
88
  local_ckpt = gradio_helpers.get_disk_cache(
89
+ config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
90
  config = dataclasses.replace(config, ckpt=local_ckpt)
91
  params, model = gradio_helpers.get_memory_cache(
92
  config,
 
94
  max_cache_size_bytes=MAX_RAM_CACHE,
95
  progress=progress,
96
  estimated_secs={
97
+ ('lit', 'B/16'): 1,
98
+ ('lit', 'L/16'): 2.5,
99
+ ('siglip', 'B/16'): 9,
100
+ ('siglip', 'L/16'): 28,
101
+ ('siglip', 'So400m/14'): 36,
102
  }.get((family, variant))
103
  )
104
  model: models.ContrastiveModel = model
 
110
  image = PIL.Image.open(image_path)
111
  next(it)
112
  with gradio_helpers.timed('image features'):
113
+ zimg, unused_out = model.embed_images(
114
  params, model.preprocess_images([image])
115
  )
116
  next(it)
117
  with gradio_helpers.timed('text features'):
118
  prompts = prompts.split('\n')
119
  ztxt, out = model.embed_texts(
120
+ params, model.preprocess_texts(prompts)
121
  )
122
  next(it)
123
 
124
  t = model.get_temperature(out)
125
+ text_probs = []
126
  if family == 'lit':
127
  text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
128
  elif family == 'siglip':
 
144
  """Generates visible sliders for answers."""
145
  answers = []
146
  for prompt, prob in state[:MAX_ANSWERS]:
147
+ answers.append(
148
+ gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
149
  while len(answers) < MAX_ANSWERS:
150
  answers.append(gr.Slider(visible=False))
151
  return answers
 
164
 
165
  with gr.Blocks(css=css) as demo:
166
 
167
+ gr.Markdown(
168
+ 'Gradio clone of the original '
169
+ '[LiT demo](https://google-research.github.io/vision_transformer/lit/).'
170
+ )
171
 
172
  status = gr.Markdown(f'Ready ({get_cache_status()})')
173
 
 
176
  source = gr.Markdown('', visible=False)
177
  state = gr.State([])
178
  with gr.Column():
179
+ prompts = gr.Textbox(
180
+ label='Prompts (press Shift-ENTER to add a prompt)')
181
  with gr.Row():
182
 
183
  values = {}
184
 
185
+ family = gr.Dropdown(
186
+ value='lit', choices=list(MODEL_MAP), label='Model family')
187
  values['family'] = family.value
188
 
189
  # Unfortunately below reactive UI code is a bit convoluted, because:
 
195
  def make_variant(family_value):
196
  choices = list(MODEL_MAP[family_value])
197
  values['variant'] = choices[0]
198
+ return gr.Dropdown(
199
+ value=values['variant'], choices=choices, label='Variant')
200
  variant = make_variant(family.value)
201
 
202
  def make_res(family, variant):
203
  choices = list(MODEL_MAP[family][variant])
204
  values['res'] = choices[0]
205
+ return gr.Dropdown(
206
+ value=values['res'], choices=choices, label='Resolution')
207
  res = make_res(family.value, variant.value)
208
  values['res'] = res.value
209
 
210
  def make_bias(family, variant, res):
211
  visible = family == 'siglip'
212
  value = {
213
+ ('siglip', 'B/16', 224): -12.9,
214
+ ('siglip', 'L/16', 256): -12.7,
215
+ ('siglip', 'L/16', 256): -16.5,
216
+ # ...
217
  }.get((family, variant, res), -10.0)
218
+ return gr.Slider(
219
+ value=value,
220
+ minimum=-20,
221
+ maximum=0,
222
+ step=0.05,
223
+ label='Bias',
224
+ visible=visible,
225
+ )
226
  bias = make_bias(family.value, variant.value, res.value)
227
  values['bias'] = bias.value
228
 
 
267
  # a single `status` widget here, and store the computed information in
268
  # `state`...
269
  run.click(
270
+ fn=compute,
271
+ inputs=[image, prompts, family, variant, res, bias],
272
+ outputs=[status, state],
273
+ )
274
  # ... then we use `state` to update UI components without showing a
275
  # progress bar in their place.
276
  status.change(fn=update_answers, inputs=state, outputs=answers)
 
280
  gr.Examples(
281
  examples=[
282
  [
283
+ IMG_URL_FMT.format(ex['id']),
284
+ ex['prompts'].replace(', ', '\n'),
285
+ '[source](%s)' % ex['source'],
286
  ]
287
  for ex in info
288
  ],
 
294
  return demo
295
 
296
 
297
+ if __name__ == '__main__':
298
 
299
  logging.basicConfig(level=logging.INFO,
300
  format='%(asctime)s - %(levelname)s - %(message)s')
big_vision_contrastive_models.py CHANGED
@@ -27,15 +27,17 @@ import transformers
27
 
28
 
29
  def _clone_git(url, destination_folder, commit_hash=None):
30
- subprocess.run([
31
- 'git', 'clone', '--depth=1',
32
- url, destination_folder
33
- ], check=True)
34
- if commit_hash:
35
- subprocess.run(['git', '-C', destination_folder, 'checkout', commit_hash], check=True)
 
36
 
37
 
38
  def setup(commit_hash=None):
 
39
  for url, dst_name in (
40
  ('https://github.com/google-research/big_vision', 'big_vision_repo'),
41
  ('https://github.com/google/flaxformer', 'flaxformer_repo'),
@@ -43,11 +45,12 @@ def setup(commit_hash=None):
43
  dst_path = os.path.join(tempfile.gettempdir(), dst_name)
44
  if not os.path.exists(dst_path):
45
  _clone_git(url, dst_path, commit_hash)
46
- if not dst_path in sys.path:
47
  sys.path.insert(0, dst_path)
48
 
49
 
50
  class ContrastiveModelFamily(enum.Enum):
 
51
  LIT = 'lit'
52
  SIGLIP = 'siglip'
53
 
@@ -96,18 +99,21 @@ class ContrastiveModel:
96
  return ztxt, out
97
 
98
  def preprocess_texts(self, texts):
 
99
 
100
  def tokenize_pad(text, seqlen=self.config.seqlen):
101
 
102
  if self.config.family == ContrastiveModelFamily.LIT:
103
- tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)[:-1] # removes [SEP]
 
104
  tokens = tokens[:seqlen]
105
  return tokens + [0] * (seqlen - len(tokens))
106
 
107
  if self.config.family == ContrastiveModelFamily.SIGLIP:
108
  tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
109
  if len(tokens) >= seqlen:
110
- return tokens[:seqlen - 1] + [tok.eos_id()] # "sticky" eos
 
111
  return tokens + [0] * (seqlen - len(tokens))
112
 
113
  return np.array([tokenize_pad(text) for text in texts])
@@ -125,7 +131,9 @@ class ContrastiveModel:
125
  ]) / 127.5 - 1.0
126
 
127
  def get_bias(self, out):
128
- assert self.config.family == ContrastiveModelFamily.SIGLIP, self.config.family
 
 
129
  return out['b'].item()
130
 
131
  def get_temperature(self, out):
@@ -145,7 +153,9 @@ class ContrastiveModel:
145
  return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
146
 
147
 
148
- def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size):
 
 
149
  if family == 'lit':
150
  tokenizer = ckpt.replace('.npz', '.txt')
151
  else:
@@ -153,11 +163,12 @@ def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_
153
  return ContrastiveModelConfig(
154
  family=ContrastiveModelFamily(family), variant=variant, res=res,
155
  textvariant=textvariant, embdim=embdim, seqlen=seqlen,
156
- tokenizer=tokenizer, vocab_size=32_000,
157
  ckpt=ckpt,
158
  )
159
 
160
 
 
161
  MODEL_CONFIGS = dict(
162
  lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
163
  lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
@@ -173,6 +184,7 @@ MODEL_CONFIGS = dict(
173
  siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
174
  siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
175
  )
 
176
 
177
 
178
  @functools.cache
@@ -187,7 +199,6 @@ def load_tokenizer_sp(name_or_path):
187
 
188
  @functools.cache
189
  def load_tokenizer_bert(path):
190
- tok = sentencepiece.SentencePieceProcessor()
191
  if path.startswith('gs://'):
192
  dst = tempfile.mktemp()
193
  gfile.copy(path, dst)
@@ -203,7 +214,9 @@ def load_model(config, check_params=False):
203
  cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
204
  if config.family == ContrastiveModelFamily.LIT:
205
  cfg.text_model = 'proj.flaxformer.bert'
206
- cfg.image = dict(variant=config.variant, pool_type='tok', head_zeroinit=False)
 
 
207
  bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
208
  cfg.text = dict(config=bert_config, head_zeroinit=False)
209
  tokenizer_bert = load_tokenizer_bert(config.tokenizer)
@@ -211,10 +224,12 @@ def load_model(config, check_params=False):
211
  if config.variant == 'L/16':
212
  cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
213
  else:
214
- cfg.out_dim = (config.embdim, config.embdim) # (image_out_dim, text_out_dim)
 
215
  else:
216
  cfg.image = dict(variant=config.variant, pool_type='map')
217
- cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default
 
218
  cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
219
  cfg.bias_init = -10.0
220
  tokenizer_sp = load_tokenizer_sp(config.tokenizer)
@@ -223,7 +238,7 @@ def load_model(config, check_params=False):
223
  cfg.temperature_init = 10.0
224
 
225
  model_mod = importlib.import_module(
226
- 'big_vision.models.proj.image_text.two_towers')
227
  model = model_mod.Model(**cfg)
228
 
229
  init_params = None # Faster but bypasses loading sanity-checks.
 
27
 
28
 
29
  def _clone_git(url, destination_folder, commit_hash=None):
30
+ subprocess.run(
31
+ ['git', 'clone', '--depth=1', url, destination_folder], check=True
32
+ )
33
+ if commit_hash:
34
+ subprocess.run(
35
+ ['git', '-C', destination_folder, 'checkout', commit_hash], check=True
36
+ )
37
 
38
 
39
  def setup(commit_hash=None):
40
+ """Checks out required non-pypi code from Github."""
41
  for url, dst_name in (
42
  ('https://github.com/google-research/big_vision', 'big_vision_repo'),
43
  ('https://github.com/google/flaxformer', 'flaxformer_repo'),
 
45
  dst_path = os.path.join(tempfile.gettempdir(), dst_name)
46
  if not os.path.exists(dst_path):
47
  _clone_git(url, dst_path, commit_hash)
48
+ if dst_path not in sys.path:
49
  sys.path.insert(0, dst_path)
50
 
51
 
52
  class ContrastiveModelFamily(enum.Enum):
53
+ """Defines a contrastive model family."""
54
  LIT = 'lit'
55
  SIGLIP = 'siglip'
56
 
 
99
  return ztxt, out
100
 
101
  def preprocess_texts(self, texts):
102
+ """Converts texts to padded tokens."""
103
 
104
  def tokenize_pad(text, seqlen=self.config.seqlen):
105
 
106
  if self.config.family == ContrastiveModelFamily.LIT:
107
+ tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)
108
+ tokens = tokens[:-1] # removes [SEP]
109
  tokens = tokens[:seqlen]
110
  return tokens + [0] * (seqlen - len(tokens))
111
 
112
  if self.config.family == ContrastiveModelFamily.SIGLIP:
113
  tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
114
  if len(tokens) >= seqlen:
115
+ eos_id = self.tokenizer_sp.eos_id()
116
+ return tokens[:seqlen - 1] + [eos_id] # "sticky" eos
117
  return tokens + [0] * (seqlen - len(tokens))
118
 
119
  return np.array([tokenize_pad(text) for text in texts])
 
131
  ]) / 127.5 - 1.0
132
 
133
  def get_bias(self, out):
134
+ assert (
135
+ self.config.family == ContrastiveModelFamily.SIGLIP
136
+ ), self.config.family
137
  return out['b'].item()
138
 
139
  def get_temperature(self, out):
 
153
  return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
154
 
155
 
156
+ def _make_config(
157
+ family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size
158
+ ):
159
  if family == 'lit':
160
  tokenizer = ckpt.replace('.npz', '.txt')
161
  else:
 
163
  return ContrastiveModelConfig(
164
  family=ContrastiveModelFamily(family), variant=variant, res=res,
165
  textvariant=textvariant, embdim=embdim, seqlen=seqlen,
166
+ tokenizer=tokenizer, vocab_size=vocab_size,
167
  ckpt=ckpt,
168
  )
169
 
170
 
171
+ # pylint: disable=line-too-long
172
  MODEL_CONFIGS = dict(
173
  lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
174
  lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
 
184
  siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
185
  siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
186
  )
187
+ # pylint: enable=line-too-long
188
 
189
 
190
  @functools.cache
 
199
 
200
  @functools.cache
201
  def load_tokenizer_bert(path):
 
202
  if path.startswith('gs://'):
203
  dst = tempfile.mktemp()
204
  gfile.copy(path, dst)
 
214
  cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
215
  if config.family == ContrastiveModelFamily.LIT:
216
  cfg.text_model = 'proj.flaxformer.bert'
217
+ cfg.image = dict(
218
+ variant=config.variant, pool_type='tok', head_zeroinit=False
219
+ )
220
  bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
221
  cfg.text = dict(config=bert_config, head_zeroinit=False)
222
  tokenizer_bert = load_tokenizer_bert(config.tokenizer)
 
224
  if config.variant == 'L/16':
225
  cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
226
  else:
227
+ # (image_out_dim, text_out_dim)
228
+ cfg.out_dim = (config.embdim, config.embdim)
229
  else:
230
  cfg.image = dict(variant=config.variant, pool_type='map')
231
+ # TODO(lbeyer): remove later, default
232
+ cfg.text_model = 'proj.image_text.text_transformer'
233
  cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
234
  cfg.bias_init = -10.0
235
  tokenizer_sp = load_tokenizer_sp(config.tokenizer)
 
238
  cfg.temperature_init = 10.0
239
 
240
  model_mod = importlib.import_module(
241
+ 'big_vision.models.proj.image_text.two_towers')
242
  model = model_mod.Model(**cfg)
243
 
244
  init_params = None # Faster but bypasses loading sanity-checks.
gradio_helpers.py CHANGED
@@ -30,8 +30,9 @@ def timed(name):
30
  logging.info('Timed %s: %.1f secs', name, timing['secs'])
31
 
32
 
33
-
34
- def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False):
 
35
  """Copies a file with progress bar.
36
 
37
  Args:
@@ -39,6 +40,7 @@ def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite
39
  dst: Destination file. Path must be readable by `tf.io.gfile`.
40
  progress: An object with a `.tqdm` attribute, or `None`.
41
  block_size: Size of individual blocks to be read/written.
 
42
  """
43
  if os.path.dirname(dst):
44
  os.makedirs(os.path.dirname(dst), exist_ok=True)
@@ -87,7 +89,9 @@ def _get_array_sizes(tree):
87
  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
88
 
89
 
90
- def get_memory_cache(key, getter, max_cache_size_bytes, progress=None, estimated_secs=None):
 
 
91
  """Keeps cache below specified size by removing elements not last accessed."""
92
  if key in _memory_cache:
93
  _memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
 
30
  logging.info('Timed %s: %.1f secs', name, timing['secs'])
31
 
32
 
33
+ def copy_file(
34
+ src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
35
+ ):
36
  """Copies a file with progress bar.
37
 
38
  Args:
 
40
  dst: Destination file. Path must be readable by `tf.io.gfile`.
41
  progress: An object with a `.tqdm` attribute, or `None`.
42
  block_size: Size of individual blocks to be read/written.
43
+ overwrite: If `True`, overwrite `dst` if it exists.
44
  """
45
  if os.path.dirname(dst):
46
  os.makedirs(os.path.dirname(dst), exist_ok=True)
 
89
  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
90
 
91
 
92
+ def get_memory_cache(
93
+ key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
94
+ ):
95
  """Keeps cache below specified size by removing elements not last accessed."""
96
  if key in _memory_cache:
97
  _memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order