File size: 3,239 Bytes
db1a0dc 125d133 db1a0dc 58e8f4d 8037bf7 58e8f4d 8037bf7 58e8f4d 8037bf7 db1a0dc 8037bf7 db1a0dc 8037bf7 b69abfd 125d133 b69abfd 00c41dd db1a0dc b69abfd 8037bf7 db1a0dc b69abfd db1a0dc b69abfd 8037bf7 db1a0dc 248607a db1a0dc 8037bf7 db1a0dc 8037bf7 58e8f4d 59b2e32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from flask import Flask, request, jsonify, render_template
from PIL import Image
import base64
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
app = Flask(__name__)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
# normalize the mask
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
# threshold the mask
bmask = mask > threshold
# zero out values below the threshold
mask[mask < threshold] = 0
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, alpha=alpha_value, cmap="jet")
if draw_rectangles:
contours, hierarchy = cv2.findContours(
bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
rect = plt.Rectangle(
(x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
)
ax.add_patch(rect)
ax.axis("off")
plt.tight_layout()
bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
output_image.paste(image, mask=bmask)
# Convert mask to base64
buffered_mask = BytesIO()
bmask.save(buffered_mask, format="PNG")
result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
# Convert output image to base64
buffered_output = BytesIO()
output_image.save(buffered_output, format="PNG")
result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8')
return fig, result_mask, result_output
# Existing process_image function, copy it here
# ...
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/mask_image', methods=['POST'])
def mask_image_api():
data = request.get_json()
base64_image = data.get('base64_image', '')
prompt = data.get('prompt', '')
threshold = data.get('threshold', 0.4)
alpha_value = data.get('alpha_value', 0.5)
draw_rectangles = data.get('draw_rectangles', False)
# Decode base64 image
image_data = base64.b64decode(base64_image.split(',')[1])
image = Image.open(BytesIO(image_data))
# Process the image
_, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
return jsonify({'result_mask': result_mask, 'result_output': result_output})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)
|