LinWeizheDragon commited on
Commit
d9d68f5
·
verified ·
1 Parent(s): 3dd99ee

Update modeling_flmr.py

Browse files
Files changed (1) hide show
  1. modeling_flmr.py +16 -14
modeling_flmr.py CHANGED
@@ -584,13 +584,14 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
584
  self.text_encoder_embedding_size = self.config.text_config.hidden_size
585
  self.late_interaction_embedding_size = self.config.dim
586
 
587
- self.context_vision_projection = FLMRMultiLayerPerceptron(
588
- (
589
- self.vision_encoder_embedding_size,
590
- (self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
591
- self.late_interaction_embedding_size * self.mapping_network_prefix_length,
 
 
592
  )
593
- )
594
 
595
  if self.config.use_vision_encoder:
596
  self.context_vision_encoder = FLMRVisionModel(config.vision_config)
@@ -636,13 +637,14 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
636
  self.query_text_encoder_linear = self.context_text_encoder_linear
637
  self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
638
 
639
- if self.config.separate_query_and_context_vision_encoder:
640
- self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
641
- self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
642
- else:
643
- self.query_vision_encoder = self.context_vision_encoder
644
- self.query_vision_projection = self.context_vision_projection
645
- self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
 
646
 
647
  if self.config.load_cpu_extension:
648
  try:
@@ -1304,7 +1306,7 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
1304
  # TODO: fix the engine to support masks with discontinuous 0 and 1.
1305
  D = torch.cat([vision_embeddings, text_embeddings], dim=1)
1306
  # concatenate the mask
1307
- mask = torch.cat([mask, image_mask], dim=1)
1308
  elif concat_output_from_vision_encoder:
1309
  D = vision_embeddings
1310
  mask = image_mask
 
584
  self.text_encoder_embedding_size = self.config.text_config.hidden_size
585
  self.late_interaction_embedding_size = self.config.dim
586
 
587
+ if self.config.use_vision_encoder:
588
+ self.context_vision_projection = FLMRMultiLayerPerceptron(
589
+ (
590
+ self.vision_encoder_embedding_size,
591
+ (self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
592
+ self.late_interaction_embedding_size * self.mapping_network_prefix_length,
593
+ )
594
  )
 
595
 
596
  if self.config.use_vision_encoder:
597
  self.context_vision_encoder = FLMRVisionModel(config.vision_config)
 
637
  self.query_text_encoder_linear = self.context_text_encoder_linear
638
  self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
639
 
640
+ if self.config.use_vision_encoder:
641
+ if self.config.separate_query_and_context_vision_encoder:
642
+ self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
643
+ self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
644
+ else:
645
+ self.query_vision_encoder = self.context_vision_encoder
646
+ self.query_vision_projection = self.context_vision_projection
647
+ self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
648
 
649
  if self.config.load_cpu_extension:
650
  try:
 
1306
  # TODO: fix the engine to support masks with discontinuous 0 and 1.
1307
  D = torch.cat([vision_embeddings, text_embeddings], dim=1)
1308
  # concatenate the mask
1309
+ mask = torch.cat([image_mask, mask], dim=1)
1310
  elif concat_output_from_vision_encoder:
1311
  D = vision_embeddings
1312
  mask = image_mask