iamroot commited on
Commit
def591b
1 Parent(s): 8f9259e

Upload ZeroShotEmbedding

Browse files
Files changed (3) hide show
  1. config.json +16 -0
  2. model.py +90 -0
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ZeroShotEmbedding"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model.ZeroShotEmbeddingConfig",
7
+ "AutoModel": "model.ZeroShotEmbedding"
8
+ },
9
+ "base_embedding_model": "all-mpnet-base-v2",
10
+ "hidden_size": 2048,
11
+ "input_size": 768,
12
+ "model_type": "embedding-head",
13
+ "output_size": 128,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.35.0"
16
+ }
model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from sentence_transformers import SentenceTransformer
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ class ZeroShotEmbeddingConfig(PretrainedConfig):
9
+ model_type = "embedding-head"
10
+
11
+ def __init__(self, input_size=768, hidden_size=2048, output_size=128, base_embedding_model='all-mpnet-base-v2', **kwargs):
12
+ self.input_size = input_size
13
+ self.hidden_size = hidden_size
14
+ self.output_size = output_size
15
+ self.base_embedding_model = base_embedding_model
16
+ super().__init__(**kwargs)
17
+
18
+
19
+ class ZeroShotEmbedding(PreTrainedModel):
20
+ def __init__(self, config):
21
+ super(ZeroShotEmbedding, self).__init__(config)
22
+
23
+ input_size = config.input_size
24
+ hidden_size = config.hidden_size
25
+ output_size = config.output_size
26
+
27
+ self.input_size = input_size
28
+ self.hidden_size = hidden_size
29
+ self.output_size = output_size
30
+ # 3-layer MLP: input embedding -> hidden -> output embedding
31
+ self.fc1 = nn.Linear(input_size * 2, hidden_size)
32
+ self.fc2 = nn.Linear(hidden_size, output_size)
33
+ self.gelu = nn.GELU()
34
+
35
+ def forward(self, prompt_embedding, text_a_embedding, text_b_embedding=None, labels=None, **kwargs):
36
+ # document_embedding: [batch_size, input_size]
37
+ # prompt_embedding: [batch_size, input_size]
38
+ # output: [batch_size, output_size]
39
+
40
+ # concatenate document embedding and prompt embedding
41
+ # [batch_size, input_size * 2]
42
+ x = torch.cat((text_a_embedding, prompt_embedding), dim=1)
43
+ if text_b_embedding is not None:
44
+ # concatenate document embedding and prompt embedding
45
+ # [batch_size, input_size * 2]
46
+ x2 = torch.cat((text_b_embedding, prompt_embedding), dim=1)
47
+
48
+ # 3-layer MLP
49
+ x = self.fc1(x)
50
+ x = self.gelu(x)
51
+ x = self.fc2(x)
52
+ x = nn.functional.normalize(x, p=2, dim=1)
53
+ if text_b_embedding is not None:
54
+ x2 = self.fc1(x2)
55
+ x2 = self.gelu(x2)
56
+ x2 = self.fc2(x2)
57
+ x2 = nn.functional.normalize(x2, p=2, dim=1)
58
+ # Compute dot product for batches of output vectors
59
+ dot_product = torch.bmm(x.unsqueeze(1), x2.unsqueeze(2)).squeeze()
60
+ if labels is not None:
61
+ # Compute loss (magnitude of dot product minus label)
62
+ loss = torch.mean((dot_product - labels) ** 2)
63
+ return loss, dot_product
64
+ return dot_product
65
+ return x
66
+
67
+
68
+ class ZeroShotEmbeddingForClustering(PreTrainedModel):
69
+ def __init__(self, config):
70
+ super(ZeroShotEmbeddingForClustering, self).__init__(config)
71
+ self.base_embedding_model = SentenceTransformer(
72
+ config.base_embedding_model)
73
+ self.head_model = ZeroShotEmbedding(config)
74
+
75
+ def forward(self, texts, prompt, **kwargs):
76
+ text_embeddings = self.base_embedding_model.encode(texts)
77
+ prompt_embedding = self.base_embedding_model.encode(prompt)
78
+ prompt_embeddings = np.tile(prompt_embedding, (len(texts), 1))
79
+ text_embeddings = torch.tensor(text_embeddings)
80
+ prompt_embeddings = torch.tensor(prompt_embeddings)
81
+ prompted_embeddings = self.head_model(
82
+ prompt_embeddings, text_embeddings)
83
+ similarity = torch.mm(prompted_embeddings,
84
+ prompted_embeddings.transpose(0, 1))
85
+ return similarity
86
+
87
+
88
+ ZeroShotEmbeddingConfig.register_for_auto_class()
89
+ ZeroShotEmbedding.register_for_auto_class("AutoModel")
90
+ ZeroShotEmbeddingForClustering.register_for_auto_class("AutoModel")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d449b9dd7a347194e3596a3398581419c5f13a2efad7a82fe1478df0b152eec6
3
+ size 13640544