from transformers import PreTrainedModel, PretrainedConfig from sentence_transformers import SentenceTransformer import torch import torch.nn as nn import numpy as np class ZeroShotEmbeddingConfig(PretrainedConfig): model_type = "embedding-head" def __init__(self, input_size=768, hidden_size=2048, output_size=128, base_embedding_model='all-mpnet-base-v2', **kwargs): self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.base_embedding_model = base_embedding_model super().__init__(**kwargs) class ZeroShotEmbedding(PreTrainedModel): config_class = ZeroShotEmbeddingConfig def __init__(self, config): super(ZeroShotEmbedding, self).__init__(config) input_size = config.input_size hidden_size = config.hidden_size output_size = config.output_size self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size # 3-layer MLP: input embedding -> hidden -> output embedding self.fc1 = nn.Linear(input_size * 2, hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) self.gelu = nn.GELU() def forward(self, prompt_embedding, text_a_embedding, text_b_embedding=None, labels=None, **kwargs): # document_embedding: [batch_size, input_size] # prompt_embedding: [batch_size, input_size] # output: [batch_size, output_size] # concatenate document embedding and prompt embedding # [batch_size, input_size * 2] x = torch.cat((text_a_embedding, prompt_embedding), dim=1) if text_b_embedding is not None: # concatenate document embedding and prompt embedding # [batch_size, input_size * 2] x2 = torch.cat((text_b_embedding, prompt_embedding), dim=1) # 3-layer MLP x = self.fc1(x) x = self.gelu(x) x = self.fc2(x) x = nn.functional.normalize(x, p=2, dim=1) if text_b_embedding is not None: x2 = self.fc1(x2) x2 = self.gelu(x2) x2 = self.fc2(x2) x2 = nn.functional.normalize(x2, p=2, dim=1) # Compute dot product for batches of output vectors dot_product = torch.bmm(x.unsqueeze(1), x2.unsqueeze(2)).squeeze() if labels is not None: # Compute loss (magnitude of dot product minus label) loss = torch.mean((dot_product - labels) ** 2) return loss, dot_product return dot_product return x class ZeroShotEmbeddingForClustering(PreTrainedModel): config_class = ZeroShotEmbeddingConfig def __init__(self, config): super(ZeroShotEmbeddingForClustering, self).__init__(config) self.base_embedding_model = SentenceTransformer( config.base_embedding_model) self.head_model = ZeroShotEmbedding(config) def forward(self, texts, prompt, **kwargs): text_embeddings = self.base_embedding_model.encode(texts) prompt_embedding = self.base_embedding_model.encode(prompt) prompt_embeddings = np.tile(prompt_embedding, (len(texts), 1)) text_embeddings = torch.tensor(text_embeddings) prompt_embeddings = torch.tensor(prompt_embeddings) prompted_embeddings = self.head_model( prompt_embeddings, text_embeddings) similarity = torch.mm(prompted_embeddings, prompted_embeddings.transpose(0, 1)) return similarity @classmethod def from_pretrained_base(cls, pretrained_model_name_or_path): head_model = ZeroShotEmbedding.from_pretrained( pretrained_model_name_or_path) model = cls(head_model.config) cls.head_model = head_model return model ZeroShotEmbeddingConfig.register_for_auto_class() ZeroShotEmbedding.register_for_auto_class("AutoModel") ZeroShotEmbeddingForClustering.register_for_auto_class("AutoModel")