ZhengPeng7 commited on
Commit
a10635a
1 Parent(s): d38161e

Add weights option to BiRefNet trained in all different settings.

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -36,11 +36,19 @@ class ImagePreprocessor():
36
 
37
 
38
  from transformers import AutoModelForImageSegmentation
39
- weights_path = 'zhengpeng7/BiRefNet'
40
- birefnet = AutoModelForImageSegmentation.from_pretrained(weights_path, trust_remote_code=True)
41
  birefnet.to(device)
42
  birefnet.eval()
43
 
 
 
 
 
 
 
 
 
44
 
45
  # def predict(image_1, image_2):
46
  # images = [image_1, image_2]
@@ -50,10 +58,10 @@ def predict(image, resolution, weights_file):
50
  global birefnet
51
  if weights_file != weights_path:
52
  # Load BiRefNet with chosen weights
53
- birefnet = AutoModelForImageSegmentation.from_pretrained(weights_file if weights_file is not None else 'zhengpeng7/BiRefNet', trust_remote_code=True)
54
  birefnet.to(device)
55
  birefnet.eval()
56
- # weights_path = weights_file
57
 
58
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
59
  # Image is a RGB numpy array.
@@ -97,7 +105,7 @@ demo = gr.Interface(
97
  inputs=[
98
  'image',
99
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `512x512`. Higher resolutions can be much slower for inference.", label="Resolution"),
100
- gr.Radio(['zhengpeng7/BiRefNet', 'zhengpeng7/BiRefNet-portrait'], label="Weights", info="Choose the weights you want.")
101
  ],
102
  outputs=ImageSlider(),
103
  examples=examples,
 
36
 
37
 
38
  from transformers import AutoModelForImageSegmentation
39
+ weights_path = 'BiRefNet'
40
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', weights_path)), trust_remote_code=True)
41
  birefnet.to(device)
42
  birefnet.eval()
43
 
44
+ usage_to_weights_file = {
45
+ 'General': 'BiRefNet',
46
+ 'Portrait': 'BiRefNet-portrait',
47
+ 'DIS': 'BiRefNet-DIS5K',
48
+ 'HRSOD': 'BiRefNet-HRSOD',
49
+ 'COD': 'BiRefNet-COD',
50
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
51
+ }
52
 
53
  # def predict(image_1, image_2):
54
  # images = [image_1, image_2]
 
58
  global birefnet
59
  if weights_file != weights_path:
60
  # Load BiRefNet with chosen weights
61
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else 'BiRefNet')), trust_remote_code=True)
62
  birefnet.to(device)
63
  birefnet.eval()
64
+ weights_path = weights_file
65
 
66
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
67
  # Image is a RGB numpy array.
 
105
  inputs=[
106
  'image',
107
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `512x512`. Higher resolutions can be much slower for inference.", label="Resolution"),
108
+ gr.Radio(list(usage_to_weights_file.keys()), label="Weights", info="Choose the weights you want.")
109
  ],
110
  outputs=ImageSlider(),
111
  examples=examples,