|
from flask import Flask, request, jsonify |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
|
|
app = Flask(__name__) |
|
|
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
def process_image(image, prompt, threshold, alpha_value, draw_rectangles): |
|
inputs = processor( |
|
text=prompt, images=image, padding="max_length", return_tensors="pt" |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
preds = outputs.logits |
|
|
|
pred = torch.sigmoid(preds) |
|
|
|
if len(pred.shape) == 4: |
|
mat = pred[0, 0].cpu().numpy() |
|
else: |
|
mat = pred[0].cpu().numpy() |
|
|
|
mask = Image.fromarray(np.uint8(mat * 255), "L") |
|
|
|
|
|
mask_array = np.array(mask) |
|
|
|
|
|
mask_min = mask_array.min() |
|
mask_max = mask_array.max() |
|
mask_array = (mask_array - mask_min) / (mask_max - mask_min) |
|
|
|
|
|
bmask = mask_array > threshold |
|
|
|
mask_array[mask_array < threshold] = 0 |
|
|
|
bmask = Image.fromarray((bmask * 255).astype(np.uint8), "L") |
|
|
|
return bmask |
|
|
|
|
|
|
|
|
|
@app.route('/') |
|
def index(): |
|
return "Hello, World! clipseg2" |
|
|
|
@app.route('/api/mask_image', methods=['POST']) |
|
def mask_image_api(): |
|
data = request.get_json() |
|
|
|
base64_image = data.get('base64_image', '') |
|
prompt = data.get('prompt', '') |
|
threshold = data.get('threshold', 0.4) |
|
alpha_value = data.get('alpha_value', 0.5) |
|
draw_rectangles = data.get('draw_rectangles', False) |
|
|
|
|
|
image_data = base64.b64decode(base64_image.split(',')[1]) |
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles) |
|
|
|
|
|
buffered_mask = BytesIO() |
|
output_mask.save(buffered_mask, format="PNG") |
|
result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8') |
|
|
|
return jsonify({'result_mask': result_mask}) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860, debug=False) |
|
|