sigyllly commited on
Commit
125d133
1 Parent(s): b69abfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -38
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from flask import Flask, request, jsonify
2
- from PIL import Image
3
  import base64
 
4
  from io import BytesIO
5
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
  import torch
@@ -13,12 +13,23 @@ app = Flask(__name__)
13
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
16
- def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
17
- inputs = processor(
18
- text=prompt, images=image, padding="max_length", return_tensors="pt"
19
- )
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # predict
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
  preds = outputs.logits
@@ -26,50 +37,23 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
26
  pred = torch.sigmoid(preds)
27
  mat = pred.cpu().numpy()
28
  mask = Image.fromarray(np.uint8(mat * 255), "L")
 
29
  mask = mask.resize(image.size)
30
  mask = np.array(mask)[:, :, 0]
31
 
32
- # normalize the mask
33
  mask_min = mask.min()
34
  mask_max = mask.max()
35
  mask = (mask - mask_min) / (mask_max - mask_min)
36
 
37
- # threshold the mask
38
  bmask = mask > threshold
39
- # zero out values below the threshold
40
  mask[mask < threshold] = 0
41
 
42
- bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
43
-
44
- return bmask
45
-
46
- @app.route('/')
47
- def index():
48
- return "Hello, World! clipseg2"
49
-
50
- @app.route('/api/mask_image', methods=['POST'])
51
- def mask_image_api():
52
- data = request.get_json()
53
-
54
- base64_image = data.get('base64_image', '')
55
- prompt = data.get('prompt', '')
56
- threshold = data.get('threshold', 0.4)
57
- alpha_value = data.get('alpha_value', 0.5)
58
- draw_rectangles = data.get('draw_rectangles', False)
59
-
60
- # Decode base64 image
61
- image_data = base64.b64decode(base64_image.split(',')[1])
62
- image = Image.open(BytesIO(image_data))
63
-
64
- # Process the image
65
- output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
66
-
67
  # Convert the output mask to base64
68
  buffered_mask = BytesIO()
69
- output_mask.save(buffered_mask, format="PNG")
70
- result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
71
 
72
- return jsonify({'result_mask': result_mask})
73
 
74
  if __name__ == '__main__':
75
- app.run(host='0.0.0.0', port=7860, debug=False)
 
1
  from flask import Flask, request, jsonify
 
2
  import base64
3
+ from PIL import Image
4
  from io import BytesIO
5
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
  import torch
 
13
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
16
+ @app.route('/api/mask_image', methods=['POST'])
17
+ def mask_image_api():
18
+ data = request.get_json()
19
+
20
+ base64_image = data.get('base64_image', '')
21
+ prompt = data.get('prompt', '')
22
+ threshold = data.get('threshold', 0.4)
23
+ alpha_value = data.get('alpha_value', 0.5)
24
+ draw_rectangles = data.get('draw_rectangles', False)
25
+
26
+ # Decode base64 image
27
+ image_data = base64.b64decode(base64_image)
28
+
29
+ # Process the image
30
+ image = Image.open(BytesIO(image_data))
31
+ inputs = processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
32
 
 
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
  preds = outputs.logits
 
37
  pred = torch.sigmoid(preds)
38
  mat = pred.cpu().numpy()
39
  mask = Image.fromarray(np.uint8(mat * 255), "L")
40
+ mask = mask.convert("RGB")
41
  mask = mask.resize(image.size)
42
  mask = np.array(mask)[:, :, 0]
43
 
 
44
  mask_min = mask.min()
45
  mask_max = mask.max()
46
  mask = (mask - mask_min) / (mask_max - mask_min)
47
 
 
48
  bmask = mask > threshold
 
49
  mask[mask < threshold] = 0
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Convert the output mask to base64
52
  buffered_mask = BytesIO()
53
+ mask.save(buffered_mask, format="PNG")
54
+ base64_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
55
 
56
+ return jsonify({'base64_mask': base64_mask})
57
 
58
  if __name__ == '__main__':
59
+ app.run(debug=True)