|
from flask import Flask, request, jsonify |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
|
|
app = Flask(__name__) |
|
|
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
def process_image(image, prompt, threshold, alpha_value, draw_rectangles): |
|
inputs = processor( |
|
text=prompt, images=image, padding="max_length", return_tensors="pt" |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
preds = outputs.logits |
|
|
|
pred = torch.sigmoid(preds) |
|
mat = pred.cpu().numpy() |
|
mask = Image.fromarray(np.uint8(mat[0, 0] * 255), "L") |
|
|
|
|
|
mask_min = mask.min() |
|
mask_max = mask.max() |
|
mask = (mask - mask_min) / (mask_max - mask_min) |
|
|
|
|
|
bmask = mask > threshold |
|
|
|
mask[mask < threshold] = 0 |
|
|
|
bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L") |
|
|
|
return bmask |
|
|
|
|
|
@app.route('/') |
|
def index(): |
|
return "Hello, World! clipseg2" |
|
|
|
@app.route('/api/mask_image', methods=['POST']) |
|
def mask_image_api(): |
|
data = request.get_json() |
|
|
|
base64_image = data.get('base64_image', '') |
|
prompt = data.get('prompt', '') |
|
threshold = data.get('threshold', 0.4) |
|
alpha_value = data.get('alpha_value', 0.5) |
|
draw_rectangles = data.get('draw_rectangles', False) |
|
|
|
|
|
image_data = base64.b64decode(base64_image.split(',')[1]) |
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles) |
|
|
|
|
|
buffered_mask = BytesIO() |
|
output_mask.save(buffered_mask, format="PNG") |
|
result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8') |
|
|
|
return jsonify({'result_mask': result_mask}) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860, debug=False) |
|
|