sigyllly commited on
Commit
db1a0dc
1 Parent(s): 125d133

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -22
app.py CHANGED
@@ -1,6 +1,6 @@
1
- from flask import Flask, request, jsonify
2
- import base64
3
  from PIL import Image
 
4
  from io import BytesIO
5
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
  import torch
@@ -13,23 +13,12 @@ app = Flask(__name__)
13
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
16
- @app.route('/api/mask_image', methods=['POST'])
17
- def mask_image_api():
18
- data = request.get_json()
19
-
20
- base64_image = data.get('base64_image', '')
21
- prompt = data.get('prompt', '')
22
- threshold = data.get('threshold', 0.4)
23
- alpha_value = data.get('alpha_value', 0.5)
24
- draw_rectangles = data.get('draw_rectangles', False)
25
-
26
- # Decode base64 image
27
- image_data = base64.b64decode(base64_image)
28
-
29
- # Process the image
30
- image = Image.open(BytesIO(image_data))
31
- inputs = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
32
 
 
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
  preds = outputs.logits
@@ -41,19 +30,75 @@ def mask_image_api():
41
  mask = mask.resize(image.size)
42
  mask = np.array(mask)[:, :, 0]
43
 
 
44
  mask_min = mask.min()
45
  mask_max = mask.max()
46
  mask = (mask - mask_min) / (mask_max - mask_min)
47
 
 
48
  bmask = mask > threshold
 
49
  mask[mask < threshold] = 0
50
 
51
- # Convert the output mask to base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  buffered_mask = BytesIO()
53
- mask.save(buffered_mask, format="PNG")
54
- base64_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- return jsonify({'base64_mask': base64_mask})
57
 
58
  if __name__ == '__main__':
59
  app.run(debug=True)
 
1
+ from flask import Flask, request, jsonify, render_template
 
2
  from PIL import Image
3
+ import base64
4
  from io import BytesIO
5
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
  import torch
 
13
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
16
+ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
17
+ inputs = processor(
18
+ text=prompt, images=image, padding="max_length", return_tensors="pt"
19
+ )
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # predict
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
  preds = outputs.logits
 
30
  mask = mask.resize(image.size)
31
  mask = np.array(mask)[:, :, 0]
32
 
33
+ # normalize the mask
34
  mask_min = mask.min()
35
  mask_max = mask.max()
36
  mask = (mask - mask_min) / (mask_max - mask_min)
37
 
38
+ # threshold the mask
39
  bmask = mask > threshold
40
+ # zero out values below the threshold
41
  mask[mask < threshold] = 0
42
 
43
+ fig, ax = plt.subplots()
44
+ ax.imshow(image)
45
+ ax.imshow(mask, alpha=alpha_value, cmap="jet")
46
+
47
+ if draw_rectangles:
48
+ contours, hierarchy = cv2.findContours(
49
+ bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
50
+ )
51
+ for contour in contours:
52
+ x, y, w, h = cv2.boundingRect(contour)
53
+ rect = plt.Rectangle(
54
+ (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
55
+ )
56
+ ax.add_patch(rect)
57
+
58
+ ax.axis("off")
59
+ plt.tight_layout()
60
+
61
+ bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
62
+ output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
63
+ output_image.paste(image, mask=bmask)
64
+
65
+ # Convert mask to base64
66
  buffered_mask = BytesIO()
67
+ bmask.save(buffered_mask, format="PNG")
68
+ result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
69
+
70
+ # Convert output image to base64
71
+ buffered_output = BytesIO()
72
+ output_image.save(buffered_output, format="PNG")
73
+ result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8')
74
+
75
+ return fig, result_mask, result_output
76
+
77
+ # Existing process_image function, copy it here
78
+ # ...
79
+
80
+ @app.route('/')
81
+ def index():
82
+ return render_template('index.html')
83
+
84
+ @app.route('/api/mask_image', methods=['POST'])
85
+ def mask_image_api():
86
+ data = request.get_json()
87
+
88
+ base64_image = data.get('base64_image', '')
89
+ prompt = data.get('prompt', '')
90
+ threshold = data.get('threshold', 0.4)
91
+ alpha_value = data.get('alpha_value', 0.5)
92
+ draw_rectangles = data.get('draw_rectangles', False)
93
+
94
+ # Decode base64 image
95
+ image_data = base64.b64decode(base64_image.split(',')[1])
96
+ image = Image.open(BytesIO(image_data))
97
+
98
+ # Process the image
99
+ _, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
100
 
101
+ return jsonify({'result_mask': result_mask, 'result_output': result_output})
102
 
103
  if __name__ == '__main__':
104
  app.run(debug=True)