Ashoka74 commited on
Commit
5fd8731
1 Parent(s): a433024

Update inference_i2mv_sdxl.py

Browse files
Files changed (1) hide show
  1. inference_i2mv_sdxl.py +2 -2
inference_i2mv_sdxl.py CHANGED
@@ -101,8 +101,8 @@ def remove_bg(image, net, transform, device):
101
  image_size = image.size
102
  input_images = transform(image).unsqueeze(0).to(device)
103
  with torch.no_grad():
104
- #preds = net(input_images)[-1].sigmoid().cpu()
105
- preds = net(input_images)[-1] if isinstance(net(input_images), list) else net(input_images)
106
  pred = preds[0].squeeze()
107
  pred_pil = transforms.ToPILImage()(pred)
108
  mask = pred_pil.resize(image_size)
 
101
  image_size = image.size
102
  input_images = transform(image).unsqueeze(0).to(device)
103
  with torch.no_grad():
104
+ preds = net(input_images)[-1][0].sigmoid().cpu()
105
+ #preds = net(input_images)[-1] if isinstance(net(input_images), list) else net(input_images)
106
  pred = preds[0].squeeze()
107
  pred_pil = transforms.ToPILImage()(pred)
108
  mask = pred_pil.resize(image_size)