|
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig |
|
from .configuration_bert import SimBertConfig |
|
from torch import nn |
|
|
|
class SimBertModel(PreTrainedModel): |
|
""" SimBert Model |
|
""" |
|
|
|
config_class = SimBertConfig |
|
|
|
def __init__( |
|
self, |
|
config: PretrainedConfig |
|
) -> None: |
|
super().__init__(config) |
|
self.bert = BertModel(config=config, add_pooling_layer=True) |
|
self.fc = nn.Linear(config.hidden_size, 2) |
|
|
|
self.loss_fct = nn.MSELoss() |
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
token_type_ids, |
|
attention_mask, |
|
labels=None |
|
): |
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids |
|
) |
|
pooled_output = outputs.pooler_output |
|
logits = self.fc(pooled_output) |
|
logits = self.softmax(logits)[:,1] |
|
if labels is not None: |
|
loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
|
return loss, logits |
|
return None, logits |
|
|
|
class CosSimBertModel(PreTrainedModel): |
|
""" CosSimBert Model |
|
""" |
|
|
|
config_class = SimBertConfig |
|
|
|
def __init__( |
|
self, |
|
config: PretrainedConfig |
|
) -> None: |
|
super().__init__(config) |
|
self.bert = BertModel(config=config, add_pooling_layer=True) |
|
self.loss_fct = nn.MSELoss() |
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
token_type_ids, |
|
attention_mask, |
|
labels=None |
|
): |
|
seq_length = input_ids.size(-1) |
|
a = { |
|
"input_ids": input_ids[:,:seq_length//2], |
|
"token_type_ids": token_type_ids[:,:seq_length//2], |
|
"attention_mask": attention_mask[:,:seq_length//2] |
|
} |
|
b = { |
|
"input_ids": input_ids[:,seq_length//2:], |
|
"token_type_ids": token_type_ids[:,seq_length//2:], |
|
"attention_mask": attention_mask[:,seq_length//2:] |
|
} |
|
outputs_a = self.bert(**a) |
|
outputs_b = self.bert(**b) |
|
pooled_a_output = outputs_a.pooler_output |
|
pooled_b_output = outputs_b.pooler_output |
|
logits = nn.functional.cosine_similarity(pooled_a_output, pooled_b_output) |
|
if labels is not None: |
|
loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
|
return loss, logits |
|
return None, logits |
|
|
|
def encode( |
|
self, |
|
input_ids, |
|
token_type_ids, |
|
attention_mask, |
|
): |
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids |
|
) |
|
pooled_output = outputs.pooler_output |
|
return pooled_output |