p1atdev commited on
Commit
d8a9dbd
·
verified ·
1 Parent(s): d5b4063

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -8,6 +8,8 @@ import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
10
 
 
 
11
  TITLE = "WaifuDiffusion Tagger"
12
  DESCRIPTION = """
13
  Demo for the WaifuDiffusion tagger models
@@ -129,7 +131,7 @@ class Predictor:
129
  self.general_indexes = sep_tags[2]
130
  self.character_indexes = sep_tags[3]
131
 
132
- model = rt.InferenceSession(model_path)
133
  _, height, width, _ = model.get_inputs()[0].shape
134
  self.model_target_size = height
135
 
@@ -167,6 +169,7 @@ class Predictor:
167
 
168
  return np.expand_dims(image_array, axis=0)
169
 
 
170
  def predict(
171
  self,
172
  image,
 
8
  import pandas as pd
9
  from PIL import Image
10
 
11
+ import spaces
12
+
13
  TITLE = "WaifuDiffusion Tagger"
14
  DESCRIPTION = """
15
  Demo for the WaifuDiffusion tagger models
 
131
  self.general_indexes = sep_tags[2]
132
  self.character_indexes = sep_tags[3]
133
 
134
+ model = rt.InferenceSession(model_path, providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])
135
  _, height, width, _ = model.get_inputs()[0].shape
136
  self.model_target_size = height
137
 
 
169
 
170
  return np.expand_dims(image_array, axis=0)
171
 
172
+ @spaces.GPU
173
  def predict(
174
  self,
175
  image,