naxalpha commited on
Commit
c9ebb32
1 Parent(s): e14438a
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -31,8 +31,8 @@ def main():
31
  depth=24,
32
  ),
33
  )
34
- model.net.token_emb.weight.requires_grad_(False)
35
- model.net.to_logits.weight.requires_grad_(False)
36
  model.net.to_logits = nn.Sequential(
37
  nn.LayerNorm(f_emb),
38
  model.net.to_logits,
@@ -43,7 +43,7 @@ def main():
43
  wandb.watch(model)
44
 
45
  model.load_state_dict(torch.load('model.pt'))
46
- optim = AdamW(model.parameters(), 2e-5)
47
 
48
  bs = 1
49
  kk = 2048
@@ -77,7 +77,7 @@ def main():
77
  if i % 1000 == 0:
78
  accelerator.wait_for_everyone()
79
  unwrapped_model = accelerator.unwrap_model(model)
80
- b, n = 4, 512
81
  init = torch.tensor([[50256]]*b).to(accelerator.device)
82
  prd = unwrapped_model.generate(init, n)
83
  prd = [dsx.decode(p) for p in prd]
 
31
  depth=24,
32
  ),
33
  )
34
+ # model.net.token_emb.weight.requires_grad_(False)
35
+ # model.net.to_logits.weight.requires_grad_(False)
36
  model.net.to_logits = nn.Sequential(
37
  nn.LayerNorm(f_emb),
38
  model.net.to_logits,
 
43
  wandb.watch(model)
44
 
45
  model.load_state_dict(torch.load('model.pt'))
46
+ optim = AdamW(model.parameters(), 5e-6)
47
 
48
  bs = 1
49
  kk = 2048
 
77
  if i % 1000 == 0:
78
  accelerator.wait_for_everyone()
79
  unwrapped_model = accelerator.unwrap_model(model)
80
+ b, n = 1, 2048
81
  init = torch.tensor([[50256]]*b).to(accelerator.device)
82
  prd = unwrapped_model.generate(init, n)
83
  prd = [dsx.decode(p) for p in prd]