File size: 3,239 Bytes
db1a0dc
125d133
db1a0dc
58e8f4d
 
8037bf7
58e8f4d
8037bf7
 
58e8f4d
 
8037bf7
 
 
 
db1a0dc
 
 
 
8037bf7
db1a0dc
8037bf7
 
 
 
 
b69abfd
 
125d133
b69abfd
 
00c41dd
db1a0dc
b69abfd
 
 
8037bf7
db1a0dc
b69abfd
db1a0dc
b69abfd
8037bf7
db1a0dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248607a
db1a0dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8037bf7
db1a0dc
8037bf7
58e8f4d
59b2e32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from flask import Flask, request, jsonify, render_template
from PIL import Image
import base64
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

app = Flask(__name__)

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )

    # predict
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    # normalize the mask
    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)

    # threshold the mask
    bmask = mask > threshold
    # zero out values below the threshold
    mask[mask < threshold] = 0

    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.imshow(mask, alpha=alpha_value, cmap="jet")

    if draw_rectangles:
        contours, hierarchy = cv2.findContours(
            bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
        )
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            rect = plt.Rectangle(
                (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
            )
            ax.add_patch(rect)

    ax.axis("off")
    plt.tight_layout()

    bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
    output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
    output_image.paste(image, mask=bmask)

    # Convert mask to base64
    buffered_mask = BytesIO()
    bmask.save(buffered_mask, format="PNG")
    result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')

    # Convert output image to base64
    buffered_output = BytesIO()
    output_image.save(buffered_output, format="PNG")
    result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8')

    return fig, result_mask, result_output

    # Existing process_image function, copy it here
    # ...

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/api/mask_image', methods=['POST'])
def mask_image_api():
    data = request.get_json()

    base64_image = data.get('base64_image', '')
    prompt = data.get('prompt', '')
    threshold = data.get('threshold', 0.4)
    alpha_value = data.get('alpha_value', 0.5)
    draw_rectangles = data.get('draw_rectangles', False)

    # Decode base64 image
    image_data = base64.b64decode(base64_image.split(',')[1])
    image = Image.open(BytesIO(image_data))

    # Process the image
    _, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles)

    return jsonify({'result_mask': result_mask, 'result_output': result_output})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)