working
Browse files
app.py
CHANGED
@@ -31,8 +31,8 @@ def main():
|
|
31 |
depth=24,
|
32 |
),
|
33 |
)
|
34 |
-
model.net.token_emb.weight.requires_grad_(False)
|
35 |
-
model.net.to_logits.weight.requires_grad_(False)
|
36 |
model.net.to_logits = nn.Sequential(
|
37 |
nn.LayerNorm(f_emb),
|
38 |
model.net.to_logits,
|
@@ -43,7 +43,7 @@ def main():
|
|
43 |
wandb.watch(model)
|
44 |
|
45 |
model.load_state_dict(torch.load('model.pt'))
|
46 |
-
optim = AdamW(model.parameters(),
|
47 |
|
48 |
bs = 1
|
49 |
kk = 2048
|
@@ -77,7 +77,7 @@ def main():
|
|
77 |
if i % 1000 == 0:
|
78 |
accelerator.wait_for_everyone()
|
79 |
unwrapped_model = accelerator.unwrap_model(model)
|
80 |
-
b, n =
|
81 |
init = torch.tensor([[50256]]*b).to(accelerator.device)
|
82 |
prd = unwrapped_model.generate(init, n)
|
83 |
prd = [dsx.decode(p) for p in prd]
|
|
|
31 |
depth=24,
|
32 |
),
|
33 |
)
|
34 |
+
# model.net.token_emb.weight.requires_grad_(False)
|
35 |
+
# model.net.to_logits.weight.requires_grad_(False)
|
36 |
model.net.to_logits = nn.Sequential(
|
37 |
nn.LayerNorm(f_emb),
|
38 |
model.net.to_logits,
|
|
|
43 |
wandb.watch(model)
|
44 |
|
45 |
model.load_state_dict(torch.load('model.pt'))
|
46 |
+
optim = AdamW(model.parameters(), 5e-6)
|
47 |
|
48 |
bs = 1
|
49 |
kk = 2048
|
|
|
77 |
if i % 1000 == 0:
|
78 |
accelerator.wait_for_everyone()
|
79 |
unwrapped_model = accelerator.unwrap_model(model)
|
80 |
+
b, n = 1, 2048
|
81 |
init = torch.tensor([[50256]]*b).to(accelerator.device)
|
82 |
prd = unwrapped_model.generate(init, n)
|
83 |
prd = [dsx.decode(p) for p in prd]
|