arcAman07 commited on
Commit
79934ad
·
1 Parent(s): 18f733c

add gradio model

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from model import Transformer
7
+
8
+ # hyperparameters
9
+ batch_size = 16 # how many independent sequences will we process in parallel?
10
+ block_size = 64 # what is the maximum context length for predictions?
11
+ max_iters = 5000
12
+ eval_interval = 100
13
+ learning_rate = 1e-3
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ eval_iters = 200
16
+ n_embd = 128
17
+ n_head = 8
18
+ n_layer = 4
19
+ dropout = 0.0
20
+ vocab = 101
21
+ # ------------
22
+
23
+ with open('/Users/deepaksharma/Documents/Python/Kaggle/GenerateKanyeLyrics/Kanye West Lyrics.txt','r',encoding='utf-8') as f:
24
+ text = f.read()
25
+
26
+ chars = sorted(list(set(text)))
27
+
28
+ stoi = {ch:i for i,ch in enumerate(chars)}
29
+ itos = {i:ch for i,ch in enumerate(chars)}
30
+
31
+ encode = lambda s: [stoi[c] for c in s]
32
+ decode = lambda l: ''.join([itos[c] for c in l])
33
+
34
+
35
+ model = Transformer(n_embd,n_layer)
36
+ model.load_state_dict(torch.load('model_weights.pth'))
37
+ model.eval()
38
+
39
+ def generate_kanye_lyrics(text, max_tokens=500):
40
+ if len(text)<64:
41
+ initial_text = ""
42
+ padding = 64-len(text)
43
+ initial_list = []
44
+ for i in range(0, padding):
45
+ initial_list.append(0)
46
+ context = initial_list + encode(text)
47
+ else:
48
+ padding = 0
49
+ initial_text = text[0:len(text)-block_size]
50
+ context = text[-block_size:]
51
+ context = encode(context)
52
+ context = torch.tensor(context, dtype=torch.long)
53
+ lyrics = torch.stack([context for _ in range(1)], dim=0)
54
+ return initial_text + decode(model.generate(lyrics, max_tokens=int(max_tokens))[0].tolist())[padding:]
55
+
56
+ demo = gr.Interface(fn=generate_kanye_lyrics, inputs=[gr.Textbox(lines=2, placeholder="Enter Starting lyrics ..."),gr.Number()], outputs="text")
57
+
58
+ demo.launch()