windy2612 commited on
Commit
36de638
1 Parent(s): fdc4612

Update Model.py

Browse files
Files changed (1) hide show
  1. Model.py +278 -278
Model.py CHANGED
@@ -1,279 +1,279 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers import AutoTokenizer, AutoModel, AutoFeatureExtractor
5
- import numpy as np
6
- import math
7
- import warnings
8
- warnings.filterwarnings("ignore")
9
-
10
-
11
- device = torch.device("cuda" if torch.cuda.is_available else "cpu")
12
- vision_model_name = "google/vit-base-patch16-224-in21k"
13
- language_model_name = "vinai/phobert-base"
14
-
15
-
16
-
17
- def generate_padding_mask(sequences, padding_idx):
18
- if sequences is None:
19
- return None
20
- if len(sequences.shape) == 2:
21
- __seq = sequences.unsqueeze(dim=-1)
22
- else:
23
- __seq = sequences
24
- mask = (torch.sum(__seq, dim=-1) == (padding_idx*__seq.shape[-1])).long() * -10e4
25
- return mask.unsqueeze(1).unsqueeze(1)
26
-
27
-
28
- class ScaledDotProduct(nn.Module):
29
- def __init__(self, d_model = 512, h = 8, d_k = 64, d_v = 64):
30
- super().__init__()
31
-
32
- self.fc_q = nn.Linear(d_model, h * d_k)
33
- self.fc_k = nn.Linear(d_model, h * d_k)
34
- self.fc_v = nn.Linear(d_model, h * d_v)
35
- self.fc_o = nn.Linear(h * d_v, d_model)
36
-
37
- self.d_model = d_model
38
- self.d_k = d_k
39
- self.d_v = d_v
40
- self.h = h
41
-
42
- self.init_weights()
43
-
44
- def init_weights(self):
45
- nn.init.xavier_uniform_(self.fc_q.weight)
46
- nn.init.xavier_uniform_(self.fc_k.weight)
47
- nn.init.xavier_uniform_(self.fc_v.weight)
48
- nn.init.xavier_uniform_(self.fc_o.weight)
49
- nn.init.constant_(self.fc_q.bias, 0)
50
- nn.init.constant_(self.fc_k.bias, 0)
51
- nn.init.constant_(self.fc_v.bias, 0)
52
- nn.init.constant_(self.fc_o.bias, 0)
53
-
54
- def forward(self, queries, keys, values, attention_mask=None, **kwargs):
55
- b_s, nq = queries.shape[:2]
56
- nk = keys.shape[1]
57
- q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
58
- k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
59
- v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
60
-
61
- att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
62
- if attention_mask is not None:
63
- att += attention_mask
64
- att = torch.softmax(att, dim=-1)
65
- out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
66
- out = self.fc_o(out) # (b_s, nq, d_model)
67
-
68
- return out, att
69
-
70
-
71
- class MultiheadAttention(nn.Module):
72
-
73
- def __init__(self, d_model = 512, dropout = 0.1, use_aoa = True):
74
- super().__init__()
75
- self.d_model = d_model
76
- self.use_aoa = use_aoa
77
-
78
- self.attention = ScaledDotProduct()
79
- self.norm = nn.LayerNorm(d_model)
80
- self.dropout = nn.Dropout(dropout)
81
- if self.use_aoa:
82
- self.infomative_attention = nn.Linear(2 * self.d_model, self.d_model)
83
- self.gated_attention = nn.Linear(2 * self.d_model, self.d_model)
84
-
85
- def forward(self, q, k, v, mask = None):
86
- out, _ = self.attention(q, k, v, mask)
87
- if self.use_aoa:
88
- aoa_input = torch.cat([q, out], dim = -1)
89
- i = self.infomative_attention(aoa_input)
90
- g = torch.sigmoid(self.gated_attention(aoa_input))
91
- out = i * g
92
- return out
93
-
94
-
95
- class PositionWiseFeedForward(nn.Module):
96
- def __init__(self, d_model = 512, d_ff = 2048, dropout = 0.1):
97
- super().__init__()
98
- self.fc1 = nn.Linear(d_model, d_ff)
99
- self.fc2 = nn.Linear(d_ff, d_model)
100
- self.relu = nn.ReLU()
101
-
102
- def forward(self, input):
103
- out = self.fc1(input)
104
- out = self.fc2(self.relu(out))
105
- return out
106
-
107
- class AddNorm(nn.Module):
108
- def __init__(self, dim = 512, dropout = 0.1):
109
- super().__init__()
110
- self.dropout = nn.Dropout(dropout)
111
- self.norm = nn.LayerNorm(dim)
112
-
113
- def forward(self, x, y):
114
- return self.norm(x + self.dropout(y))
115
-
116
-
117
- class SinusoidPositionalEmbedding(nn.Module):
118
- def __init__(self, num_pos_feats=512, temperature=10000, normalize=False, scale=None):
119
- super().__init__()
120
- self.num_pos_feats = num_pos_feats
121
- self.temperature = temperature
122
- self.normalize = normalize
123
- if scale is not None and normalize is False:
124
- raise ValueError("normalize should be True if scale is passed")
125
- if scale is None:
126
- scale = 2 * math.pi
127
- self.scale = scale
128
-
129
- def forward(self, x, mask=None):
130
- if mask is None:
131
- mask = torch.zeros(x.shape[:-1], dtype=torch.bool, device=x.device)
132
- not_mask = (mask == False)
133
- embed = not_mask.cumsum(1, dtype=torch.float32)
134
- if self.normalize:
135
- eps = 1e-6
136
- embed = embed / (embed[:, -1:] + eps) * self.scale
137
-
138
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
139
- dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / self.num_pos_feats)
140
-
141
- pos = embed[:, :, None] / dim_t
142
- pos = torch.stack((pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=-1).flatten(-2)
143
-
144
- return pos
145
-
146
-
147
- class GuidedEncoderLayer(nn.Module):
148
- def __init__(self):
149
- super().__init__()
150
- self.self_mhatt = MultiheadAttention()
151
- self.guided_mhatt = MultiheadAttention()
152
- self.pwff = PositionWiseFeedForward()
153
- self.first_norm = AddNorm()
154
- self.second_norm = AddNorm()
155
- self.third_norm = AddNorm()
156
- def forward(self, q, k, v, self_mask, guided_mask):
157
- self_att = self.self_mhatt(q, q, q, self_mask)
158
- self_att = self.first_norm(self_att, q)
159
- guided_att = self.guided_mhatt(self_att, k, v, guided_mask)
160
- guided_att = self.second_norm(guided_att, self_att)
161
- out = self.pwff(guided_att)
162
- out = self.third_norm(out, guided_att)
163
- return out
164
-
165
-
166
- class GuidedAttentionEncoder(nn.Module):
167
- def __init__(self, num_layers = 2, d_model = 512):
168
- super().__init__()
169
- self.pos_embedding = SinusoidPositionalEmbedding()
170
- self.layer_norm = nn.LayerNorm(d_model)
171
-
172
- self.guided_layers = nn.ModuleList([GuidedEncoderLayer() for _ in range(num_layers)])
173
- self.language_layers = nn.ModuleList(GuidedEncoderLayer() for _ in range(num_layers))
174
-
175
- def forward(self, vision_features, vision_mask, language_features, language_mask):
176
- vision_features = self.layer_norm(vision_features) + self.pos_embedding(vision_features)
177
- language_features = self.layer_norm(language_features) + self.pos_embedding(language_features)
178
-
179
- for layers in zip(self.guided_layers, self.language_layers):
180
- guided_layer, language_layer = layers
181
- vision_features = guided_layer(q = vision_features,
182
- k = language_features,
183
- v = language_features,
184
- self_mask = vision_mask,
185
- guided_mask = language_mask)
186
- language_features = language_layer(q = language_features,
187
- k = vision_features,
188
- v = vision_features,
189
- self_mask = language_mask,
190
- guided_mask = vision_mask)
191
-
192
- return vision_features, language_features
193
-
194
-
195
- class VisionEmbedding(nn.Module):
196
- def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
197
- super().__init__()
198
- self.prep = AutoFeatureExtractor.from_pretrained(vision_model_name)
199
- self.backbone = AutoModel.from_pretrained(vision_model_name)
200
- for param in self.backbone.parameters():
201
- param.requires_grad = False
202
-
203
- self.proj = nn.Linear(out_dim, hidden_dim)
204
- self.dropout = nn.Dropout(dropout)
205
- self.gelu = nn.GELU()
206
- def forward(self, images):
207
- inputs = self.prep(images = images, return_tensors = "pt").to(device)
208
- with torch.no_grad():
209
- outputs = self.backbone(**inputs)
210
- features = outputs.last_hidden_state
211
- vision_mask = generate_padding_mask(features, padding_idx = 0)
212
- out = self.proj(features)
213
- out = self.gelu(out)
214
- out = self.dropout(out)
215
- return out, vision_mask
216
-
217
-
218
- class LanguageEmbedding(nn.Module):
219
- def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
220
- super().__init__()
221
- self.tokenizer = AutoTokenizer.from_pretrained(language_model_name)
222
- self.embeding = AutoModel.from_pretrained(language_model_name)
223
- for param in self.embeding.parameters():
224
- param.requires_grad = False
225
- self.proj = nn.Linear(out_dim, hidden_dim)
226
- self.dropout = nn.Dropout(dropout)
227
- self.gelu = nn.GELU()
228
- def forward(self, questions):
229
- inputs = self.tokenizer(questions,
230
- padding = 'max_length',
231
- max_length = 30,
232
- truncation = True,
233
- return_tensors = 'pt',
234
- return_token_type_ids = True,
235
- return_attention_mask = True).to(device)
236
-
237
- features = self.embeding(**inputs).last_hidden_state
238
- language_mask = generate_padding_mask(inputs.input_ids, padding_idx=self.tokenizer.pad_token_id)
239
- out = self.proj(features)
240
- out = self.gelu(out)
241
- out = self.dropout(out)
242
- return out, language_mask
243
-
244
- class BaseModel(nn.Module):
245
- def __init__(self, num_classes = 353, d_model = 512):
246
- super().__init__()
247
- self.vision_embedding = VisionEmbedding()
248
- self.language_embedding = LanguageEmbedding()
249
- self.encoder = GuidedAttentionEncoder()
250
- self.fusion = nn.Sequential(nn.Linear(2 * d_model, d_model),
251
- nn.ReLU(),
252
- nn.Dropout(0.2))
253
- self.classify = nn.Linear(d_model, num_classes)
254
- self.attention_weights = nn.Linear(d_model, 1)
255
-
256
- def forward(self, images, questions):
257
- embedded_text, text_mask = self.language_embedding(questions)
258
- embedded_vision, vison_mask = self.vision_embedding(images)
259
-
260
- encoded_image, encoded_text = self.encoder(embedded_vision, vison_mask,embedded_text, text_mask)
261
- text_attended = self.attention_weights(torch.tanh(encoded_text))
262
- image_attended = self.attention_weights(torch.tanh(encoded_image))
263
-
264
- attention_weights = torch.softmax(torch.cat([text_attended, image_attended], dim=1), dim=1)
265
-
266
- attended_text = torch.sum(attention_weights[:, 0].unsqueeze(-1) * encoded_text, dim=1)
267
- attended_image = torch.sum(attention_weights[:, 1].unsqueeze(-1) * encoded_image, dim=1)
268
-
269
- fused_output = self.fusion(torch.cat([attended_text, attended_image], dim=1))
270
- logits = self.classify(fused_output)
271
- logits = F.log_softmax(logits, dim=-1)
272
- return logits
273
-
274
-
275
-
276
- if __name__ == "__main__":
277
- model = BaseModel().to(device)
278
- print(model.eval)
279
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModel, AutoFeatureExtractor
5
+ import numpy as np
6
+ import math
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+
11
+ device = torch.device("cpu")
12
+ vision_model_name = "google/vit-base-patch16-224-in21k"
13
+ language_model_name = "vinai/phobert-base"
14
+
15
+
16
+
17
+ def generate_padding_mask(sequences, padding_idx):
18
+ if sequences is None:
19
+ return None
20
+ if len(sequences.shape) == 2:
21
+ __seq = sequences.unsqueeze(dim=-1)
22
+ else:
23
+ __seq = sequences
24
+ mask = (torch.sum(__seq, dim=-1) == (padding_idx*__seq.shape[-1])).long() * -10e4
25
+ return mask.unsqueeze(1).unsqueeze(1)
26
+
27
+
28
+ class ScaledDotProduct(nn.Module):
29
+ def __init__(self, d_model = 512, h = 8, d_k = 64, d_v = 64):
30
+ super().__init__()
31
+
32
+ self.fc_q = nn.Linear(d_model, h * d_k)
33
+ self.fc_k = nn.Linear(d_model, h * d_k)
34
+ self.fc_v = nn.Linear(d_model, h * d_v)
35
+ self.fc_o = nn.Linear(h * d_v, d_model)
36
+
37
+ self.d_model = d_model
38
+ self.d_k = d_k
39
+ self.d_v = d_v
40
+ self.h = h
41
+
42
+ self.init_weights()
43
+
44
+ def init_weights(self):
45
+ nn.init.xavier_uniform_(self.fc_q.weight)
46
+ nn.init.xavier_uniform_(self.fc_k.weight)
47
+ nn.init.xavier_uniform_(self.fc_v.weight)
48
+ nn.init.xavier_uniform_(self.fc_o.weight)
49
+ nn.init.constant_(self.fc_q.bias, 0)
50
+ nn.init.constant_(self.fc_k.bias, 0)
51
+ nn.init.constant_(self.fc_v.bias, 0)
52
+ nn.init.constant_(self.fc_o.bias, 0)
53
+
54
+ def forward(self, queries, keys, values, attention_mask=None, **kwargs):
55
+ b_s, nq = queries.shape[:2]
56
+ nk = keys.shape[1]
57
+ q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
58
+ k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
59
+ v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
60
+
61
+ att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
62
+ if attention_mask is not None:
63
+ att += attention_mask
64
+ att = torch.softmax(att, dim=-1)
65
+ out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
66
+ out = self.fc_o(out) # (b_s, nq, d_model)
67
+
68
+ return out, att
69
+
70
+
71
+ class MultiheadAttention(nn.Module):
72
+
73
+ def __init__(self, d_model = 512, dropout = 0.1, use_aoa = True):
74
+ super().__init__()
75
+ self.d_model = d_model
76
+ self.use_aoa = use_aoa
77
+
78
+ self.attention = ScaledDotProduct()
79
+ self.norm = nn.LayerNorm(d_model)
80
+ self.dropout = nn.Dropout(dropout)
81
+ if self.use_aoa:
82
+ self.infomative_attention = nn.Linear(2 * self.d_model, self.d_model)
83
+ self.gated_attention = nn.Linear(2 * self.d_model, self.d_model)
84
+
85
+ def forward(self, q, k, v, mask = None):
86
+ out, _ = self.attention(q, k, v, mask)
87
+ if self.use_aoa:
88
+ aoa_input = torch.cat([q, out], dim = -1)
89
+ i = self.infomative_attention(aoa_input)
90
+ g = torch.sigmoid(self.gated_attention(aoa_input))
91
+ out = i * g
92
+ return out
93
+
94
+
95
+ class PositionWiseFeedForward(nn.Module):
96
+ def __init__(self, d_model = 512, d_ff = 2048, dropout = 0.1):
97
+ super().__init__()
98
+ self.fc1 = nn.Linear(d_model, d_ff)
99
+ self.fc2 = nn.Linear(d_ff, d_model)
100
+ self.relu = nn.ReLU()
101
+
102
+ def forward(self, input):
103
+ out = self.fc1(input)
104
+ out = self.fc2(self.relu(out))
105
+ return out
106
+
107
+ class AddNorm(nn.Module):
108
+ def __init__(self, dim = 512, dropout = 0.1):
109
+ super().__init__()
110
+ self.dropout = nn.Dropout(dropout)
111
+ self.norm = nn.LayerNorm(dim)
112
+
113
+ def forward(self, x, y):
114
+ return self.norm(x + self.dropout(y))
115
+
116
+
117
+ class SinusoidPositionalEmbedding(nn.Module):
118
+ def __init__(self, num_pos_feats=512, temperature=10000, normalize=False, scale=None):
119
+ super().__init__()
120
+ self.num_pos_feats = num_pos_feats
121
+ self.temperature = temperature
122
+ self.normalize = normalize
123
+ if scale is not None and normalize is False:
124
+ raise ValueError("normalize should be True if scale is passed")
125
+ if scale is None:
126
+ scale = 2 * math.pi
127
+ self.scale = scale
128
+
129
+ def forward(self, x, mask=None):
130
+ if mask is None:
131
+ mask = torch.zeros(x.shape[:-1], dtype=torch.bool, device=x.device)
132
+ not_mask = (mask == False)
133
+ embed = not_mask.cumsum(1, dtype=torch.float32)
134
+ if self.normalize:
135
+ eps = 1e-6
136
+ embed = embed / (embed[:, -1:] + eps) * self.scale
137
+
138
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
139
+ dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / self.num_pos_feats)
140
+
141
+ pos = embed[:, :, None] / dim_t
142
+ pos = torch.stack((pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=-1).flatten(-2)
143
+
144
+ return pos
145
+
146
+
147
+ class GuidedEncoderLayer(nn.Module):
148
+ def __init__(self):
149
+ super().__init__()
150
+ self.self_mhatt = MultiheadAttention()
151
+ self.guided_mhatt = MultiheadAttention()
152
+ self.pwff = PositionWiseFeedForward()
153
+ self.first_norm = AddNorm()
154
+ self.second_norm = AddNorm()
155
+ self.third_norm = AddNorm()
156
+ def forward(self, q, k, v, self_mask, guided_mask):
157
+ self_att = self.self_mhatt(q, q, q, self_mask)
158
+ self_att = self.first_norm(self_att, q)
159
+ guided_att = self.guided_mhatt(self_att, k, v, guided_mask)
160
+ guided_att = self.second_norm(guided_att, self_att)
161
+ out = self.pwff(guided_att)
162
+ out = self.third_norm(out, guided_att)
163
+ return out
164
+
165
+
166
+ class GuidedAttentionEncoder(nn.Module):
167
+ def __init__(self, num_layers = 2, d_model = 512):
168
+ super().__init__()
169
+ self.pos_embedding = SinusoidPositionalEmbedding()
170
+ self.layer_norm = nn.LayerNorm(d_model)
171
+
172
+ self.guided_layers = nn.ModuleList([GuidedEncoderLayer() for _ in range(num_layers)])
173
+ self.language_layers = nn.ModuleList(GuidedEncoderLayer() for _ in range(num_layers))
174
+
175
+ def forward(self, vision_features, vision_mask, language_features, language_mask):
176
+ vision_features = self.layer_norm(vision_features) + self.pos_embedding(vision_features)
177
+ language_features = self.layer_norm(language_features) + self.pos_embedding(language_features)
178
+
179
+ for layers in zip(self.guided_layers, self.language_layers):
180
+ guided_layer, language_layer = layers
181
+ vision_features = guided_layer(q = vision_features,
182
+ k = language_features,
183
+ v = language_features,
184
+ self_mask = vision_mask,
185
+ guided_mask = language_mask)
186
+ language_features = language_layer(q = language_features,
187
+ k = vision_features,
188
+ v = vision_features,
189
+ self_mask = language_mask,
190
+ guided_mask = vision_mask)
191
+
192
+ return vision_features, language_features
193
+
194
+
195
+ class VisionEmbedding(nn.Module):
196
+ def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
197
+ super().__init__()
198
+ self.prep = AutoFeatureExtractor.from_pretrained(vision_model_name)
199
+ self.backbone = AutoModel.from_pretrained(vision_model_name)
200
+ for param in self.backbone.parameters():
201
+ param.requires_grad = False
202
+
203
+ self.proj = nn.Linear(out_dim, hidden_dim)
204
+ self.dropout = nn.Dropout(dropout)
205
+ self.gelu = nn.GELU()
206
+ def forward(self, images):
207
+ inputs = self.prep(images = images, return_tensors = "pt").to(device)
208
+ with torch.no_grad():
209
+ outputs = self.backbone(**inputs)
210
+ features = outputs.last_hidden_state
211
+ vision_mask = generate_padding_mask(features, padding_idx = 0)
212
+ out = self.proj(features)
213
+ out = self.gelu(out)
214
+ out = self.dropout(out)
215
+ return out, vision_mask
216
+
217
+
218
+ class LanguageEmbedding(nn.Module):
219
+ def __init__(self, out_dim = 768, hidden_dim = 512, dropout = 0.1):
220
+ super().__init__()
221
+ self.tokenizer = AutoTokenizer.from_pretrained(language_model_name)
222
+ self.embeding = AutoModel.from_pretrained(language_model_name)
223
+ for param in self.embeding.parameters():
224
+ param.requires_grad = False
225
+ self.proj = nn.Linear(out_dim, hidden_dim)
226
+ self.dropout = nn.Dropout(dropout)
227
+ self.gelu = nn.GELU()
228
+ def forward(self, questions):
229
+ inputs = self.tokenizer(questions,
230
+ padding = 'max_length',
231
+ max_length = 30,
232
+ truncation = True,
233
+ return_tensors = 'pt',
234
+ return_token_type_ids = True,
235
+ return_attention_mask = True).to(device)
236
+
237
+ features = self.embeding(**inputs).last_hidden_state
238
+ language_mask = generate_padding_mask(inputs.input_ids, padding_idx=self.tokenizer.pad_token_id)
239
+ out = self.proj(features)
240
+ out = self.gelu(out)
241
+ out = self.dropout(out)
242
+ return out, language_mask
243
+
244
+ class BaseModel(nn.Module):
245
+ def __init__(self, num_classes = 353, d_model = 512):
246
+ super().__init__()
247
+ self.vision_embedding = VisionEmbedding()
248
+ self.language_embedding = LanguageEmbedding()
249
+ self.encoder = GuidedAttentionEncoder()
250
+ self.fusion = nn.Sequential(nn.Linear(2 * d_model, d_model),
251
+ nn.ReLU(),
252
+ nn.Dropout(0.2))
253
+ self.classify = nn.Linear(d_model, num_classes)
254
+ self.attention_weights = nn.Linear(d_model, 1)
255
+
256
+ def forward(self, images, questions):
257
+ embedded_text, text_mask = self.language_embedding(questions)
258
+ embedded_vision, vison_mask = self.vision_embedding(images)
259
+
260
+ encoded_image, encoded_text = self.encoder(embedded_vision, vison_mask,embedded_text, text_mask)
261
+ text_attended = self.attention_weights(torch.tanh(encoded_text))
262
+ image_attended = self.attention_weights(torch.tanh(encoded_image))
263
+
264
+ attention_weights = torch.softmax(torch.cat([text_attended, image_attended], dim=1), dim=1)
265
+
266
+ attended_text = torch.sum(attention_weights[:, 0].unsqueeze(-1) * encoded_text, dim=1)
267
+ attended_image = torch.sum(attention_weights[:, 1].unsqueeze(-1) * encoded_image, dim=1)
268
+
269
+ fused_output = self.fusion(torch.cat([attended_text, attended_image], dim=1))
270
+ logits = self.classify(fused_output)
271
+ logits = F.log_softmax(logits, dim=-1)
272
+ return logits
273
+
274
+
275
+
276
+ if __name__ == "__main__":
277
+ model = BaseModel().to(device)
278
+ print(model.eval)
279