Ashoka74 commited on
Commit
2cf63f7
1 Parent(s): 1d39e72

Update app_3.py

Browse files
Files changed (1) hide show
  1. app_3.py +8 -2
app_3.py CHANGED
@@ -150,8 +150,14 @@ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
150
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
151
  # Load model directly
152
  from transformers import AutoModelForImageSegmentation
153
- rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
154
- rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
 
 
 
 
 
 
155
 
156
  model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
157
  model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
 
150
  unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
151
  # Load model directly
152
  from transformers import AutoModelForImageSegmentation
153
+ # rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
154
+ # rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
155
+
156
+ # remove bg
157
+ rmbg = AutoModelForImageSegmentation.from_pretrained(
158
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
159
+ )
160
+ rmbg = rmbg.to(device)
161
 
162
  model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
163
  model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))