Charbel Malo commited on
Commit
20f67f1
1 Parent(s): 533fe99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -83,9 +83,19 @@ FACE_ENHANCER_LIST.extend(cv2_interpolations)
83
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
84
  # Note: Non CUDA users may change settings here
85
 
86
- PROVIDER = ["CPUExecutionProvider"] # Default to CPU provider
87
- device = "cpu"
 
 
 
 
 
 
 
 
 
88
 
 
89
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
90
 
91
  ## ------------------------------ LOAD MODELS ------------------------------
@@ -93,7 +103,7 @@ EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
93
  def load_face_analyser_model(name="buffalo_l"):
94
  global FACE_ANALYSER
95
  if FACE_ANALYSER is None:
96
- FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
97
  FACE_ANALYSER.prepare(
98
  ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
99
  )
@@ -103,7 +113,7 @@ def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"
103
  global FACE_SWAPPER
104
  if FACE_SWAPPER is None:
105
  batch = int(BATCH_SIZE) if device == "cuda" else 1
106
- FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
107
 
108
 
109
  def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
@@ -114,7 +124,7 @@ def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
114
  def load_nsfw_detector_model(path="./assets/pretrained_models/open-nsfw.onnx"):
115
  global NSFW_DETECTOR
116
  if NSFW_DETECTOR is None:
117
- NSFW_DETECTOR = NSFWChecker(model_path=path, providers=PROVIDER)
118
 
119
 
120
  load_face_analyser_model()
@@ -170,7 +180,7 @@ def process(
170
  FACE_ENHANCER = None
171
  FACE_PARSER = None
172
  NSFW_DETECTOR = None
173
-
174
  ## ------------------------------ GUI UPDATE FUNC ------------------------------
175
 
176
  def ui_before():
@@ -930,7 +940,7 @@ if __name__ == "__main__":
930
  if USE_COLAB:
931
  print("Running in colab mode")
932
 
933
- interface.queue().launch()
934
 
935
 
936
  #### APP.PY CODE END ###
 
83
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
84
  # Note: Non CUDA users may change settings here
85
 
86
+ if USE_CUDA:
87
+ available_providers = onnxruntime.get_available_providers()
88
+ if "CUDAExecutionProvider" in available_providers:
89
+ print("\n********** Running on CUDA **********\n")
90
+ PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
91
+ else:
92
+ USE_CUDA = False
93
+ print("\n********** CUDA unavailable running on CPU **********\n")
94
+ else:
95
+ USE_CUDA = False
96
+ print("\n********** Running on CPU **********\n")
97
 
98
+ device = "cuda" if USE_CUDA else "cpu"
99
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
100
 
101
  ## ------------------------------ LOAD MODELS ------------------------------
 
103
  def load_face_analyser_model(name="buffalo_l"):
104
  global FACE_ANALYSER
105
  if FACE_ANALYSER is None:
106
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
107
  FACE_ANALYSER.prepare(
108
  ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
109
  )
 
113
  global FACE_SWAPPER
114
  if FACE_SWAPPER is None:
115
  batch = int(BATCH_SIZE) if device == "cuda" else 1
116
+ FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
117
 
118
 
119
  def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
 
124
  def load_nsfw_detector_model(path="./assets/pretrained_models/open-nsfw.onnx"):
125
  global NSFW_DETECTOR
126
  if NSFW_DETECTOR is None:
127
+ NSFW_DETECTOR = NSFWChecker(model_path=path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
128
 
129
 
130
  load_face_analyser_model()
 
180
  FACE_ENHANCER = None
181
  FACE_PARSER = None
182
  NSFW_DETECTOR = None
183
+
184
  ## ------------------------------ GUI UPDATE FUNC ------------------------------
185
 
186
  def ui_before():
 
940
  if USE_COLAB:
941
  print("Running in colab mode")
942
 
943
+ interface.launch()
944
 
945
 
946
  #### APP.PY CODE END ###