from flask import Flask, request, jsonify from PIL import Image import base64 from io import BytesIO from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import torch import numpy as np import matplotlib.pyplot as plt import cv2 app = Flask(__name__) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def process_image(image, prompt, threshold, alpha_value, draw_rectangles): inputs = processor( text=prompt, images=image, padding="max_length", return_tensors="pt" ) # predict with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits pred = torch.sigmoid(preds) mat = pred.cpu().numpy() mask = Image.fromarray(np.uint8(mat * 255), "L") mask = mask.resize(image.size) mask = np.array(mask)[:, :, 0] # normalize the mask mask_min = mask.min() mask_max = mask.max() mask = (mask - mask_min) / (mask_max - mask_min) # threshold the mask bmask = mask > threshold # zero out values below the threshold mask[mask < threshold] = 0 bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L") return bmask @app.route('/') def index(): return "Hello, World! clipseg2" @app.route('/api/mask_image', methods=['POST']) def mask_image_api(): data = request.get_json() base64_image = data.get('base64_image', '') prompt = data.get('prompt', '') threshold = data.get('threshold', 0.4) alpha_value = data.get('alpha_value', 0.5) draw_rectangles = data.get('draw_rectangles', False) # Decode base64 image image_data = base64.b64decode(base64_image.split(',')[1]) image = Image.open(BytesIO(image_data)) # Process the image output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles) # Convert the output mask to base64 buffered_mask = BytesIO() output_mask.save(buffered_mask, format="PNG") result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8') return jsonify({'result_mask': result_mask}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=False)