arcAman07 commited on
Commit
a3c6bef
·
1 Parent(s): 7da7e5d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -58
app.py DELETED
@@ -1,58 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- from model import Transformer
7
-
8
- # hyperparameters
9
- batch_size = 16 # how many independent sequences will we process in parallel?
10
- block_size = 64 # what is the maximum context length for predictions?
11
- max_iters = 5000
12
- eval_interval = 100
13
- learning_rate = 1e-3
14
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
- eval_iters = 200
16
- n_embd = 128
17
- n_head = 8
18
- n_layer = 4
19
- dropout = 0.0
20
- vocab = 101
21
- # ------------
22
-
23
- with open('/Users/deepaksharma/Documents/Python/Kaggle/GenerateKanyeLyrics/Kanye West Lyrics.txt','r',encoding='utf-8') as f:
24
- text = f.read()
25
-
26
- chars = sorted(list(set(text)))
27
-
28
- stoi = {ch:i for i,ch in enumerate(chars)}
29
- itos = {i:ch for i,ch in enumerate(chars)}
30
-
31
- encode = lambda s: [stoi[c] for c in s]
32
- decode = lambda l: ''.join([itos[c] for c in l])
33
-
34
-
35
- model = Transformer(n_embd,n_layer)
36
- model.load_state_dict(torch.load('model_weights.pth'))
37
- model.eval()
38
-
39
- def generate_kanye_lyrics(text, max_tokens=500):
40
- if len(text)<64:
41
- initial_text = ""
42
- padding = 64-len(text)
43
- initial_list = []
44
- for i in range(0, padding):
45
- initial_list.append(0)
46
- context = initial_list + encode(text)
47
- else:
48
- padding = 0
49
- initial_text = text[0:len(text)-block_size]
50
- context = text[-block_size:]
51
- context = encode(context)
52
- context = torch.tensor(context, dtype=torch.long)
53
- lyrics = torch.stack([context for _ in range(1)], dim=0)
54
- return initial_text + decode(model.generate(lyrics, max_tokens=int(max_tokens))[0].tolist())[padding:]
55
-
56
- demo = gr.Interface(fn=generate_kanye_lyrics, inputs=[gr.Textbox(lines=2, placeholder="Enter Starting lyrics ..."),gr.Number()], outputs="text")
57
-
58
- demo.launch()