ZhengPeng7 commited on
Commit
4420101
1 Parent(s): f70bf31

Fix a bug.

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -60,7 +60,7 @@ def predict(
60
  resolution: str,
61
  weights_file: Optional[str]
62
  ) -> Tuple[np.ndarray, np.ndarray]:
63
- # global birefnet
64
  # Load BiRefNet with chosen weights
65
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
66
  print('Using weights:', _weights_file)
@@ -94,6 +94,8 @@ def predict(
94
  pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
95
  image_pred = (pred * np.array(image_pil)).astype(np.uint8)
96
 
 
 
97
  return image, image_pred
98
 
99
 
 
60
  resolution: str,
61
  weights_file: Optional[str]
62
  ) -> Tuple[np.ndarray, np.ndarray]:
63
+ global birefnet
64
  # Load BiRefNet with chosen weights
65
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
66
  print('Using weights:', _weights_file)
 
94
  pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
95
  image_pred = (pred * np.array(image_pil)).astype(np.uint8)
96
 
97
+ torch.cuda.empty_cache()
98
+
99
  return image, image_pred
100
 
101