|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from model import Transformer |
|
|
|
with open('/Users/deepaksharma/Documents/Python/Kaggle/GenerateKanyeLyrics/Kanye West Lyrics.txt','r',encoding='utf-8') as f: |
|
text = f.read() |
|
|
|
chars = sorted(list(set(text))) |
|
|
|
stoi = {ch:i for i,ch in enumerate(chars)} |
|
itos = {i:ch for i,ch in enumerate(chars)} |
|
|
|
encode = lambda s: [stoi[c] for c in s] |
|
decode = lambda l: ''.join([itos[c] for c in l]) |
|
|
|
data = torch.tensor(encode(text), dtype=torch.long) |
|
|
|
n = int(0.9*len(text)) |
|
train_data = data[:n] |
|
val_data = data[n:] |
|
|
|
def get_batch(split): |
|
if split == 'train': |
|
data = train_data |
|
elif split == 'val': |
|
data = val_data |
|
else: |
|
raise ValueError("Invalid split") |
|
|
|
ix = torch.randint(len(data)-block_size,(batch_size,)) |
|
x = torch.stack([data[i:i+block_size] for i in ix]) |
|
y = torch.stack([data[i+1:i+block_size+1] for i in ix]) |
|
return x, y |
|
|
|
|
|
batch_size = 16 |
|
block_size = 64 |
|
max_iters = 5000 |
|
eval_interval = 100 |
|
learning_rate = 1e-3 |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
eval_iters = 200 |
|
n_embd = 128 |
|
n_head = 8 |
|
n_layer = 4 |
|
dropout = 0.0 |
|
vocab = len(chars) |
|
|
|
|
|
|
|
model = Transformer(n_embd,n_layer) |
|
|
|
print("Total params: ", sum(p.numel() for p in model.parameters())) |
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) |
|
|
|
for steps in range(20000): |
|
x,y = get_batch('train') |
|
logits, loss = model(x, y) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
if steps % 100 == 0: |
|
print("Step: ", steps, " Loss: ", loss.item()) |
|
|
|
|
|
print("Model's state_dict:") |
|
for param_tensor in model.state_dict(): |
|
print(param_tensor, "\t", model.state_dict()[param_tensor].size()) |
|
|
|
|
|
print("Optimizer's state_dict:") |
|
for var_name in optimizer.state_dict(): |
|
print(var_name, "\t", optimizer.state_dict()[var_name]) |
|
|
|
torch.save(model.state_dict(), 'kanye_weights.pth') |
|
|
|
lyrics = encode("Bitch I am back on my comma , sipping on my CocaCola, driving on a hangover ") |
|
lyrics = torch.tensor(lyrics, dtype=torch.long) |
|
lyrics = torch.stack([lyrics for _ in range(1)], dim=0) |
|
|
|
print(decode(model.generate(lyrics, max_tokens=1000)[0].tolist())) |
|
|