asigalov61 commited on
Commit
5681620
1 Parent(s): ed37c31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -11,6 +11,7 @@ import re
11
  import tqdm
12
 
13
  import gradio as gr
 
14
 
15
  from x_transformer_1_23_2 import *
16
  import random
@@ -116,20 +117,14 @@ print('Loading model checkpoint...')
116
 
117
  model_path = 'Popular_Hook_Transformer_Small_Trained_Model_10869_steps_0.2308_loss_0.9252_acc.pth'
118
 
119
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
120
-
121
- print('=' * 70)
122
-
123
- model.to(DEVICE)
124
- model.eval()
125
-
126
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16)
127
 
128
  print('Done!')
129
  print('=' * 70)
130
 
131
  #====================================================================================
132
 
 
133
  def Generate_POP_Section(input_comp_section,
134
  input_mode_time,
135
  input_mode_dur,
@@ -167,9 +162,13 @@ def Generate_POP_Section(input_comp_section,
167
 
168
  seq += input_seq
169
 
 
 
 
170
  x = torch.LongTensor(seq).to(DEVICE)
171
 
172
- with ctx:
 
173
  out = model.generate(x,
174
  512-len(seq),
175
  temperature=0.9,
 
11
  import tqdm
12
 
13
  import gradio as gr
14
+ import spaces
15
 
16
  from x_transformer_1_23_2 import *
17
  import random
 
117
 
118
  model_path = 'Popular_Hook_Transformer_Small_Trained_Model_10869_steps_0.2308_loss_0.9252_acc.pth'
119
 
120
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
 
 
 
 
 
 
 
121
 
122
  print('Done!')
123
  print('=' * 70)
124
 
125
  #====================================================================================
126
 
127
+ @spaces.GPU
128
  def Generate_POP_Section(input_comp_section,
129
  input_mode_time,
130
  input_mode_dur,
 
162
 
163
  seq += input_seq
164
 
165
+ model.to(DEVICE)
166
+ model.eval()
167
+
168
  x = torch.LongTensor(seq).to(DEVICE)
169
 
170
+ with torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16):
171
+
172
  out = model.generate(x,
173
  512-len(seq),
174
  temperature=0.9,