Spaces:
Running
Running
charlieoneill
commited on
Commit
•
10845f0
1
Parent(s):
63a794c
Update topk_sae.py
Browse files- topk_sae.py +1 -108
topk_sae.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
-
import torch.optim as optim
|
5 |
import numpy as np
|
6 |
from torch.utils.data import DataLoader, TensorDataset
|
7 |
from tqdm import tqdm
|
8 |
-
import wandb
|
9 |
import os
|
10 |
import glob
|
11 |
|
@@ -153,109 +151,4 @@ def init_from_data_(ae, data_sample):
|
|
153 |
# encoder as transpose of decoder
|
154 |
ae.encoder.weight.data = ae.decoder.weight.t().clone()
|
155 |
|
156 |
-
nn.init.zeros_(ae.latent_bias)
|
157 |
-
|
158 |
-
def train(ae, train_loader, optimizer, epochs, k, auxk_coef, multik_coef, clip_grad=None, save_dir="../models", model_name=""):
|
159 |
-
os.makedirs(save_dir, exist_ok=True)
|
160 |
-
step = 0
|
161 |
-
num_batches = len(train_loader)
|
162 |
-
for epoch in range(epochs):
|
163 |
-
ae.train()
|
164 |
-
total_loss = 0
|
165 |
-
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
|
166 |
-
optimizer.zero_grad()
|
167 |
-
x = batch[0].to(device)
|
168 |
-
recons, info = ae(x)
|
169 |
-
loss, recons_loss, auxk_loss = loss_fn(ae, x, recons, info, auxk_coef, multik_coef)
|
170 |
-
loss.backward()
|
171 |
-
step += 1
|
172 |
-
|
173 |
-
# calculate proportion of dead latents (not fired in last num_batches = 1 epoch)
|
174 |
-
dead_latents_prop = (ae.stats_last_nonzero > num_batches).float().mean().item()
|
175 |
-
|
176 |
-
wandb.log({
|
177 |
-
"total_loss": loss.item(),
|
178 |
-
"reconstruction_loss": recons_loss.item(),
|
179 |
-
"auxiliary_loss": auxk_loss.item(),
|
180 |
-
"dead_latents_proportion": dead_latents_prop,
|
181 |
-
"l0_norm": k,
|
182 |
-
"step": step
|
183 |
-
})
|
184 |
-
|
185 |
-
unit_norm_decoder_grad_adjustment_(ae)
|
186 |
-
|
187 |
-
if clip_grad is not None:
|
188 |
-
torch.nn.utils.clip_grad_norm_(ae.parameters(), clip_grad)
|
189 |
-
|
190 |
-
optimizer.step()
|
191 |
-
unit_norm_decoder_(ae)
|
192 |
-
|
193 |
-
total_loss += loss.item()
|
194 |
-
|
195 |
-
avg_loss = total_loss / len(train_loader)
|
196 |
-
print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
|
197 |
-
|
198 |
-
# Delete previous model saves for this configuration
|
199 |
-
for old_model in glob.glob(os.path.join(save_dir, f"{model_name}_epoch_*.pth")):
|
200 |
-
os.remove(old_model)
|
201 |
-
|
202 |
-
# Save new model
|
203 |
-
save_path = os.path.join(save_dir, f"{model_name}_epoch_{epoch+1}.pth")
|
204 |
-
torch.save(ae.state_dict(), save_path)
|
205 |
-
print(f"Model saved to {save_path}")
|
206 |
-
|
207 |
-
def main():
|
208 |
-
d_model = 1536
|
209 |
-
n_dirs = 3072 #9216
|
210 |
-
k = 64 #64
|
211 |
-
auxk = k*2 #256
|
212 |
-
multik = 128
|
213 |
-
batch_size = 1024
|
214 |
-
lr = 1e-4
|
215 |
-
auxk_coef = 1/32
|
216 |
-
clip_grad = 1.0
|
217 |
-
multik_coef = 0 # turn it off
|
218 |
-
|
219 |
-
csLG = False
|
220 |
-
|
221 |
-
# Create model name
|
222 |
-
model_name = f"{k}_{n_dirs}_{auxk}_auxk" if not csLG else f"{k}_{n_dirs}_{auxk}_auxk_csLG"
|
223 |
-
epochs = 50 if not csLG else 137
|
224 |
-
|
225 |
-
wandb.init(project="saerch", name=model_name, config={
|
226 |
-
"n_dirs": n_dirs,
|
227 |
-
"d_model": d_model,
|
228 |
-
"k": k,
|
229 |
-
"auxk": auxk,
|
230 |
-
"batch_size": batch_size,
|
231 |
-
"lr": lr,
|
232 |
-
"epochs": epochs,
|
233 |
-
"auxk_coef": auxk_coef,
|
234 |
-
"multik_coef": multik_coef,
|
235 |
-
"clip_grad": clip_grad,
|
236 |
-
"device": device.type
|
237 |
-
})
|
238 |
-
|
239 |
-
if not csLG:
|
240 |
-
data = np.load("../data/vector_store_astroPH/abstract_embeddings.npy")
|
241 |
-
print("Doing astro.ph...")
|
242 |
-
else:
|
243 |
-
data = np.load("../data/vector_store_csLG/abstract_embeddings.npy")
|
244 |
-
print("Doing csLG...")
|
245 |
-
data_tensor = torch.from_numpy(data).float()
|
246 |
-
# Print shape
|
247 |
-
print(f"Data shape: {data_tensor.shape}")
|
248 |
-
dataset = TensorDataset(data_tensor)
|
249 |
-
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
250 |
-
|
251 |
-
ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik).to(device)
|
252 |
-
init_from_data_(ae, data_tensor[:10000].to(device))
|
253 |
-
|
254 |
-
optimizer = optim.Adam(ae.parameters(), lr=lr)
|
255 |
-
|
256 |
-
train(ae, train_loader, optimizer, epochs, k, auxk_coef, multik_coef, clip_grad=clip_grad, model_name=model_name)
|
257 |
-
|
258 |
-
wandb.finish()
|
259 |
-
|
260 |
-
if __name__ == "__main__":
|
261 |
-
main()
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
4 |
import numpy as np
|
5 |
from torch.utils.data import DataLoader, TensorDataset
|
6 |
from tqdm import tqdm
|
|
|
7 |
import os
|
8 |
import glob
|
9 |
|
|
|
151 |
# encoder as transpose of decoder
|
152 |
ae.encoder.weight.data = ae.decoder.weight.t().clone()
|
153 |
|
154 |
+
nn.init.zeros_(ae.latent_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|