windy2612 commited on
Commit
3dd5d3e
1 Parent(s): cf3c0c2

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitattributes +3 -35
  2. .gitignore +3 -0
  3. App.py +9 -0
  4. Model.py +279 -0
  5. Predict.py +35 -0
  6. last_checkpoint.pt +3 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ .pt filter=lfs diff=lfs merge=lfs -text
2
+ Checkpoint.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /Main.ipynb/
2
+ /Dataset/
3
+ /Checkpoint/
App.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from Predict import generate_caption
3
+
4
+ interface = gr.Interface(
5
+ fn = generate_caption,
6
+ inputs =[gr.components.Image(), gr.components.Textbox(label = "Question")],
7
+ outputs=[gr.components.Textbox(label = "Answer", lines=3)]
8
+ )
9
+ interface.launch(share = True, debug = True)
Model.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModel, AutoFeatureExtractor
5
+ import numpy as np
6
+ import math
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available else "cpu")
12
+ vision_model_name = "google/vit-base-patch16-224-in21k"
13
+ language_model_name = "vinai/phobert-base"
14
+
15
+
16
+
17
+ def generate_padding_mask(sequences, padding_idx):
18
+ if sequences is None:
19
+ return None
20
+ if len(sequences.shape) == 2:
21
+ __seq = sequences.unsqueeze(dim=-1)
22
+ else:
23
+ __seq = sequences
24
+ mask = (torch.sum(__seq, dim=-1) == (padding_idx*__seq.shape[-1])).long() * -10e4
25
+ return mask.unsqueeze(1).unsqueeze(1)
26
+
27
+
28
+ class ScaledDotProduct(nn.Module):
29
+ def __init__(self, d_model = 512, h = 8, d_k = 64, d_v = 64):
30
+ super().__init__()
31
+
32
+ self.fc_q = nn.Linear(d_model, h * d_k)
33
+ self.fc_k = nn.Linear(d_model, h * d_k)
34
+ self.fc_v = nn.Linear(d_model, h * d_v)
35
+ self.fc_o = nn.Linear(h * d_v, d_model)
36
+
37
+ self.d_model = d_model
38
+ self.d_k = d_k
39
+ self.d_v = d_v
40
+ self.h = h
41
+
42
+ self.init_weights()
43
+
44
+ def init_weights(self):
45
+ nn.init.xavier_uniform_(self.fc_q.weight)
46
+ nn.init.xavier_uniform_(self.fc_k.weight)
47
+ nn.init.xavier_uniform_(self.fc_v.weight)
48
+ nn.init.xavier_uniform_(self.fc_o.weight)
49
+ nn.init.constant_(self.fc_q.bias, 0)
50
+ nn.init.constant_(self.fc_k.bias, 0)
51
+ nn.init.constant_(self.fc_v.bias, 0)
52
+ nn.init.constant_(self.fc_o.bias, 0)
53
+
54
+ def forward(self, queries, keys, values, attention_mask=None, **kwargs):
55
+ b_s, nq = queries.shape[:2]
56
+ nk = keys.shape[1]
57
+ q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
58
+ k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
59
+ v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
60
+
61
+ att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
62
+ if attention_mask is not None:
63
+ att += attention_mask
64
+ att = torch.softmax(att, dim=-1)
65
+ out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
66
+ out = self.fc_o(out) # (b_s, nq, d_model)
67
+
68
+ return out, att
69
+
70
+
71
+ class MultiheadAttention(nn.Module):
72
+
73
+ def __init__(self, d_model = 512, dropout = 0.1, use_aoa = True):
74
+ super().__init__()
75
+ self.d_model = d_model
76
+ self.use_aoa = use_aoa
77
+
78
+ self.attention = ScaledDotProduct()
79
+ self.norm = nn.LayerNorm(d_model)
80
+ self.dropout = nn.Dropout(dropout)
81
+ if self.use_aoa:
82
+ self.infomative_attention = nn.Linear(2 * self.d_model, self.d_model)
83
+ self.gated_attention = nn.Linear(2 * self.d_model, self.d_model)
84
+
85
+ def forward(self, q, k, v, mask = None):
86
+ out, _ = self.attention(q, k, v, mask)
87
+ if self.use_aoa:
88
+ aoa_input = torch.cat([q, out], dim = -1)
89
+ i = self.infomative_attention(aoa_input)
90
+ g = torch.sigmoid(self.gated_attention(aoa_input))
91
+ out = i * g
92
+ return out
93
+
94
+
95
+ class PositionWiseFeedForward(nn.Module):
96
+ def __init__(self, d_model = 512, d_ff = 2048, dropout = 0.1):
97
+ super().__init__()
98
+ self.fc1 = nn.Linear(d_model, d_ff)
99
+ self.fc2 = nn.Linear(d_ff, d_model)
100
+ self.relu = nn.ReLU()
101
+
102
+ def forward(self, input):
103
+ out = self.fc1(input)
104
+ out = self.fc2(self.relu(out))
105
+ return out
106
+
107
+ class AddNorm(nn.Module):
108
+ def __init__(self, dim = 512, dropout = 0.1):
109
+ super().__init__()
110
+ self.dropout = nn.Dropout(dropout)
111
+ self.norm = nn.LayerNorm(dim)
112
+
113
+ def forward(self, x, y):
114
+ return self.norm(x + self.dropout(y))
115
+
116
+
117
+ class SinusoidPositionalEmbedding(nn.Module):
118
+ def __init__(self, num_pos_feats=512, temperature=10000, normalize=False, scale=None):
119
+ super().__init__()
120
+ self.num_pos_feats = num_pos_feats
121
+ self.temperature = temperature
122
+ self.normalize = normalize
123
+ if scale is not None and normalize is False:
124
+ raise ValueError("normalize should be True if scale is passed")
125
+ if scale is None:
126
+ scale = 2 * math.pi
127
+ self.scale = scale
128
+
129
+ def forward(self, x, mask=None):
130
+ if mask is None:
131
+ mask = torch.zeros(x.shape[:-1], dtype=torch.bool, device=x.device)
132
+ not_mask = (mask == False)
133
+ embed = not_mask.cumsum(1, dtype=torch.float32)
134
+ if self.normalize:
135
+ eps = 1e-6
136
+ embed = embed / (embed[:, -1:] + eps) * self.scale
137
+
138
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
139
+ dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / self.num_pos_feats)
140
+
141
+ pos = embed[:, :, None] / dim_t
142
+ pos = torch.stack((pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=-1).flatten(-2)
143
+
144
+ return pos
145
+
146
+
147
+ class GuidedEncoderLayer(nn.Module):
148
+ def __init__(self):
149
+ super().__init__()
150
+ self.self_mhatt = MultiheadAttention()
151
+ self.guided_mhatt = MultiheadAttention()
152
+ self.pwff = PositionWiseFeedForward()
153
+ self.first_norm = AddNorm()
154
+ self.second_norm = AddNorm()
155
+ self.third_norm = AddNorm()
156
+ def forward(self, q, k, v, self_mask, guided_mask):
157
+ self_att = self.self_mhatt(q, q, q, self_mask)
158
+ self_att = self.first_norm(self_att, q)
159
+ guided_att = self.guided_mhatt(self_att, k, v, guided_mask)
160
+ guided_att = self.second_norm(guided_att, self_att)
161
+ out = self.pwff(guided_att)
162
+ out = self.third_norm(out, guided_att)
163
+ return out
164
+
165
+
166
+ class GuidedAttentionEncoder(nn.Module):
167
+ def __init__(self, num_layers = 2, d_model = 512):
168
+ super().__init__()
169
+ self.pos_embedding = SinusoidPositionalEmbedding()
170
+ self.layer_norm = nn.LayerNorm(d_model)
171
+
172
+ self.guided_layers = nn.ModuleList([GuidedEncoderLayer() for _ in range(num_layers)])
173
+ self.language_layers = nn.ModuleList(GuidedEncoderLayer() for _ in range(num_layers))
174
+
175
+ def forward(self, vision_features, vision_mask, language_features, language_mask):
176
+ vision_features = self.layer_norm(vision_features) + self.pos_embedding(vision_features)
177
+ language_features = self.layer_norm(language_features) + self.pos_embedding(language_features)
178
+
179
+ for layers in zip(self.guided_layers, self.language_layers):
180
+ guided_layer, language_layer = layers
181
+ vision_features = guided_layer(q = vision_features,
182
+ k = language_features,
183
+ v = language_features,
184
+ self_mask = vision_mask,
185
+ guided_mask = language_mask)
186
+ language_features = language_layer(q = language_features,
187
+ k = vision_features,
188
+ v = vision_features,
189
+ self_mask = language_mask,
190
+ guided_mask = vision_mask)
191
+
192
+ return vision_features, language_features
193
+
194
+
195
+ class VisionEmbedding(nn.Module):
196
+ def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
197
+ super().__init__()
198
+ self.prep = AutoFeatureExtractor.from_pretrained(vision_model_name)
199
+ self.backbone = AutoModel.from_pretrained(vision_model_name)
200
+ for param in self.backbone.parameters():
201
+ param.requires_grad = False
202
+
203
+ self.proj = nn.Linear(out_dim, hidden_dim)
204
+ self.dropout = nn.Dropout(dropout)
205
+ self.gelu = nn.GELU()
206
+ def forward(self, images):
207
+ inputs = self.prep(images = images, return_tensors = "pt").to(device)
208
+ with torch.no_grad():
209
+ outputs = self.backbone(**inputs)
210
+ features = outputs.last_hidden_state
211
+ vision_mask = generate_padding_mask(features, padding_idx = 0)
212
+ out = self.proj(features)
213
+ out = self.gelu(out)
214
+ out = self.dropout(out)
215
+ return out, vision_mask
216
+
217
+
218
+ class LanguageEmbedding(nn.Module):
219
+ def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
220
+ super().__init__()
221
+ self.tokenizer = AutoTokenizer.from_pretrained(language_model_name)
222
+ self.embeding = AutoModel.from_pretrained(language_model_name)
223
+ for param in self.embeding.parameters():
224
+ param.requires_grad = False
225
+ self.proj = nn.Linear(out_dim, hidden_dim)
226
+ self.dropout = nn.Dropout(dropout)
227
+ self.gelu = nn.GELU()
228
+ def forward(self, questions):
229
+ inputs = self.tokenizer(questions,
230
+ padding = 'max_length',
231
+ max_length = 30,
232
+ truncation = True,
233
+ return_tensors = 'pt',
234
+ return_token_type_ids = True,
235
+ return_attention_mask = True).to(device)
236
+
237
+ features = self.embeding(**inputs).last_hidden_state
238
+ language_mask = generate_padding_mask(inputs.input_ids, padding_idx=self.tokenizer.pad_token_id)
239
+ out = self.proj(features)
240
+ out = self.gelu(out)
241
+ out = self.dropout(out)
242
+ return out, language_mask
243
+
244
+ class BaseModel(nn.Module):
245
+ def __init__(self, num_classes = 353, d_model = 512):
246
+ super().__init__()
247
+ self.vision_embedding = VisionEmbedding()
248
+ self.language_embedding = LanguageEmbedding()
249
+ self.encoder = GuidedAttentionEncoder()
250
+ self.fusion = nn.Sequential(nn.Linear(2 * d_model, d_model),
251
+ nn.ReLU(),
252
+ nn.Dropout(0.2))
253
+ self.classify = nn.Linear(d_model, num_classes)
254
+ self.attention_weights = nn.Linear(d_model, 1)
255
+
256
+ def forward(self, images, questions):
257
+ embedded_text, text_mask = self.language_embedding(questions)
258
+ embedded_vision, vison_mask = self.vision_embedding(images)
259
+
260
+ encoded_image, encoded_text = self.encoder(embedded_vision, vison_mask,embedded_text, text_mask)
261
+ text_attended = self.attention_weights(torch.tanh(encoded_text))
262
+ image_attended = self.attention_weights(torch.tanh(encoded_image))
263
+
264
+ attention_weights = torch.softmax(torch.cat([text_attended, image_attended], dim=1), dim=1)
265
+
266
+ attended_text = torch.sum(attention_weights[:, 0].unsqueeze(-1) * encoded_text, dim=1)
267
+ attended_image = torch.sum(attention_weights[:, 1].unsqueeze(-1) * encoded_image, dim=1)
268
+
269
+ fused_output = self.fusion(torch.cat([attended_text, attended_image], dim=1))
270
+ logits = self.classify(fused_output)
271
+ logits = F.log_softmax(logits, dim=-1)
272
+ return logits
273
+
274
+
275
+
276
+ if __name__ == "__main__":
277
+ model = BaseModel().to(device)
278
+ print(model.eval)
279
+
Predict.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Model import BaseModel
2
+ import json
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision import transforms as T
6
+ import torch
7
+
8
+
9
+ checkpoint = torch.load('Checkpoint/checkpoint.pt')
10
+ with open('Dataset/answer.json', 'r', encoding = 'utf8') as f:
11
+ answer_space = json.load(f)
12
+ swap_space = {v : k for k, v in answer_space.items()}
13
+
14
+
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ model = BaseModel().to(device)
17
+ model.load_state_dict(checkpoint['model_state_dict'])
18
+
19
+ def generate_caption(image, question):
20
+ if isinstance(image, np.ndarray):
21
+ image = Image.fromarray(image)
22
+ elif isinstance(image, str):
23
+ image = Image.open(image).convert("RGB")
24
+ transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
25
+ image = transform(image).unsqueeze(0)
26
+ with torch.no_grad():
27
+ logits = model(image, question)
28
+ idx = torch.argmax(logits)
29
+ return swap_space[idx.item()]
30
+
31
+ if __name__ == "__main__":
32
+ image = 'Dataset/train/68857.jpg'
33
+ question = 'màu của chiếc bình là gì'
34
+ pred = generate_caption(image, question)
35
+ print(pred)
last_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ccdd750d65a5f4c508d8b0702ee792f9e27f9be52ffa5c42b98c15420c1c9d1
3
+ size 1105547492