|
from transformers import PreTrainedModel, PretrainedConfig |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class ZeroShotEmbeddingConfig(PretrainedConfig): |
|
model_type = "embedding-head" |
|
|
|
def __init__(self, input_size=768, hidden_size=2048, output_size=128, base_embedding_model='all-mpnet-base-v2', **kwargs): |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.output_size = output_size |
|
self.base_embedding_model = base_embedding_model |
|
super().__init__(**kwargs) |
|
|
|
|
|
class ZeroShotEmbedding(PreTrainedModel): |
|
config_class = ZeroShotEmbeddingConfig |
|
|
|
def __init__(self, config): |
|
super(ZeroShotEmbedding, self).__init__(config) |
|
|
|
input_size = config.input_size |
|
hidden_size = config.hidden_size |
|
output_size = config.output_size |
|
|
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.output_size = output_size |
|
|
|
self.fc1 = nn.Linear(input_size * 2, hidden_size) |
|
self.fc2 = nn.Linear(hidden_size, output_size) |
|
self.gelu = nn.GELU() |
|
|
|
def forward(self, prompt_embedding, text_a_embedding, text_b_embedding=None, labels=None, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = torch.cat((text_a_embedding, prompt_embedding), dim=1) |
|
if text_b_embedding is not None: |
|
|
|
|
|
x2 = torch.cat((text_b_embedding, prompt_embedding), dim=1) |
|
|
|
|
|
x = self.fc1(x) |
|
x = self.gelu(x) |
|
x = self.fc2(x) |
|
x = nn.functional.normalize(x, p=2, dim=1) |
|
if text_b_embedding is not None: |
|
x2 = self.fc1(x2) |
|
x2 = self.gelu(x2) |
|
x2 = self.fc2(x2) |
|
x2 = nn.functional.normalize(x2, p=2, dim=1) |
|
|
|
dot_product = torch.bmm(x.unsqueeze(1), x2.unsqueeze(2)).squeeze() |
|
if labels is not None: |
|
|
|
loss = torch.mean((dot_product - labels) ** 2) |
|
return loss, dot_product |
|
return dot_product |
|
return x |
|
|
|
|
|
class ZeroShotEmbeddingForClustering(PreTrainedModel): |
|
config_class = ZeroShotEmbeddingConfig |
|
|
|
def __init__(self, config): |
|
super(ZeroShotEmbeddingForClustering, self).__init__(config) |
|
self.base_embedding_model = SentenceTransformer( |
|
config.base_embedding_model) |
|
self.head_model = ZeroShotEmbedding(config) |
|
|
|
def forward(self, texts, prompt, **kwargs): |
|
text_embeddings = self.base_embedding_model.encode(texts) |
|
prompt_embedding = self.base_embedding_model.encode(prompt) |
|
prompt_embeddings = np.tile(prompt_embedding, (len(texts), 1)) |
|
text_embeddings = torch.tensor(text_embeddings) |
|
prompt_embeddings = torch.tensor(prompt_embeddings) |
|
prompted_embeddings = self.head_model( |
|
prompt_embeddings, text_embeddings) |
|
similarity = torch.mm(prompted_embeddings, |
|
prompted_embeddings.transpose(0, 1)) |
|
return similarity |
|
|
|
@classmethod |
|
def from_pretrained_base(cls, pretrained_model_name_or_path): |
|
head_model = ZeroShotEmbedding.from_pretrained( |
|
pretrained_model_name_or_path) |
|
model = cls(head_model.config) |
|
cls.head_model = head_model |
|
return model |
|
|
|
|
|
ZeroShotEmbeddingConfig.register_for_auto_class() |
|
ZeroShotEmbedding.register_for_auto_class("AutoModel") |
|
ZeroShotEmbeddingForClustering.register_for_auto_class("AutoModel") |
|
|