seosnaps commited on
Commit
a7c6564
1 Parent(s): c3521e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -122
app.py CHANGED
@@ -1,122 +1,55 @@
1
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
2
- import gradio as gr
3
- from PIL import Image
4
- import torch
5
- import matplotlib.pyplot as plt
6
- import cv2
7
- import torch
8
- import numpy as np
9
-
10
- processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
11
- model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
12
-
13
-
14
- def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
15
- inputs = processor(
16
- text=prompt, images=image, padding="max_length", return_tensors="pt"
17
- )
18
-
19
- # predict
20
- with torch.no_grad():
21
- outputs = model(**inputs)
22
- preds = outputs.logits
23
-
24
- pred = torch.sigmoid(preds)
25
- mat = pred.cpu().numpy()
26
- mask = Image.fromarray(np.uint8(mat * 255), "L")
27
- mask = mask.convert("RGB")
28
- mask = mask.resize(image.size)
29
- mask = np.array(mask)[:, :, 0]
30
-
31
- # normalize the mask
32
- mask_min = mask.min()
33
- mask_max = mask.max()
34
- mask = (mask - mask_min) / (mask_max - mask_min)
35
-
36
- # threshold the mask
37
- bmask = mask > threhsold
38
- # zero out values below the threshold
39
- mask[mask < threhsold] = 0
40
-
41
- fig, ax = plt.subplots()
42
- ax.imshow(image)
43
- ax.imshow(mask, alpha=alpha_value, cmap="jet")
44
-
45
- if draw_rectangles:
46
- contours, hierarchy = cv2.findContours(
47
- bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
48
- )
49
- for contour in contours:
50
- x, y, w, h = cv2.boundingRect(contour)
51
- rect = plt.Rectangle(
52
- (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
53
- )
54
- ax.add_patch(rect)
55
-
56
- ax.axis("off")
57
- plt.tight_layout()
58
-
59
- bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
60
- output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
61
- output_image.paste(image, mask=bmask)
62
-
63
- return fig, mask, output_image
64
-
65
-
66
- title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
67
- description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
68
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
69
-
70
-
71
- with gr.Blocks() as demo:
72
- gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
73
- gr.Markdown(article)
74
- gr.Markdown(description)
75
- gr.Markdown(
76
- "*Example images are taken from the [ImageNet-A](https://paperswithcode.com/dataset/imagenet-a) dataset*"
77
- )
78
-
79
- with gr.Row():
80
- with gr.Column():
81
- input_image = gr.Image(type="pil")
82
- input_prompt = gr.Textbox(label="Please describe what you want to identify")
83
- input_slider_T = gr.Slider(
84
- minimum=0, maximum=1, value=0.4, label="Threshold"
85
- )
86
- input_slider_A = gr.Slider(minimum=0, maximum=1, value=0.5, label="Alpha")
87
- draw_rectangles = gr.Checkbox(label="Draw rectangles")
88
- btn_process = gr.Button(label="Process")
89
-
90
- with gr.Column():
91
- output_plot = gr.Plot(label="Segmentation Result")
92
- output_mask = gr.Image(label="Mask")
93
- output_image = gr.Image(label="Output Image")
94
-
95
- btn_process.click(
96
- process_image,
97
- inputs=[
98
- input_image,
99
- input_prompt,
100
- input_slider_T,
101
- input_slider_A,
102
- draw_rectangles,
103
- ],
104
- outputs=[output_plot, output_mask, output_image],
105
- )
106
-
107
- gr.Examples(
108
- [
109
- ["0.003473_cliff _ cliff_0.51112.jpg", "dog", 0.5, 0.5, True],
110
- ["0.001861_submarine _ submarine_0.9862991.jpg", "beacon", 0.55, 0.4, True],
111
- ["0.004658_spatula _ spatula_0.35416836.jpg", "banana", 0.4, 0.5, True],
112
- ],
113
- inputs=[
114
- input_image,
115
- input_prompt,
116
- input_slider_T,
117
- input_slider_A,
118
- draw_rectangles,
119
- ],
120
- )
121
-
122
- demo.launch()
 
1
+ from flask import Flask, request, send_from_directory
2
+ import os
3
+
4
+ app = Flask(__name__)
5
+
6
+ UPLOAD_FOLDER = 'uploads'
7
+ if not os.path.exists(UPLOAD_FOLDER):
8
+ os.makedirs(UPLOAD_FOLDER)
9
+
10
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
11
+
12
+ @app.route('/', methods=['GET', 'POST'])
13
+ def upload_file():
14
+ if request.method == 'POST':
15
+ if 'file' not in request.files:
16
+ return 'No file part'
17
+
18
+ file = request.files['file']
19
+
20
+ if file.filename == '':
21
+ return 'No selected file'
22
+
23
+ if file:
24
+ filename = file.filename
25
+ file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
26
+ return 'File uploaded successfully'
27
+
28
+ return '''
29
+ <!doctype html>
30
+ <title>Upload new File</title>
31
+ <h1>Upload new File</h1>
32
+ <form method=post enctype=multipart/form-data>
33
+ <input type=file name=file>
34
+ <input type=submit value=Upload>
35
+ </form>
36
+ '''
37
+
38
+ @app.route('/files')
39
+ def list_files():
40
+ files = os.listdir(app.config['UPLOAD_FOLDER'])
41
+ return '''
42
+ <!doctype html>
43
+ <title>Uploaded files</title>
44
+ <h1>Uploaded files</h1>
45
+ <ul>
46
+ ''' + ''.join(['<li><a href="/download/{}">{}</a></li>'.format(f, f) for f in files]) + '''
47
+ </ul>
48
+ '''
49
+
50
+ @app.route('/download/<filename>')
51
+ def download_file(filename):
52
+ return send_from_directory(app.config['UPLOAD_FOLDER'], filename, as_attachment=True)
53
+
54
+ if __name__ == '__main__':
55
+ app.run(debug=True, port=7860)