Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Commit
•
5681620
1
Parent(s):
ed37c31
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import re
|
|
11 |
import tqdm
|
12 |
|
13 |
import gradio as gr
|
|
|
14 |
|
15 |
from x_transformer_1_23_2 import *
|
16 |
import random
|
@@ -116,20 +117,14 @@ print('Loading model checkpoint...')
|
|
116 |
|
117 |
model_path = 'Popular_Hook_Transformer_Small_Trained_Model_10869_steps_0.2308_loss_0.9252_acc.pth'
|
118 |
|
119 |
-
model.load_state_dict(torch.load(model_path, map_location=
|
120 |
-
|
121 |
-
print('=' * 70)
|
122 |
-
|
123 |
-
model.to(DEVICE)
|
124 |
-
model.eval()
|
125 |
-
|
126 |
-
ctx = torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16)
|
127 |
|
128 |
print('Done!')
|
129 |
print('=' * 70)
|
130 |
|
131 |
#====================================================================================
|
132 |
|
|
|
133 |
def Generate_POP_Section(input_comp_section,
|
134 |
input_mode_time,
|
135 |
input_mode_dur,
|
@@ -167,9 +162,13 @@ def Generate_POP_Section(input_comp_section,
|
|
167 |
|
168 |
seq += input_seq
|
169 |
|
|
|
|
|
|
|
170 |
x = torch.LongTensor(seq).to(DEVICE)
|
171 |
|
172 |
-
with
|
|
|
173 |
out = model.generate(x,
|
174 |
512-len(seq),
|
175 |
temperature=0.9,
|
|
|
11 |
import tqdm
|
12 |
|
13 |
import gradio as gr
|
14 |
+
import spaces
|
15 |
|
16 |
from x_transformer_1_23_2 import *
|
17 |
import random
|
|
|
117 |
|
118 |
model_path = 'Popular_Hook_Transformer_Small_Trained_Model_10869_steps_0.2308_loss_0.9252_acc.pth'
|
119 |
|
120 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
print('Done!')
|
123 |
print('=' * 70)
|
124 |
|
125 |
#====================================================================================
|
126 |
|
127 |
+
@spaces.GPU
|
128 |
def Generate_POP_Section(input_comp_section,
|
129 |
input_mode_time,
|
130 |
input_mode_dur,
|
|
|
162 |
|
163 |
seq += input_seq
|
164 |
|
165 |
+
model.to(DEVICE)
|
166 |
+
model.eval()
|
167 |
+
|
168 |
x = torch.LongTensor(seq).to(DEVICE)
|
169 |
|
170 |
+
with torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16):
|
171 |
+
|
172 |
out = model.generate(x,
|
173 |
512-len(seq),
|
174 |
temperature=0.9,
|