Abdullah-Nazhat commited on
Commit
5e6b20f
·
verified ·
1 Parent(s): c411986

Upload 2 files

Browse files
Files changed (2) hide show
  1. linearizer.py +221 -0
  2. train.py +197 -0
linearizer.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from einops.layers.torch import Rearrange
6
+ import math
7
+
8
+
9
+
10
+ # helper functions
11
+
12
+ def default(val, default_val):
13
+ return val if val is not None else default_val
14
+
15
+ def init_(tensor):
16
+ dim = tensor.shape[-1]
17
+ std = 1 / math.sqrt(dim)
18
+ tensor.uniform_(-std, std)
19
+ return tensor
20
+
21
+ # helper classes
22
+
23
+ class Residual(nn.Module):
24
+ def __init__(self, fn):
25
+ super().__init__()
26
+ self.fn = fn
27
+ def forward(self, x):
28
+ return x + self.fn(x)
29
+
30
+ class PreNorm(nn.Module):
31
+ def __init__(self, dim, fn):
32
+ super().__init__()
33
+ self.fn = fn
34
+ self.norm = nn.LayerNorm(dim)
35
+ def forward(self, x):
36
+ x = self.norm(x)
37
+ return self.fn(x)
38
+
39
+ class GELU_(nn.Module):
40
+ def forward(self, x):
41
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
42
+
43
+ GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
44
+
45
+ class FeedForward(nn.Module):
46
+ def __init__(self, dim, hidden_dim, dropout = 0., activation = None, glu = False):
47
+ super().__init__()
48
+ activation = default(activation, GELU)
49
+
50
+ self.glu = glu
51
+ self.w1 = nn.Linear(dim, hidden_dim * (2 if glu else 1))
52
+ self.act = activation()
53
+ self.dropout = nn.Dropout(dropout)
54
+ self.w2 = nn.Linear(hidden_dim, dim)
55
+
56
+ def forward(self, x, **kwargs):
57
+ if not self.glu:
58
+ x = self.w1(x)
59
+ x = self.act(x)
60
+ else:
61
+ x, v = self.w1(x).chunk(2, dim=-1)
62
+ x = self.act(x) * v
63
+
64
+ x = self.dropout(x)
65
+ x = self.w2(x)
66
+ return x
67
+
68
+ class LinformerSelfAttention(nn.Module):
69
+ def __init__(self, dim, seq_len, k = 16, heads = 4, dim_head = None, one_kv_head = False, share_kv = False, dropout = 0.):
70
+ super().__init__()
71
+
72
+ assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'
73
+
74
+ self.seq_len = seq_len
75
+ self.k = k
76
+
77
+ self.heads = heads
78
+
79
+ dim_head = default(dim_head, dim // heads)
80
+ self.dim_head = dim_head
81
+
82
+ self.to_q = nn.Linear(dim, dim_head * heads, bias = False)
83
+
84
+ kv_dim = dim_head if one_kv_head else (dim_head * heads)
85
+ self.to_k = nn.Linear(dim, kv_dim, bias = False)
86
+ self.proj_k = nn.Parameter(init_(torch.zeros(seq_len, k)))
87
+
88
+ self.share_kv = share_kv
89
+ if not share_kv:
90
+ self.to_v = nn.Linear(dim, kv_dim, bias = False)
91
+ self.proj_v = nn.Parameter(init_(torch.zeros(seq_len, k)))
92
+
93
+ self.dropout = nn.Dropout(dropout)
94
+ self.to_out = nn.Linear(dim_head * heads, dim)
95
+
96
+ def forward(self, x, context = None, **kwargs):
97
+ b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k
98
+
99
+ kv_len = n if context is None else context.shape[1]
100
+ assert kv_len == self.seq_len, f'the sequence length of the key / values must be {self.seq_len} - {kv_len} given'
101
+
102
+ queries = self.to_q(x)
103
+
104
+ proj_seq_len = lambda args: torch.einsum('bnd,nk->bkd', *args)
105
+
106
+ kv_input = x if context is None else context
107
+
108
+ keys = self.to_k(kv_input)
109
+ values = self.to_v(kv_input) if not self.share_kv else keys
110
+
111
+ kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)
112
+
113
+ # project keys and values along the sequence length dimension to k
114
+
115
+ keys, values = map(proj_seq_len, zip((keys, values), kv_projs))
116
+
117
+ # merge head into batch for queries and key / values
118
+
119
+ queries = queries.reshape(b, n, h, -1).transpose(1, 2)
120
+
121
+ merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
122
+ keys, values = map(merge_key_values, (keys, values))
123
+
124
+ # attention
125
+
126
+ dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
127
+ attn = dots.softmax(dim=-1)
128
+ attn = self.dropout(attn)
129
+ out = torch.einsum('bhnk,bhkd->bhnd', attn, values)
130
+
131
+ # split heads
132
+ out = out.transpose(1, 2).reshape(b, n, -1)
133
+ return self.to_out(out)
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+ class LinformerBlock(nn.Module):
142
+ def __init__(self, d_model, d_ffn, seq_len,dropout):
143
+ super().__init__()
144
+
145
+ self.norm = nn.LayerNorm(d_model)
146
+ self.Linformer_unit = LinformerSelfAttention(d_model, seq_len, k = 256, heads = 8, dim_head = None, one_kv_head = False, share_kv = False, dropout=dropout)
147
+ self.ffn = FeedForward(d_model,d_ffn,dropout)
148
+ def forward(self, x):
149
+ residual = x
150
+ x = self.norm(x)
151
+ x = self.Linformer_unit(x)
152
+ x = x + residual
153
+ residual = x
154
+ x = self.norm(x)
155
+ x = self.ffn(x)
156
+ out = x + residual
157
+ return out
158
+
159
+
160
+
161
+
162
+
163
+
164
+ class LinearizerGatingUnit(nn.Module):
165
+ def __init__(self,d_model,d_ffn,seq_len,dropout):
166
+ super().__init__()
167
+ self.proj = nn.Linear(d_model,d_model)
168
+ self.Linz = LinformerBlock(
169
+ d_model, d_ffn, seq_len,dropout
170
+ )
171
+
172
+
173
+
174
+
175
+ def forward(self, x):
176
+ u, v = x, x
177
+ u = self.proj(u)
178
+ v = self.Linz(v)
179
+ out = u * v
180
+ return out
181
+
182
+
183
+ class LinearizerBlock(nn.Module):
184
+ def __init__(self, d_model,d_ffn,seq_len,dropout):
185
+ super().__init__()
186
+
187
+ self.norm = nn.LayerNorm(d_model)
188
+ self.lgu = LinearizerGatingUnit(d_model,d_ffn,seq_len,dropout)
189
+ self.ffn = FeedForward(d_model,d_ffn,dropout)
190
+ def forward(self, x):
191
+ residual = x
192
+ x = self.norm(x)
193
+ x = self.lgu(x)
194
+ x = x + residual
195
+ residual = x
196
+ x = self.norm(x)
197
+ x = self.ffn(x)
198
+ out = x + residual
199
+ return out
200
+
201
+
202
+
203
+ class Linearizer(nn.Module):
204
+ def __init__(self, d_model, d_ffn,seq_len, num_layers,dropout):
205
+ super().__init__()
206
+
207
+ self.model = nn.Sequential(
208
+ *[LinearizerBlock(d_model,d_ffn,seq_len,dropout) for _ in range(num_layers)]
209
+ )
210
+
211
+ def forward(self, x):
212
+
213
+ return self.model(x)
214
+
215
+
216
+
217
+
218
+
219
+
220
+
221
+
train.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #imports
2
+
3
+ import os
4
+ import csv
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import datasets
9
+ from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
10
+ from linearizer import Linearizer
11
+
12
+ # data transforms
13
+
14
+ transform = Compose([
15
+ RandomCrop(32, padding=4),
16
+ RandomHorizontalFlip(),
17
+ ToTensor(),
18
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19
+
20
+ ])
21
+
22
+ training_data = datasets.CIFAR10(
23
+ root='data',
24
+ train=True,
25
+ download=True,
26
+ transform=transform
27
+ )
28
+
29
+ test_data = datasets.CIFAR10(
30
+ root='data',
31
+ train=False,
32
+ download=True,
33
+ transform=transform
34
+ )
35
+ # create dataloaders
36
+
37
+ batch_size = 128
38
+
39
+ train_dataloader = DataLoader(training_data, batch_size=batch_size,shuffle=True)
40
+ test_dataloader = DataLoader(test_data, batch_size=batch_size)
41
+
42
+
43
+ for X, y in test_dataloader:
44
+ print(f"Shape of X [N,C,H,W]:{X.shape}")
45
+ print(f"Shape of y:{y.shape}{y.dtype}")
46
+ break
47
+
48
+
49
+
50
+ # size checking for loading images
51
+ def check_sizes(image_size, patch_size):
52
+ sqrt_num_patches, remainder = divmod(image_size, patch_size)
53
+ assert remainder == 0, "`image_size` must be divisibe by `patch_size`"
54
+ num_patches = sqrt_num_patches ** 2
55
+ return num_patches
56
+
57
+
58
+
59
+ # create model
60
+ # Get cpu or gpu device for training.
61
+
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ print(f"using {device} device")
64
+
65
+ # model definition
66
+
67
+ # model definition
68
+
69
+ class Linearizer_ImageClassification(Linearizer):
70
+ def __init__(
71
+ self,
72
+ image_size=32,
73
+ patch_size=4,
74
+ in_channels=3,
75
+ num_classes=10,
76
+ d_model=256,
77
+ d_ffn=512,
78
+ seq_len=64,
79
+ num_layers=4,
80
+ dropout=0.5
81
+ ):
82
+ num_patches = check_sizes(image_size, patch_size)
83
+ super().__init__(d_model, d_ffn, seq_len, num_layers,dropout)
84
+ self.patcher = nn.Conv2d(
85
+ in_channels, d_model, kernel_size=patch_size, stride=patch_size
86
+ )
87
+ self.classifier = nn.Linear(d_model, num_classes)
88
+
89
+ def forward(self, x):
90
+ patches = self.patcher(x)
91
+ batch_size, num_channels, _, _ = patches.shape
92
+ patches = patches.permute(0, 2, 3, 1)
93
+ patches = patches.view(batch_size, -1, num_channels)
94
+ embedding = self.model(patches)
95
+ embedding = embedding.mean(dim=1) # global average pooling
96
+ out = self.classifier(embedding)
97
+ return out
98
+
99
+ model = Linearizer_ImageClassification().to(device)
100
+ print(model)
101
+
102
+ # Optimizer
103
+
104
+ loss_fn = nn.CrossEntropyLoss()
105
+ optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
106
+
107
+
108
+ # Training Loop
109
+
110
+ def train(dataloader, model, loss_fn, optimizer):
111
+ size = len(dataloader.dataset)
112
+ num_batches = len(dataloader)
113
+ model.train()
114
+ train_loss = 0
115
+ correct = 0
116
+ for batch, (X,y) in enumerate(dataloader):
117
+ X, y = X.to(device), y.to(device)
118
+
119
+ #compute prediction error
120
+ pred = model(X)
121
+ loss = loss_fn(pred,y)
122
+
123
+ # backpropagation
124
+ optimizer.zero_grad()
125
+ loss.backward()
126
+ optimizer.step()
127
+ train_loss += loss.item()
128
+ _, labels = torch.max(pred.data, 1)
129
+ correct += labels.eq(y.data).type(torch.float).sum()
130
+
131
+
132
+
133
+
134
+ if batch % 100 == 0:
135
+ loss, current = loss.item(), batch * len(X)
136
+ print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
137
+
138
+ train_loss /= num_batches
139
+ train_accuracy = 100. * correct.item() / size
140
+ print(train_accuracy)
141
+ return train_loss,train_accuracy
142
+
143
+
144
+
145
+ # Test loop
146
+
147
+ def test(dataloader, model, loss_fn):
148
+ size = len(dataloader.dataset)
149
+ num_batches = len(dataloader)
150
+ model.eval()
151
+ test_loss = 0
152
+ correct = 0
153
+ with torch.no_grad():
154
+ for X,y in dataloader:
155
+ X,y = X.to(device), y.to(device)
156
+ pred = model(X)
157
+ test_loss += loss_fn(pred, y).item()
158
+ correct += (pred.argmax(1) == y).type(torch.float).sum().item()
159
+ test_loss /= num_batches
160
+ correct /= size
161
+ print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
162
+ test_accuracy = 100*correct
163
+ return test_loss, test_accuracy
164
+
165
+
166
+
167
+ # apply train and test
168
+
169
+ logname = "/home/abdullah/Desktop/Proposal_experiments/Linearizer/Experiments_cifar10/logs_linearizer/logs_cifar10.csv"
170
+ if not os.path.exists(logname):
171
+ with open(logname, 'w') as logfile:
172
+ logwriter = csv.writer(logfile, delimiter=',')
173
+ logwriter.writerow(['epoch', 'train loss', 'train acc',
174
+ 'test loss', 'test acc'])
175
+
176
+
177
+ epochs = 100
178
+ for epoch in range(epochs):
179
+ print(f"Epoch {epoch+1}\n-----------------------------------")
180
+ train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
181
+ # learning rate scheduler
182
+ #if scheduler is not None:
183
+ # scheduler.step()
184
+ test_loss, test_acc = test(test_dataloader, model, loss_fn)
185
+ with open(logname, 'a') as logfile:
186
+ logwriter = csv.writer(logfile, delimiter=',')
187
+ logwriter.writerow([epoch+1, train_loss, train_acc,
188
+ test_loss, test_acc])
189
+ print("Done!")
190
+
191
+ # saving trained model
192
+
193
+ path = "/home/abdullah/Desktop/Proposal_experiments/Linearizer/Experiments_cifar10/weights_linearizer"
194
+ model_name = "linearizerImageClassification_cifar10"
195
+ torch.save(model.state_dict(), f"{path}/{model_name}.pth")
196
+ print(f"Saved Model State to {path}/{model_name}.pth ")
197
+