sigyllly commited on
Commit
58e8f4d
1 Parent(s): 982d7e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -57
app.py CHANGED
@@ -1,16 +1,18 @@
1
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
2
- import gradio as gr
3
  from PIL import Image
 
 
 
4
  import torch
 
5
  import matplotlib.pyplot as plt
6
  import cv2
7
- import torch
8
- import numpy as np
9
 
10
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
11
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
12
 
13
-
14
  def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
15
  inputs = processor(
16
  text=prompt, images=image, padding="max_length", return_tensors="pt"
@@ -61,62 +63,34 @@ def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
61
  output_image.paste(image, mask=bmask)
62
 
63
  return fig, mask, output_image
 
 
 
 
64
 
 
 
 
65
 
66
- title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
67
- description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
68
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
 
 
69
 
 
 
 
70
 
71
- with gr.Blocks() as demo:
72
- gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
73
- gr.Markdown(article)
74
- gr.Markdown(description)
75
- gr.Markdown(
76
- "*Example images are taken from the [ImageNet-A](https://paperswithcode.com/dataset/imagenet-a) dataset*"
77
- )
78
 
79
- with gr.Row():
80
- with gr.Column():
81
- input_image = gr.Image(type="pil")
82
- input_prompt = gr.Textbox(label="Please describe what you want to identify")
83
- input_slider_T = gr.Slider(
84
- minimum=0, maximum=1, value=0.4, label="Threshold"
85
- )
86
- input_slider_A = gr.Slider(minimum=0, maximum=1, value=0.5, label="Alpha")
87
- draw_rectangles = gr.Checkbox(label="Draw rectangles")
88
- btn_process = gr.Button(label="Process")
89
-
90
- with gr.Column():
91
- output_plot = gr.Plot(label="Segmentation Result")
92
- output_mask = gr.Image(label="Mask")
93
- output_image = gr.Image(label="Output Image")
94
-
95
- btn_process.click(
96
- process_image,
97
- inputs=[
98
- input_image,
99
- input_prompt,
100
- input_slider_T,
101
- input_slider_A,
102
- draw_rectangles,
103
- ],
104
- outputs=[output_plot, output_mask, output_image],api_name="masking"
105
- )
106
 
107
- gr.Examples(
108
- [
109
- ["0.003473_cliff _ cliff_0.51112.jpg", "dog", 0.5, 0.5, True],
110
- ["0.001861_submarine _ submarine_0.9862991.jpg", "beacon", 0.55, 0.4, True],
111
- ["0.004658_spatula _ spatula_0.35416836.jpg", "banana", 0.4, 0.5, True],
112
- ],
113
- inputs=[
114
- input_image,
115
- input_prompt,
116
- input_slider_T,
117
- input_slider_A,
118
- draw_rectangles,
119
- ],
120
- )
121
 
122
- demo.launch()
 
 
1
+ from flask import Flask, request, jsonify, render_template
 
2
  from PIL import Image
3
+ import base64
4
+ from io import BytesIO
5
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
  import torch
7
+ import numpy as np
8
  import matplotlib.pyplot as plt
9
  import cv2
10
+
11
+ app = Flask(__name__)
12
 
13
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
 
16
  def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
17
  inputs = processor(
18
  text=prompt, images=image, padding="max_length", return_tensors="pt"
 
63
  output_image.paste(image, mask=bmask)
64
 
65
  return fig, mask, output_image
66
+
67
+ @app.route('/')
68
+ def index():
69
+ return "Hello, World! clipseg2"
70
 
71
+ @app.route('/api/mask_image', methods=['POST'])
72
+ def mask_image_api():
73
+ data = request.get_json()
74
 
75
+ base64_image = data.get('base64_image', '')
76
+ prompt = data.get('prompt', '')
77
+ threshold = data.get('threshold', 0.4)
78
+ alpha_value = data.get('alpha_value', 0.5)
79
+ draw_rectangles = data.get('draw_rectangles', False)
80
 
81
+ # Decode base64 image
82
+ image_data = base64.b64decode(base64_image.split(',')[1])
83
+ image = Image.open(BytesIO(image_data))
84
 
85
+ # Process the image
86
+ _, _, output_image = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
 
 
 
 
 
87
 
88
+ # Convert the output image to base64
89
+ buffered = BytesIO()
90
+ output_image.save(buffered, format="PNG")
91
+ result_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ return jsonify({'result_image': result_image})
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ if __name__ == '__main__':
96
+ app.run(host='0.0.0.0', port=7860, debug=True)