naxalpha's picture
add training code
d87b97b
raw
history blame
2.85 kB
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
import wandb
from tqdm import tqdm
from transformers import GPT2LMHeadModel
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from c4x import C4X
if __name__ == '__main__':
wandb.init(
project="gated-state-space",
entity="naxalpha",
)
gpt_2 = GPT2LMHeadModel.from_pretrained('gpt2-xl')
gpt_2.requires_grad_(False)
gpt_2 = gpt_2.cuda()
f_emb = 1600
model = AutoregressiveWrapper(
GatedStateSpacesLM(
num_tokens=50257,
dim=f_emb,
depth=24,
),
)
wandb.watch(model)
emb = gpt_2.state_dict()['transformer.wte.weight']
model.net.token_emb.weight.requires_grad_(False)
model.net.token_emb.weight.copy_(emb)
model.net.to_logits.weight.requires_grad_(False)
model.net.to_logits.weight.copy_(emb)
model.net.to_logits = nn.Sequential(
nn.LayerNorm(f_emb),
model.net.to_logits,
)
model = model.cuda()
optim = AdamW(model.parameters(), 2e-5)
bs = 8
kk = 128
dsx = C4X(kk+1)
dlx = DataLoader(
dsx,
batch_size=bs,
num_workers=16,
)
k = 4
prog = tqdm(dlx)
optim.zero_grad()
for i, batch in enumerate(prog):
batch = batch.cuda()
if i % 2 == 0: # distil
batch = batch[:, :-1]
with torch.no_grad():
logits = gpt_2(batch).logits
probs = logits.softmax(dim=-1)
out = model.net(batch)
los = F.cross_entropy(
out.flatten(0,1),
probs.flatten(0,1),
)
else: # scratch
los = model(batch)
(los / k).backward()
if (i+1) % k == 0:
clip_grad_norm_(
model.parameters(),
max_norm=1.,
)
optim.step()
optim.zero_grad()
if i % 1000 == 0:
b, n = 4, 512
init = torch.tensor([[50256]]*b).cuda()
prd = model.generate(init, n)
prd = [dsx.decode(p) for p in prd]
try:
wandb.log(dict(
text=wandb.Html(
'<hr>'.join(
p.replace('\n', '<br>') for p in prd
)
)), step=i)
except Exception as ex:
print('Failed to log to W&B...', ex)
torch.save(model.state_dict(), 'model.pt')
wandb.log(dict(
loss=los.item(),
), step=i)
prog.set_postfix(loss=los.item())