# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch import torch import torch.nn as nn from torch.optim import AdamW from torch.nn import functional as F from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ import wandb from tqdm import tqdm from transformers import GPT2LMHeadModel from gated_state_spaces_pytorch import GatedStateSpacesLM from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper from c4x import C4X from accelerate import Accelerator def main(): accelerator = Accelerator( log_with="wandb", gradient_accumulation_steps=4, ) accelerator.init_trackers("gated-state-space") f_emb = 1600 model = AutoregressiveWrapper( GatedStateSpacesLM( num_tokens=50257, dim=f_emb, depth=24, ), ) # model.net.token_emb.weight.requires_grad_(False) # model.net.to_logits.weight.requires_grad_(False) model.net.to_logits = nn.Sequential( nn.LayerNorm(f_emb), model.net.to_logits, ) model = model.to(accelerator.device) if accelerator.is_main_process: wandb.watch(model) model.load_state_dict(torch.load('model.pt')) optim = AdamW(model.parameters(), 5e-6) bs = 1 kk = 2048 dsx = C4X(kk+1) dlx = DataLoader( dsx, batch_size=bs, num_workers=4, ) prog = tqdm(dlx, disable=not accelerator.is_main_process) model = accelerator.prepare(model) optim, dlx = accelerator.prepare(optim, dlx) optim.zero_grad() for i, batch in enumerate(prog): batch = batch.to(accelerator.device) with accelerator.accumulate(model): with accelerator.autocast(): los = model(batch) accelerator.backward(los) if accelerator.sync_gradients: accelerator.clip_grad_norm_( model.parameters(), 1.0, ) optim.step() optim.zero_grad() if i % 1000 == 0: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) b, n = 1, 2048 init = torch.tensor([[50256]]*b).to(accelerator.device) prd = unwrapped_model.generate(init, n) prd = [dsx.decode(p) for p in prd] try: accelerator.log(dict( text=wandb.Html( '
'.join( p.replace('\n', '
') for p in prd ) )), step=i) except Exception as ex: accelerator.print('Failed to log to W&B...', ex) accelerator.save(unwrapped_model.state_dict(), 'model2.pt') if i % 10 == 0: accelerator.log(dict( loss=los.item(), ), step=i) prog.set_postfix(loss=los.item()) if __name__ == '__main__': main()