from flask import Flask, request, jsonify from PIL import Image import base64 from io import BytesIO from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import torch import numpy as np import matplotlib.pyplot as plt import cv2 app = Flask(__name__) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def process_image(image, prompt, threshold, alpha_value, draw_rectangles): inputs = processor( text=prompt, images=image, padding="max_length", return_tensors="pt" ) # predict with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits pred = torch.sigmoid(preds) if len(pred.shape) == 4: # Check if the shape is (batch_size, channels, height, width) mat = pred[0, 0].cpu().numpy() # Access the first channel of the first batch else: mat = pred[0].cpu().numpy() # If the shape is (channels, height, width) mask = Image.fromarray(np.uint8(mat * 255), "L") # Convert to PIL Image # normalize the mask mask_min = mask.min() mask_max = mask.max() mask = (mask - mask_min) / (mask_max - mask_min) # threshold the mask bmask = mask > threshold # zero out values below the threshold mask[mask < threshold] = 0 bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L") return bmask @app.route('/') def index(): return "Hello, World! clipseg2" @app.route('/api/mask_image', methods=['POST']) def mask_image_api(): data = request.get_json() base64_image = data.get('base64_image', '') prompt = data.get('prompt', '') threshold = data.get('threshold', 0.4) alpha_value = data.get('alpha_value', 0.5) draw_rectangles = data.get('draw_rectangles', False) # Decode base64 image image_data = base64.b64decode(base64_image.split(',')[1]) image = Image.open(BytesIO(image_data)) # Process the image output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles) # Convert the output mask to base64 buffered_mask = BytesIO() output_mask.save(buffered_mask, format="PNG") result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8') return jsonify({'result_mask': result_mask}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=False)