sigyllly commited on
Commit
f02dee0
·
verified ·
1 Parent(s): 3c486be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -24,8 +24,13 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
24
  preds = outputs.logits
25
 
26
  pred = torch.sigmoid(preds)
27
- mat = pred.cpu().numpy()
28
- mask = Image.fromarray(np.uint8(mat[0, 0] * 255), "L") # Access the first channel of the output
 
 
 
 
 
29
 
30
  # normalize the mask
31
  mask_min = mask.min()
@@ -42,6 +47,7 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
42
  return bmask
43
 
44
 
 
45
  @app.route('/')
46
  def index():
47
  return "Hello, World! clipseg2"
 
24
  preds = outputs.logits
25
 
26
  pred = torch.sigmoid(preds)
27
+
28
+ if len(pred.shape) == 4: # Check if the shape is (batch_size, channels, height, width)
29
+ mat = pred[0, 0].cpu().numpy() # Access the first channel of the first batch
30
+ else:
31
+ mat = pred[0].cpu().numpy() # If the shape is (channels, height, width)
32
+
33
+ mask = Image.fromarray(np.uint8(mat * 255), "L") # Convert to PIL Image
34
 
35
  # normalize the mask
36
  mask_min = mask.min()
 
47
  return bmask
48
 
49
 
50
+
51
  @app.route('/')
52
  def index():
53
  return "Hello, World! clipseg2"