sigyllly commited on
Commit
248607a
·
verified ·
1 Parent(s): aca23f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -34
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from flask import Flask, request, jsonify, render_template
2
  from PIL import Image
3
  import base64
4
  from io import BytesIO
@@ -13,7 +13,7 @@ app = Flask(__name__)
13
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
16
- def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
17
  inputs = processor(
18
  text=prompt, images=image, padding="max_length", return_tensors="pt"
19
  )
@@ -26,7 +26,6 @@ def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
26
  pred = torch.sigmoid(preds)
27
  mat = pred.cpu().numpy()
28
  mask = Image.fromarray(np.uint8(mat * 255), "L")
29
- mask = mask.convert("RGB")
30
  mask = mask.resize(image.size)
31
  mask = np.array(mask)[:, :, 0]
32
 
@@ -36,34 +35,14 @@ def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
36
  mask = (mask - mask_min) / (mask_max - mask_min)
37
 
38
  # threshold the mask
39
- bmask = mask > threhsold
40
  # zero out values below the threshold
41
- mask[mask < threhsold] = 0
42
-
43
- fig, ax = plt.subplots()
44
- ax.imshow(image)
45
- ax.imshow(mask, alpha=alpha_value, cmap="jet")
46
-
47
- if draw_rectangles:
48
- contours, hierarchy = cv2.findContours(
49
- bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
50
- )
51
- for contour in contours:
52
- x, y, w, h = cv2.boundingRect(contour)
53
- rect = plt.Rectangle(
54
- (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
55
- )
56
- ax.add_patch(rect)
57
-
58
- ax.axis("off")
59
- plt.tight_layout()
60
 
61
  bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
62
- output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
63
- output_image.paste(image, mask=bmask)
64
 
65
- return fig, mask, output_image
66
-
67
  @app.route('/')
68
  def index():
69
  return "Hello, World! clipseg2"
@@ -83,14 +62,14 @@ def mask_image_api():
83
  image = Image.open(BytesIO(image_data))
84
 
85
  # Process the image
86
- _, _, output_image = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
87
 
88
- # Convert the output image to base64
89
- buffered = BytesIO()
90
- output_image.save(buffered, format="PNG")
91
- result_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
92
 
93
- return jsonify({'result_image': result_image})
94
 
95
  if __name__ == '__main__':
96
- app.run(host='0.0.0.0', port=7860, debug=True)
 
1
+ from flask import Flask, request, jsonify
2
  from PIL import Image
3
  import base64
4
  from io import BytesIO
 
13
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
16
+ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
17
  inputs = processor(
18
  text=prompt, images=image, padding="max_length", return_tensors="pt"
19
  )
 
26
  pred = torch.sigmoid(preds)
27
  mat = pred.cpu().numpy()
28
  mask = Image.fromarray(np.uint8(mat * 255), "L")
 
29
  mask = mask.resize(image.size)
30
  mask = np.array(mask)[:, :, 0]
31
 
 
35
  mask = (mask - mask_min) / (mask_max - mask_min)
36
 
37
  # threshold the mask
38
+ bmask = mask > threshold
39
  # zero out values below the threshold
40
+ mask[mask < threshold] = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
 
 
43
 
44
+ return bmask
45
+
46
  @app.route('/')
47
  def index():
48
  return "Hello, World! clipseg2"
 
62
  image = Image.open(BytesIO(image_data))
63
 
64
  # Process the image
65
+ output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
66
 
67
+ # Convert the output mask to base64
68
+ buffered_mask = BytesIO()
69
+ output_mask.save(buffered_mask, format="PNG")
70
+ result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
71
 
72
+ return jsonify({'result_mask': result_mask})
73
 
74
  if __name__ == '__main__':
75
+ app.run(host='0.0.0.0', port=7860, debug=False)