File size: 2,552 Bytes
248607a
8037bf7
58e8f4d
 
 
8037bf7
58e8f4d
8037bf7
 
58e8f4d
 
8037bf7
 
 
 
248607a
8037bf7
 
 
 
 
 
 
 
 
 
f02dee0
 
 
 
 
 
 
8037bf7
00c41dd
 
 
8037bf7
00c41dd
 
 
8037bf7
 
00c41dd
8037bf7
00c41dd
8037bf7
00c41dd
8037bf7
248607a
 
3c486be
f02dee0
00c41dd
58e8f4d
 
 
8037bf7
58e8f4d
 
 
8037bf7
58e8f4d
 
 
 
 
8037bf7
58e8f4d
 
 
8037bf7
58e8f4d
248607a
8037bf7
248607a
 
 
 
8037bf7
248607a
8037bf7
58e8f4d
248607a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from flask import Flask, request, jsonify
from PIL import Image
import base64
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

app = Flask(__name__)

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )

    # predict
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    
    if len(pred.shape) == 4:  # Check if the shape is (batch_size, channels, height, width)
        mat = pred[0, 0].cpu().numpy()  # Access the first channel of the first batch
    else:
        mat = pred[0].cpu().numpy()  # If the shape is (channels, height, width)

    mask = Image.fromarray(np.uint8(mat * 255), "L")  # Convert to PIL Image

    # Convert the mask to a NumPy array for calculation
    mask_array = np.array(mask)

    # normalize the mask
    mask_min = mask_array.min()
    mask_max = mask_array.max()
    mask_array = (mask_array - mask_min) / (mask_max - mask_min)

    # threshold the mask
    bmask = mask_array > threshold
    # zero out values below the threshold
    mask_array[mask_array < threshold] = 0

    bmask = Image.fromarray((bmask * 255).astype(np.uint8), "L")

    return bmask




@app.route('/')
def index():
    return "Hello, World! clipseg2"

@app.route('/api/mask_image', methods=['POST'])
def mask_image_api():
    data = request.get_json()

    base64_image = data.get('base64_image', '')
    prompt = data.get('prompt', '')
    threshold = data.get('threshold', 0.4)
    alpha_value = data.get('alpha_value', 0.5)
    draw_rectangles = data.get('draw_rectangles', False)

    # Decode base64 image
    image_data = base64.b64decode(base64_image.split(',')[1])
    image = Image.open(BytesIO(image_data))

    # Process the image
    output_mask = process_image(image, prompt, threshold, alpha_value, draw_rectangles)

    # Convert the output mask to base64
    buffered_mask = BytesIO()
    output_mask.save(buffered_mask, format="PNG")
    result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')

    return jsonify({'result_mask': result_mask})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=False)