iamroot's picture
Upload ZeroShotEmbedding
f64bef6
raw
history blame contribute delete
No virus
4.01 kB
from transformers import PreTrainedModel, PretrainedConfig
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
import numpy as np
class ZeroShotEmbeddingConfig(PretrainedConfig):
model_type = "embedding-head"
def __init__(self, input_size=768, hidden_size=2048, output_size=128, base_embedding_model='all-mpnet-base-v2', **kwargs):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.base_embedding_model = base_embedding_model
super().__init__(**kwargs)
class ZeroShotEmbedding(PreTrainedModel):
config_class = ZeroShotEmbeddingConfig
def __init__(self, config):
super(ZeroShotEmbedding, self).__init__(config)
input_size = config.input_size
hidden_size = config.hidden_size
output_size = config.output_size
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
# 3-layer MLP: input embedding -> hidden -> output embedding
self.fc1 = nn.Linear(input_size * 2, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.gelu = nn.GELU()
def forward(self, prompt_embedding, text_a_embedding, text_b_embedding=None, labels=None, **kwargs):
# document_embedding: [batch_size, input_size]
# prompt_embedding: [batch_size, input_size]
# output: [batch_size, output_size]
# concatenate document embedding and prompt embedding
# [batch_size, input_size * 2]
x = torch.cat((text_a_embedding, prompt_embedding), dim=1)
if text_b_embedding is not None:
# concatenate document embedding and prompt embedding
# [batch_size, input_size * 2]
x2 = torch.cat((text_b_embedding, prompt_embedding), dim=1)
# 3-layer MLP
x = self.fc1(x)
x = self.gelu(x)
x = self.fc2(x)
x = nn.functional.normalize(x, p=2, dim=1)
if text_b_embedding is not None:
x2 = self.fc1(x2)
x2 = self.gelu(x2)
x2 = self.fc2(x2)
x2 = nn.functional.normalize(x2, p=2, dim=1)
# Compute dot product for batches of output vectors
dot_product = torch.bmm(x.unsqueeze(1), x2.unsqueeze(2)).squeeze()
if labels is not None:
# Compute loss (magnitude of dot product minus label)
loss = torch.mean((dot_product - labels) ** 2)
return loss, dot_product
return dot_product
return x
class ZeroShotEmbeddingForClustering(PreTrainedModel):
config_class = ZeroShotEmbeddingConfig
def __init__(self, config):
super(ZeroShotEmbeddingForClustering, self).__init__(config)
self.base_embedding_model = SentenceTransformer(
config.base_embedding_model)
self.head_model = ZeroShotEmbedding(config)
def forward(self, texts, prompt, **kwargs):
text_embeddings = self.base_embedding_model.encode(texts)
prompt_embedding = self.base_embedding_model.encode(prompt)
prompt_embeddings = np.tile(prompt_embedding, (len(texts), 1))
text_embeddings = torch.tensor(text_embeddings)
prompt_embeddings = torch.tensor(prompt_embeddings)
prompted_embeddings = self.head_model(
prompt_embeddings, text_embeddings)
similarity = torch.mm(prompted_embeddings,
prompted_embeddings.transpose(0, 1))
return similarity
@classmethod
def from_pretrained_base(cls, pretrained_model_name_or_path):
head_model = ZeroShotEmbedding.from_pretrained(
pretrained_model_name_or_path)
model = cls(head_model.config)
cls.head_model = head_model
return model
ZeroShotEmbeddingConfig.register_for_auto_class()
ZeroShotEmbedding.register_for_auto_class("AutoModel")
ZeroShotEmbeddingForClustering.register_for_auto_class("AutoModel")