import torch import torch.nn as nn from torch.optim import AdamW from torch.nn import functional as F from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ import wandb from tqdm import tqdm from transformers import GPT2LMHeadModel from gated_state_spaces_pytorch import GatedStateSpacesLM from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper from c4x import C4X if __name__ == '__main__': wandb.init( project="gated-state-space", entity="naxalpha", ) gpt_2 = GPT2LMHeadModel.from_pretrained('gpt2-xl') gpt_2.requires_grad_(False) gpt_2 = gpt_2.cuda() f_emb = 1600 model = AutoregressiveWrapper( GatedStateSpacesLM( num_tokens=50257, dim=f_emb, depth=24, ), ) wandb.watch(model) emb = gpt_2.state_dict()['transformer.wte.weight'] model.net.token_emb.weight.requires_grad_(False) model.net.token_emb.weight.copy_(emb) model.net.to_logits.weight.requires_grad_(False) model.net.to_logits.weight.copy_(emb) model.net.to_logits = nn.Sequential( nn.LayerNorm(f_emb), model.net.to_logits, ) model = model.cuda() optim = AdamW(model.parameters(), 2e-5) bs = 8 kk = 128 dsx = C4X(kk+1) dlx = DataLoader( dsx, batch_size=bs, num_workers=16, ) k = 4 prog = tqdm(dlx) optim.zero_grad() for i, batch in enumerate(prog): batch = batch.cuda() if i % 2 == 0: # distil batch = batch[:, :-1] with torch.no_grad(): logits = gpt_2(batch).logits probs = logits.softmax(dim=-1) out = model.net(batch) los = F.cross_entropy( out.flatten(0,1), probs.flatten(0,1), ) else: # scratch los = model(batch) (los / k).backward() if (i+1) % k == 0: clip_grad_norm_( model.parameters(), max_norm=1., ) optim.step() optim.zero_grad() if i % 1000 == 0: b, n = 4, 512 init = torch.tensor([[50256]]*b).cuda() prd = model.generate(init, n) prd = [dsx.decode(p) for p in prd] try: wandb.log(dict( text=wandb.Html( '
'.join( p.replace('\n', '
') for p in prd ) )), step=i) except Exception as ex: print('Failed to log to W&B...', ex) torch.save(model.state_dict(), 'model.pt') wandb.log(dict( loss=los.item(), ), step=i) prog.set_postfix(loss=los.item())