File size: 3,646 Bytes
def591b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import PreTrainedModel, PretrainedConfig
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
import numpy as np


class ZeroShotEmbeddingConfig(PretrainedConfig):
    model_type = "embedding-head"

    def __init__(self, input_size=768, hidden_size=2048, output_size=128, base_embedding_model='all-mpnet-base-v2', **kwargs):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.base_embedding_model = base_embedding_model
        super().__init__(**kwargs)


class ZeroShotEmbedding(PreTrainedModel):
    def __init__(self, config):
        super(ZeroShotEmbedding, self).__init__(config)

        input_size = config.input_size
        hidden_size = config.hidden_size
        output_size = config.output_size

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        # 3-layer MLP: input embedding -> hidden -> output embedding
        self.fc1 = nn.Linear(input_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.gelu = nn.GELU()

    def forward(self, prompt_embedding, text_a_embedding, text_b_embedding=None, labels=None, **kwargs):
        # document_embedding: [batch_size, input_size]
        # prompt_embedding: [batch_size, input_size]
        # output: [batch_size, output_size]

        # concatenate document embedding and prompt embedding
        # [batch_size, input_size * 2]
        x = torch.cat((text_a_embedding, prompt_embedding), dim=1)
        if text_b_embedding is not None:
            # concatenate document embedding and prompt embedding
            # [batch_size, input_size * 2]
            x2 = torch.cat((text_b_embedding, prompt_embedding), dim=1)

        # 3-layer MLP
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        x = nn.functional.normalize(x, p=2, dim=1)
        if text_b_embedding is not None:
            x2 = self.fc1(x2)
            x2 = self.gelu(x2)
            x2 = self.fc2(x2)
            x2 = nn.functional.normalize(x2, p=2, dim=1)
            # Compute dot product for batches of output vectors
            dot_product = torch.bmm(x.unsqueeze(1), x2.unsqueeze(2)).squeeze()
            if labels is not None:
                # Compute loss (magnitude of dot product minus label)
                loss = torch.mean((dot_product - labels) ** 2)
                return loss, dot_product
            return dot_product
        return x


class ZeroShotEmbeddingForClustering(PreTrainedModel):
    def __init__(self, config):
        super(ZeroShotEmbeddingForClustering, self).__init__(config)
        self.base_embedding_model = SentenceTransformer(
            config.base_embedding_model)
        self.head_model = ZeroShotEmbedding(config)

    def forward(self, texts, prompt, **kwargs):
        text_embeddings = self.base_embedding_model.encode(texts)
        prompt_embedding = self.base_embedding_model.encode(prompt)
        prompt_embeddings = np.tile(prompt_embedding, (len(texts), 1))
        text_embeddings = torch.tensor(text_embeddings)
        prompt_embeddings = torch.tensor(prompt_embeddings)
        prompted_embeddings = self.head_model(
            prompt_embeddings, text_embeddings)
        similarity = torch.mm(prompted_embeddings,
                              prompted_embeddings.transpose(0, 1))
        return similarity


ZeroShotEmbeddingConfig.register_for_auto_class()
ZeroShotEmbedding.register_for_auto_class("AutoModel")
ZeroShotEmbeddingForClustering.register_for_auto_class("AutoModel")