File size: 2,418 Bytes
248607a 8037bf7 58e8f4d 8037bf7 58e8f4d 8037bf7 58e8f4d 8037bf7 248607a 8037bf7 f02dee0 8037bf7 248607a 8037bf7 248607a 8037bf7 248607a 3c486be f02dee0 58e8f4d 8037bf7 58e8f4d 8037bf7 58e8f4d 8037bf7 58e8f4d 8037bf7 58e8f4d 248607a 8037bf7 248607a 8037bf7 248607a 8037bf7 58e8f4d 248607a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from flask import Flask, request, jsonify
from PIL import Image
import base64
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
app = Flask(__name__)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
if len(pred.shape) == 4: # Check if the shape is (batch_size, channels, height, width)
mat = pred[0, 0].cpu().numpy() # Access the first channel of the first batch
else:
mat = pred[0].cpu().numpy() # If the shape is (channels, height, width)
mask = Image.fromarray(np.uint8(mat * 255), "L") # Convert to PIL Image
# normalize the mask
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
# threshold the mask
bmask = mask > threshold
# zero out values below the threshold
mask[mask < threshold] = 0
bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
return bmask
@app.route('/')
def index():
return "Hello, World! clipseg2"
@app.route('/api/mask_image', methods=['POST'])
def mask_image_api():
data = request.get_json()
base64_image = data.get('base64_image', '')
prompt = data.get('prompt', '')
threshold = data.get('threshold', 0.4)
alpha_value = data.get('alpha_value', 0.5)
draw_rectangles = data.get('draw_rectangles', False)
# Decode base64 image
image_data = base64.b64decode(base64_image.split(',')[1])
image = Image.open(BytesIO(image_data))
# Process the image
output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
# Convert the output mask to base64
buffered_mask = BytesIO()
output_mask.save(buffered_mask, format="PNG")
result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
return jsonify({'result_mask': result_mask})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=False)
|