Update modeling_gigarembed.py
Browse files- modeling_gigarembed.py +1 -0
modeling_gigarembed.py
CHANGED
@@ -409,6 +409,7 @@ class GigarEmbedModel(PreTrainedModel):
|
|
409 |
|
410 |
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None,
|
411 |
return_dict: bool=True, **kwargs):
|
|
|
412 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
413 |
|
414 |
embeds = self.latent_attention_model(
|
|
|
409 |
|
410 |
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None,
|
411 |
return_dict: bool=True, **kwargs):
|
412 |
+
kwargs.pop('token_type_ids', None)
|
413 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
414 |
|
415 |
embeds = self.latent_attention_model(
|