bwang0911 commited on
Commit
cd0cf85
1 Parent(s): 0e50fd1

refactor: optimize training and tokenizer

Browse files
Files changed (1) hide show
  1. modeling_clip.py +11 -1
modeling_clip.py CHANGED
@@ -222,7 +222,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
222
  self.visual_projection = nn.Identity()
223
  self.text_projection = nn.Identity()
224
 
225
- self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
226
  self.post_init()
227
 
228
  def get_text_features(
@@ -247,6 +247,12 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
247
  )
248
  return self.visual_projection(self.vision_model(x=x))
249
 
 
 
 
 
 
 
250
  @torch.inference_mode()
251
  def encode_text(
252
  self,
@@ -291,7 +297,10 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
291
  If convert_to_tensor, a stacked tensor is returned.
292
  If convert_to_numpy, a numpy matrix is returned.
293
  """
 
294
  self.eval()
 
 
295
 
296
  if show_progress_bar is None:
297
  show_progress_bar = (
@@ -362,6 +371,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
362
  if input_was_string:
363
  all_embeddings = all_embeddings[0]
364
 
 
365
  return all_embeddings
366
 
367
  def encode_image(
 
222
  self.visual_projection = nn.Identity()
223
  self.text_projection = nn.Identity()
224
 
225
+ self.tokenizer = None
226
  self.post_init()
227
 
228
  def get_text_features(
 
247
  )
248
  return self.visual_projection(self.vision_model(x=x))
249
 
250
+ @property
251
+ def get_tokenizer(self):
252
+ if not self.tokenizer:
253
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
254
+ return self.tokenizer
255
+
256
  @torch.inference_mode()
257
  def encode_text(
258
  self,
 
297
  If convert_to_tensor, a stacked tensor is returned.
298
  If convert_to_numpy, a numpy matrix is returned.
299
  """
300
+ is_training = self.training
301
  self.eval()
302
+
303
+ self.tokenizer = self.get_tokenizer()
304
 
305
  if show_progress_bar is None:
306
  show_progress_bar = (
 
371
  if input_was_string:
372
  all_embeddings = all_embeddings[0]
373
 
374
+ self.train(is_training)
375
  return all_embeddings
376
 
377
  def encode_image(