Mar2Ding commited on
Commit
a13369a
1 Parent(s): 690f9c4

Upload 5 files

Browse files
Files changed (4) hide show
  1. LICENSE +201 -0
  2. LICENSE_cctorch +29 -0
  3. app.py +609 -4
  4. requirements.txt +10 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LICENSE_cctorch ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
app.py CHANGED
@@ -1,7 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import subprocess
2
+ import re
3
+ from typing import List, Tuple, Optional
4
+
5
+ # Define the command to be executed
6
+ command = ["python", "setup.py", "build_ext", "--inplace"]
7
+
8
+ # Execute the command
9
+ result = subprocess.run(command, capture_output=True, text=True)
10
+
11
+ # Print the output and error (if any)
12
+ print("Output:\n", result.stdout)
13
+ print("Errors:\n", result.stderr)
14
+
15
+ # Check if the command was successful
16
+ if result.returncode == 0:
17
+ print("Command executed successfully.")
18
+ else:
19
+ print("Command failed with return code:", result.returncode)
20
+
21
  import gradio as gr
22
+ from datetime import datetime
23
+ import os
24
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
25
+ import torch
26
+ import numpy as np
27
+ import cv2
28
+ import matplotlib.pyplot as plt
29
+ from PIL import Image, ImageFilter
30
+ from sam2.build_sam import build_sam2_video_predictor
31
+
32
+ from moviepy.editor import ImageSequenceClip
33
+
34
+ def get_video_fps(video_path):
35
+ # Open the video file
36
+ cap = cv2.VideoCapture(video_path)
37
+
38
+ if not cap.isOpened():
39
+ print("Error: Could not open video.")
40
+ return None
41
+
42
+ # Get the FPS of the video
43
+ fps = cap.get(cv2.CAP_PROP_FPS)
44
+
45
+ return fps
46
+
47
+ def clear_points(image):
48
+ # we clean all
49
+ return [
50
+ image, # first_frame_path
51
+ gr.State([]), # tracking_points
52
+ gr.State([]), # trackings_input_label
53
+ image, # points_map
54
+ #gr.State() # stored_inference_state
55
+ ]
56
+
57
+ def preprocess_video_in(video_path):
58
+
59
+ # Generate a unique ID based on the current date and time
60
+ unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
61
+
62
+ # Set directory with this ID to store video frames
63
+ extracted_frames_output_dir = f'frames_{unique_id}'
64
+
65
+ # Create the output directory
66
+ os.makedirs(extracted_frames_output_dir, exist_ok=True)
67
+
68
+ ### Process video frames ###
69
+ # Open the video file
70
+ cap = cv2.VideoCapture(video_path)
71
+
72
+ if not cap.isOpened():
73
+ print("Error: Could not open video.")
74
+ return None
75
+
76
+ # Get the frames per second (FPS) of the video
77
+ fps = cap.get(cv2.CAP_PROP_FPS)
78
+
79
+ # Calculate the number of frames to process (10 seconds of video)
80
+ max_frames = int(fps * 10)
81
+
82
+ frame_number = 0
83
+ first_frame = None
84
+
85
+ while True:
86
+ ret, frame = cap.read()
87
+ if not ret or frame_number >= max_frames:
88
+ break
89
+
90
+ # Format the frame filename as '00000.jpg'
91
+ frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
92
+
93
+ # Save the frame as a JPEG file
94
+ cv2.imwrite(frame_filename, frame)
95
+
96
+ # Store the first frame
97
+ if frame_number == 0:
98
+ first_frame = frame_filename
99
+
100
+ frame_number += 1
101
+
102
+ # Release the video capture object
103
+ cap.release()
104
+
105
+ # scan all the JPEG frame names in this directory
106
+ scanned_frames = [
107
+ p for p in os.listdir(extracted_frames_output_dir)
108
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
109
+ ]
110
+ scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
111
+ # print(f"SCANNED_FRAMES: {scanned_frames}")
112
+
113
+ return [
114
+ first_frame, # first_frame_path
115
+ gr.State([]), # tracking_points
116
+ gr.State([]), # trackings_input_label
117
+ first_frame, # input_first_frame_image
118
+ first_frame, # points_map
119
+ extracted_frames_output_dir, # video_frames_dir
120
+ scanned_frames, # scanned_frames
121
+ None, # stored_inference_state
122
+ None, # stored_frame_names
123
+ gr.update(open=False) # video_in_drawer
124
+ ]
125
+
126
+ def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
127
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
128
+
129
+ tracking_points.value.append(evt.index)
130
+ print(f"TRACKING POINT: {tracking_points.value}")
131
+
132
+ if point_type == "include":
133
+ trackings_input_label.value.append(1)
134
+ elif point_type == "exclude":
135
+ trackings_input_label.value.append(0)
136
+ print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
137
+
138
+ # Open the image and get its dimensions
139
+ transparent_background = Image.open(input_first_frame_image).convert('RGBA')
140
+ w, h = transparent_background.size
141
+
142
+ # Define the circle radius as a fraction of the smaller dimension
143
+ fraction = 0.02 # You can adjust this value as needed
144
+ radius = int(fraction * min(w, h))
145
+
146
+ # Create a transparent layer to draw on
147
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
148
+
149
+ for index, track in enumerate(tracking_points.value):
150
+ if trackings_input_label.value[index] == 1:
151
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
152
+ else:
153
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
154
+
155
+ # Convert the transparent layer back to an image
156
+ transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
157
+ selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
158
+
159
+ return tracking_points, trackings_input_label, selected_point_map
160
+
161
+ # use bfloat16 for the entire notebook
162
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
163
+
164
+ if torch.cuda.get_device_properties(0).major >= 8:
165
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
166
+ torch.backends.cuda.matmul.allow_tf32 = True
167
+ torch.backends.cudnn.allow_tf32 = True
168
+
169
+ def show_mask(mask, ax, obj_id=None, random_color=False):
170
+ if random_color:
171
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
172
+ else:
173
+ cmap = plt.get_cmap("tab10")
174
+ cmap_idx = 0 if obj_id is None else obj_id
175
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
176
+ h, w = mask.shape[-2:]
177
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
178
+ ax.imshow(mask_image)
179
+
180
+
181
+ def show_points(coords, labels, ax, marker_size=200):
182
+ pos_points = coords[labels==1]
183
+ neg_points = coords[labels==0]
184
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
185
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
186
+
187
+ def show_box(box, ax):
188
+ x0, y0 = box[0], box[1]
189
+ w, h = box[2] - box[0], box[3] - box[1]
190
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
191
+
192
+
193
+ def load_model(checkpoint):
194
+ # Load model accordingly to user's choice
195
+ if checkpoint == "tiny":
196
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"
197
+ model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
198
+ return [sam2_checkpoint, model_cfg]
199
+ elif checkpoint == "samll":
200
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"
201
+ model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
202
+ return [sam2_checkpoint, model_cfg]
203
+ elif checkpoint == "base-plus":
204
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_base_plus.pt"
205
+ model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
206
+ return [sam2_checkpoint, model_cfg]
207
+ # elif checkpoint == "large":
208
+ # sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
209
+ # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
210
+ # return [sam2_checkpoint, model_cfg]
211
+
212
+
213
+
214
+ def get_mask_sam_process(
215
+ stored_inference_state,
216
+ input_first_frame_image,
217
+ checkpoint,
218
+ tracking_points,
219
+ trackings_input_label,
220
+ video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
221
+ scanned_frames,
222
+ working_frame: str = None, # current frame being added points
223
+ available_frames_to_check: List[str] = [],
224
+ # progress=gr.Progress(track_tqdm=True)
225
+ ):
226
+
227
+ # get model and model config paths
228
+ print(f"USER CHOSEN CHECKPOINT: {checkpoint}")
229
+ sam2_checkpoint, model_cfg = load_model(checkpoint)
230
+ print("MODEL LOADED")
231
+
232
+ # set predictor
233
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
234
+ print("PREDICTOR READY")
235
+
236
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
237
+ # print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
238
+ video_dir = video_frames_dir
239
+
240
+ # scan all the JPEG frame names in this directory
241
+ frame_names = scanned_frames
242
+
243
+ # print(f"STORED INFERENCE STEP: {stored_inference_state}")
244
+ if stored_inference_state is None:
245
+ # Init SAM2 inference_state
246
+ inference_state = predictor.init_state(video_path=video_dir)
247
+ inference_state['num_pathway'] = 3
248
+ inference_state['iou_thre'] = 0.3
249
+ inference_state['uncertainty'] = 2
250
+ print("NEW INFERENCE_STATE INITIATED")
251
+ else:
252
+ inference_state = stored_inference_state
253
+
254
+ # segment and track one object
255
+ # predictor.reset_state(inference_state) # if any previous tracking, reset
256
+
257
+
258
+ ### HANDLING WORKING FRAME
259
+ # new_working_frame = None
260
+ # Add new point
261
+ if working_frame is None:
262
+ ann_frame_idx = 0 # the frame index we interact with, 0 if it is the first frame
263
+ working_frame = "00000.jpg"
264
+ else:
265
+ # Use a regular expression to find the integer
266
+ match = re.search(r'frame_(\d+)', working_frame)
267
+ if match:
268
+ # Extract the integer from the match
269
+ frame_number = int(match.group(1))
270
+ ann_frame_idx = frame_number
271
+
272
+ print(f"NEW_WORKING_FRAME PATH: {working_frame}")
273
+
274
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
275
+
276
+ # Let's add a positive click at (x, y) = (210, 350) to get started
277
+ points = np.array(tracking_points.value, dtype=np.float32)
278
+ # for labels, `1` means positive click and `0` means negative click
279
+ labels = np.array(trackings_input_label.value, np.int32)
280
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points(
281
+ inference_state=inference_state,
282
+ frame_idx=ann_frame_idx,
283
+ obj_id=ann_obj_id,
284
+ points=points,
285
+ labels=labels,
286
+ )
287
+
288
+ # Create the plot
289
+ plt.figure(figsize=(12, 8))
290
+ plt.title(f"frame {ann_frame_idx}")
291
+ plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
292
+ show_points(points, labels, plt.gca())
293
+ show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
294
+
295
+ # Save the plot as a JPG file
296
+ first_frame_output_filename = "output_first_frame.jpg"
297
+ plt.savefig(first_frame_output_filename, format='jpg')
298
+ plt.close()
299
+ torch.cuda.empty_cache()
300
+
301
+ # Assuming available_frames_to_check.value is a list
302
+ if working_frame not in available_frames_to_check:
303
+ available_frames_to_check.append(working_frame)
304
+ print(available_frames_to_check)
305
+
306
+ # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
307
+ return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
308
+
309
+ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
310
+ #### PROPAGATION ####
311
+ sam2_checkpoint, model_cfg = load_model(checkpoint)
312
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
313
+
314
+ inference_state = stored_inference_state
315
+ frame_names = stored_frame_names
316
+ video_dir = video_frames_dir
317
+
318
+ # Define a directory to save the JPEG images
319
+ frames_output_dir = "frames_output_images"
320
+ os.makedirs(frames_output_dir, exist_ok=True)
321
+
322
+ # Initialize a list to store file paths of saved images
323
+ jpeg_images = []
324
+
325
+ # run propagation throughout the video and collect the results in a dict
326
+ video_segments = {} # video_segments contains the per-frame segmentation results
327
+ # for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
328
+ # video_segments[out_frame_idx] = {
329
+ # out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
330
+ # for i, out_obj_id in enumerate(out_obj_ids)
331
+ # }
332
+
333
+ out_obj_ids, out_mask_logits = predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False,)
334
+ print(out_obj_ids)
335
+ for frame_idx in range(0, inference_state['num_frames']):
336
+
337
+ video_segments[frame_idx] = {out_obj_ids[0]: (out_mask_logits[frame_idx]> 0.0).cpu().numpy()}
338
+ # output_scores_per_object[object_id][frame_idx] = out_mask_logits[frame_idx].cpu().numpy()
339
+
340
+ # render the segmentation results every few frames
341
+ if vis_frame_type == "check":
342
+ vis_frame_stride = 15
343
+ elif vis_frame_type == "render":
344
+ vis_frame_stride = 1
345
+
346
+ plt.close("all")
347
+ for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
348
+ plt.figure(figsize=(6, 4))
349
+ plt.title(f"frame {out_frame_idx}")
350
+ plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
351
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
352
+ show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
353
+
354
+ # Define the output filename and save the figure as a JPEG file
355
+ output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
356
+ plt.savefig(output_filename, format='jpg')
357
+
358
+ # Close the plot
359
+ plt.close()
360
+
361
+ # Append the file path to the list
362
+ jpeg_images.append(output_filename)
363
+
364
+ if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
365
+ available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
366
+
367
+ torch.cuda.empty_cache()
368
+ print(f"JPEG_IMAGES: {jpeg_images}")
369
+
370
+ if vis_frame_type == "check":
371
+ return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True)
372
+ elif vis_frame_type == "render":
373
+ # Create a video clip from the image sequence
374
+ original_fps = get_video_fps(video_in)
375
+ fps = original_fps # Frames per second
376
+ total_frames = len(jpeg_images)
377
+ clip = ImageSequenceClip(jpeg_images, fps=fps)
378
+ # Write the result to a file
379
+ final_vid_output_path = "output_video.mp4"
380
+
381
+ # Write the result to a file
382
+ clip.write_videofile(
383
+ final_vid_output_path,
384
+ codec='libx264'
385
+ )
386
+
387
+ return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
388
+
389
+ def update_ui(vis_frame_type):
390
+ if vis_frame_type == "check":
391
+ return gr.update(visible=True), gr.update(visible=False)
392
+ elif vis_frame_type == "render":
393
+ return gr.update(visible=False), gr.update(visible=True)
394
+
395
+ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
396
+ new_working_frame = None
397
+ if working_frame == None:
398
+ new_working_frame = os.path.join(video_frames_dir, scanned_frames[0])
399
+
400
+ else:
401
+ # Use a regular expression to find the integer
402
+ match = re.search(r'frame_(\d+)', working_frame)
403
+ if match:
404
+ # Extract the integer from the match
405
+ frame_number = int(match.group(1))
406
+ ann_frame_idx = frame_number
407
+ new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
408
+ return gr.State([]), gr.State([]), new_working_frame, new_working_frame
409
+
410
+ def reset_propagation(first_frame_path, predictor, stored_inference_state):
411
+
412
+ predictor.reset_state(stored_inference_state)
413
+ # print(f"RESET State: {stored_inference_state} ")
414
+ return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
415
+
416
+ css="""
417
+ div#component-18, div#component-25, div#component-35, div#component-41{
418
+ align-items: stretch!important;
419
+ }
420
+ """
421
+
422
+ with gr.Blocks(css=css) as demo:
423
+ first_frame_path = gr.State()
424
+ tracking_points = gr.State([])
425
+ trackings_input_label = gr.State([])
426
+ video_frames_dir = gr.State()
427
+ scanned_frames = gr.State()
428
+ loaded_predictor = gr.State()
429
+ stored_inference_state = gr.State()
430
+ stored_frame_names = gr.State()
431
+ available_frames_to_check = gr.State([])
432
+ with gr.Column():
433
+ gr.Markdown(
434
+ """
435
+ <h1 style="text-align: center;">🔥 SAM2Long Demo 🔥</h1>
436
+ """
437
+ )
438
+ gr.Markdown(
439
+ """
440
+ This is a simple demo for video segmentation with [SAM2Long](https://github.com/Mark12Ding/SAM2Long).
441
+ """
442
+ )
443
+ gr.Markdown(
444
+ """
445
+ ### 📋 Instructions:
446
+
447
+ It is largely built on the [SAM2-Video-Predictor](https://huggingface.co/spaces/fffiloni/SAM2-Video-Predictor).
448
+
449
+ 1. **Upload your video** [MP4-24fps]
450
+ 2. With **'include' point type** selected, click on the object to mask on the first frame
451
+ 3. Switch to **'exclude' point type** if you want to specify an area to avoid
452
+ 4. **Get Mask!**
453
+ 5. **Check Propagation** every 15 frames
454
+ 6. **Propagate with "render"** to render the final masked video
455
+ 7. **Hit Reset** button if you want to refresh and start again
456
+
457
+ *Note: Input video will be processed for up to 10 seconds only for demo purposes.*
458
+ """
459
+ )
460
+ with gr.Row():
461
+
462
+ with gr.Column():
463
+ with gr.Group():
464
+ with gr.Group():
465
+ with gr.Row():
466
+ point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2)
467
+ clear_points_btn = gr.Button("Clear Points", scale=1)
468
+
469
+ input_first_frame_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
470
+
471
+ points_map = gr.Image(
472
+ label="Point n Click map",
473
+ type="filepath",
474
+ interactive=False
475
+ )
476
+
477
+ with gr.Group():
478
+ with gr.Row():
479
+ checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus"], value="tiny")
480
+ submit_btn = gr.Button("Get Mask", size="lg")
481
+
482
+ with gr.Accordion("Your video IN", open=True) as video_in_drawer:
483
+ video_in = gr.Video(label="Video IN", format="mp4")
484
+
485
+ gr.HTML("""
486
+
487
+ <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true">
488
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
489
+ </a> to skip queue and avoid OOM errors from heavy public load
490
+ """)
491
+
492
+ with gr.Column():
493
+ with gr.Group():
494
+ # with gr.Group():
495
+ # with gr.Row():
496
+ working_frame = gr.Dropdown(label="working frame ID", choices=[""], value="frame_0.jpg", visible=False, allow_custom_value=False, interactive=True)
497
+ # change_current = gr.Button("change current", visible=False)
498
+ # working_frame = []
499
+ output_result = gr.Image(label="current working mask ref")
500
+ with gr.Group():
501
+ with gr.Row():
502
+ vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
503
+ propagate_btn = gr.Button("Propagate", scale=1)
504
+ reset_prpgt_brn = gr.Button("Reset", visible=False)
505
+ output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
506
+ output_video = gr.Video(visible=False)
507
+ # output_result_mask = gr.Image()
508
+
509
+
510
+
511
+ # When new video is uploaded
512
+ video_in.upload(
513
+ fn = preprocess_video_in,
514
+ inputs = [video_in],
515
+ outputs = [
516
+ first_frame_path,
517
+ tracking_points, # update Tracking Points in the gr.State([]) object
518
+ trackings_input_label, # update Tracking Labels in the gr.State([]) object
519
+ input_first_frame_image, # hidden component used as ref when clearing points
520
+ points_map, # Image component where we add new tracking points
521
+ video_frames_dir, # Array where frames from video_in are deep stored
522
+ scanned_frames, # Scanned frames by SAM2
523
+ stored_inference_state, # Sam2 inference state
524
+ stored_frame_names, #
525
+ video_in_drawer, # Accordion to hide uploaded video player
526
+ ],
527
+ queue = False
528
+ )
529
+
530
+
531
+ # triggered when we click on image to add new points
532
+ points_map.select(
533
+ fn = get_point,
534
+ inputs = [
535
+ point_type, # "include" or "exclude"
536
+ tracking_points, # get tracking_points values
537
+ trackings_input_label, # get tracking label values
538
+ input_first_frame_image, # gr.State() first frame path
539
+ ],
540
+ outputs = [
541
+ tracking_points, # updated with new points
542
+ trackings_input_label, # updated with corresponding labels
543
+ points_map, # updated image with points
544
+ ],
545
+ queue = False
546
+ )
547
+
548
+ # Clear every points clicked and added to the map
549
+ clear_points_btn.click(
550
+ fn = clear_points,
551
+ inputs = input_first_frame_image, # we get the untouched hidden image
552
+ outputs = [
553
+ first_frame_path,
554
+ tracking_points,
555
+ trackings_input_label,
556
+ points_map,
557
+ #stored_inference_state,
558
+ ],
559
+ queue=False
560
+ )
561
+
562
+
563
+ # change_current.click(
564
+ # fn = switch_working_frame,
565
+ # inputs = [working_frame, scanned_frames, video_frames_dir],
566
+ # outputs = [tracking_points, trackings_input_label, input_first_frame_image, points_map],
567
+ # queue=False
568
+ # )
569
+
570
+
571
+ submit_btn.click(
572
+ fn = get_mask_sam_process,
573
+ inputs = [
574
+ stored_inference_state,
575
+ input_first_frame_image,
576
+ checkpoint,
577
+ tracking_points,
578
+ trackings_input_label,
579
+ video_frames_dir,
580
+ scanned_frames,
581
+ working_frame,
582
+ available_frames_to_check,
583
+ ],
584
+ outputs = [
585
+ output_result,
586
+ stored_frame_names,
587
+ loaded_predictor,
588
+ stored_inference_state,
589
+ working_frame,
590
+ ],
591
+ queue=False
592
+ )
593
+
594
+ reset_prpgt_brn.click(
595
+ fn = reset_propagation,
596
+ inputs = [first_frame_path, loaded_predictor, stored_inference_state],
597
+ outputs = [points_map, tracking_points, trackings_input_label, output_propagated, stored_inference_state, output_result, available_frames_to_check, input_first_frame_image, working_frame, reset_prpgt_brn],
598
+ queue=False
599
+ )
600
 
601
+ propagate_btn.click(
602
+ fn = update_ui,
603
+ inputs = [vis_frame_type],
604
+ outputs = [output_propagated, output_video],
605
+ queue=False
606
+ ).then(
607
+ fn = propagate_to_all,
608
+ inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
609
+ outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
610
+ )
611
 
612
+ demo.queue().launch(show_api=False, show_error=True, share=True, server_name="0.0.0.0", server_port=11111)
 
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy>=1.24.4
4
+ tqdm>=4.66.1
5
+ hydra-core>=1.3.2
6
+ iopath>=0.1.10
7
+ pillow>=9.4.0
8
+ matplotlib>=3.9.1
9
+ opencv-python>=4.7.0
10
+ moviepy