File size: 2,854 Bytes
d87b97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

import wandb
from tqdm import tqdm
from transformers import GPT2LMHeadModel
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper

from c4x import C4X


if __name__ == '__main__':
    wandb.init(
        project="gated-state-space", 
        entity="naxalpha",
    )
    
    gpt_2 = GPT2LMHeadModel.from_pretrained('gpt2-xl')
    gpt_2.requires_grad_(False)
    gpt_2 = gpt_2.cuda()

    f_emb = 1600
    model = AutoregressiveWrapper(
        GatedStateSpacesLM(
            num_tokens=50257,
            dim=f_emb,
            depth=24,
        ),
    )
    wandb.watch(model)

    emb = gpt_2.state_dict()['transformer.wte.weight']

    model.net.token_emb.weight.requires_grad_(False)
    model.net.token_emb.weight.copy_(emb)

    model.net.to_logits.weight.requires_grad_(False)
    model.net.to_logits.weight.copy_(emb)
    
    model.net.to_logits = nn.Sequential(
        nn.LayerNorm(f_emb),
        model.net.to_logits,
    )

    model = model.cuda()
    optim = AdamW(model.parameters(), 2e-5)

    bs = 8
    kk = 128
    dsx = C4X(kk+1)
    dlx = DataLoader(
        dsx,
        batch_size=bs, 
        num_workers=16,
    )

    k = 4
    prog = tqdm(dlx)
    optim.zero_grad()

    for i, batch in enumerate(prog):
        batch = batch.cuda()
        if i % 2 == 0:  # distil
            batch = batch[:, :-1]
            with torch.no_grad():
                logits = gpt_2(batch).logits
                probs = logits.softmax(dim=-1)
            out = model.net(batch)
            los = F.cross_entropy(
                out.flatten(0,1),
                probs.flatten(0,1),
            )
        else:  # scratch
            los = model(batch)

        (los / k).backward()
        if (i+1) % k == 0:
            clip_grad_norm_(
                model.parameters(),
                max_norm=1.,
            )
            optim.step()
            optim.zero_grad()

        if i % 1000 == 0:
            b, n = 4, 512
            init = torch.tensor([[50256]]*b).cuda()
            prd = model.generate(init, n)
            prd = [dsx.decode(p) for p in prd]
            try:
                wandb.log(dict(
                    text=wandb.Html(
                        '<hr>'.join(
                            p.replace('\n', '<br>') for p in prd
                        )
                )), step=i)
            except Exception as ex:
                print('Failed to log to W&B...', ex)
            torch.save(model.state_dict(), 'model.pt')

        wandb.log(dict(
            loss=los.item(),
        ), step=i)
        prog.set_postfix(loss=los.item())