MarcHabib commited on
Commit
546546e
1 Parent(s): 7e59580

Normalize line endings for Windows

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +20 -0
  2. README.md +1 -1
  3. app.log +0 -0
  4. app.py +797 -0
  5. checkpoints/download_ckpts.sh +59 -0
  6. checkpoints/sam2.1_hiera_large.pt +3 -0
  7. checkpoints/sam2_hiera_large.pt +3 -0
  8. requirements.txt +15 -0
  9. runtime.txt +1 -0
  10. sam2/__init__.py +11 -0
  11. sam2/__pycache__/__init__.cpython-311.pyc +0 -0
  12. sam2/__pycache__/build_sam.cpython-311.pyc +0 -0
  13. sam2/__pycache__/sam2_image_predictor.cpython-311.pyc +0 -0
  14. sam2/automatic_mask_generator.py +454 -0
  15. sam2/build_sam.py +167 -0
  16. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  17. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  18. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  19. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  20. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  21. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  22. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  23. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  24. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  25. sam2/csrc/connected_components.cu +289 -0
  26. sam2/modeling/__init__.py +5 -0
  27. sam2/modeling/__pycache__/__init__.cpython-311.pyc +0 -0
  28. sam2/modeling/__pycache__/memory_attention.cpython-311.pyc +0 -0
  29. sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc +0 -0
  30. sam2/modeling/__pycache__/position_encoding.cpython-311.pyc +0 -0
  31. sam2/modeling/__pycache__/sam2_base.cpython-311.pyc +0 -0
  32. sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc +0 -0
  33. sam2/modeling/backbones/__init__.py +5 -0
  34. sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc +0 -0
  35. sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc +0 -0
  36. sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc +0 -0
  37. sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc +0 -0
  38. sam2/modeling/backbones/hieradet.py +317 -0
  39. sam2/modeling/backbones/image_encoder.py +134 -0
  40. sam2/modeling/backbones/utils.py +95 -0
  41. sam2/modeling/memory_attention.py +169 -0
  42. sam2/modeling/memory_encoder.py +181 -0
  43. sam2/modeling/position_encoding.py +221 -0
  44. sam2/modeling/sam/__init__.py +5 -0
  45. sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc +0 -0
  46. sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc +0 -0
  47. sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc +0 -0
  48. sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc +0 -0
  49. sam2/modeling/sam/mask_decoder.py +295 -0
  50. sam2/modeling/sam/prompt_encoder.py +182 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11.0
2
+
3
+ RUN apt-get update && apt-get install -y \
4
+ libgl1 \
5
+ libglib2.0-0 \
6
+ && rm -rf /var/lib/apt/lists/*
7
+
8
+ RUN useradd -m -u 1000 user
9
+ USER user
10
+ ENV PATH="/home/user/.local/bin:$PATH"
11
+
12
+ WORKDIR /app
13
+
14
+ COPY --chown=user ./requirements.txt requirements.txt
15
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
16
+
17
+ COPY --chown=user . /app
18
+ EXPOSE 7860
19
+
20
+ CMD ["gunicorn","-b","0.0.0.0:7860", "app:app"]
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Segmentation Yolo Sam
3
- emoji: 📉
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: docker
 
1
  ---
2
  title: Segmentation Yolo Sam
3
+ emoji: 🏆
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: docker
app.log ADDED
File without changes
app.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from flask_socketio import SocketIO
3
+ import os
4
+ import shutil
5
+ import numpy as np
6
+ from PIL import Image
7
+ from utils.predictor import Predictor
8
+ from utils.helpers import (
9
+ blend_mask_with_image,
10
+ save_mask_as_png,
11
+ convert_mask_to_yolo,
12
+ )
13
+ import torch
14
+ from ultralytics import YOLO
15
+ import threading
16
+ from threading import Lock
17
+ import subprocess
18
+ import time
19
+ import logging
20
+ import multiprocessing
21
+
22
+
23
+ # Initialize Flask app and SocketIO
24
+ app = Flask(__name__)
25
+ socketio = SocketIO(app)
26
+
27
+ # Define Base Directory
28
+ BASE_DIR = os.path.abspath(os.path.dirname(__file__))
29
+
30
+ # Folder structure with absolute paths
31
+ UPLOAD_FOLDERS = {
32
+ 'input': os.path.join(BASE_DIR, 'static/uploads/input'),
33
+ 'segmented_voids': os.path.join(BASE_DIR, 'static/uploads/segmented/voids'),
34
+ 'segmented_chips': os.path.join(BASE_DIR, 'static/uploads/segmented/chips'),
35
+ 'mask_voids': os.path.join(BASE_DIR, 'static/uploads/mask/voids'),
36
+ 'mask_chips': os.path.join(BASE_DIR, 'static/uploads/mask/chips'),
37
+ 'automatic_segmented': os.path.join(BASE_DIR, 'static/uploads/segmented/automatic'),
38
+ }
39
+
40
+ HISTORY_FOLDERS = {
41
+ 'images': os.path.join(BASE_DIR, 'static/history/images'),
42
+ 'masks_chip': os.path.join(BASE_DIR, 'static/history/masks/chip'),
43
+ 'masks_void': os.path.join(BASE_DIR, 'static/history/masks/void'),
44
+ }
45
+
46
+ DATASET_FOLDERS = {
47
+ 'train_images': os.path.join(BASE_DIR, 'dataset/train/images'),
48
+ 'train_labels': os.path.join(BASE_DIR, 'dataset/train/labels'),
49
+ 'val_images': os.path.join(BASE_DIR, 'dataset/val/images'),
50
+ 'val_labels': os.path.join(BASE_DIR, 'dataset/val/labels'),
51
+ 'temp_backup': os.path.join(BASE_DIR, 'temp_backup'),
52
+ 'models': os.path.join(BASE_DIR, 'models'),
53
+ 'models_old': os.path.join(BASE_DIR, 'models/old'),
54
+ }
55
+
56
+ # Ensure all folders exist
57
+ for folder_name, folder_path in {**UPLOAD_FOLDERS, **HISTORY_FOLDERS, **DATASET_FOLDERS}.items():
58
+ os.makedirs(folder_path, exist_ok=True)
59
+ logging.info(f"Ensured folder exists: {folder_name} -> {folder_path}")
60
+
61
+ training_process = None
62
+
63
+
64
+ def initialize_training_status():
65
+ """Initialize global training status."""
66
+ global training_status
67
+ training_status = {'running': False, 'cancelled': False}
68
+
69
+ def persist_training_status():
70
+ """Save training status to a file."""
71
+ with open(os.path.join(BASE_DIR, 'training_status.json'), 'w') as status_file:
72
+ json.dump(training_status, status_file)
73
+
74
+ def load_training_status():
75
+ """Load training status from a file."""
76
+ global training_status
77
+ status_path = os.path.join(BASE_DIR, 'training_status.json')
78
+ if os.path.exists(status_path):
79
+ with open(status_path, 'r') as status_file:
80
+ training_status = json.load(status_file)
81
+ else:
82
+ training_status = {'running': False, 'cancelled': False}
83
+
84
+ load_training_status()
85
+
86
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"
87
+
88
+ # Initialize SAM Predictor
89
+ MODEL_CFG = r"\sam2\configs\sam2.1\sam2.1_hiera_l.yaml"
90
+ CHECKPOINT = r"\checkpoints\sam2.1_hiera_large.pt"
91
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ predictor = Predictor(MODEL_CFG, CHECKPOINT, DEVICE)
93
+
94
+ # Initialize YOLO-seg
95
+ YOLO_CFG = os.path.join(DATASET_FOLDERS['models'], "best.pt")
96
+ yolo_model = YOLO(YOLO_CFG)
97
+
98
+ # Configure logging
99
+ logging.basicConfig(
100
+ level=logging.INFO,
101
+ format='%(asctime)s [%(levelname)s] %(message)s',
102
+ handlers=[
103
+ logging.StreamHandler(),
104
+ logging.FileHandler(os.path.join(BASE_DIR, "app.log")) # Log to a file
105
+ ]
106
+ )
107
+
108
+
109
+ @app.route('/')
110
+ def index():
111
+ """Serve the main UI."""
112
+ return render_template('index.html')
113
+
114
+ @app.route('/upload', methods=['POST'])
115
+ def upload_image():
116
+ """Handle image uploads."""
117
+ if 'file' not in request.files:
118
+ return jsonify({'error': 'No file uploaded'}), 400
119
+ file = request.files['file']
120
+ if file.filename == '':
121
+ return jsonify({'error': 'No file selected'}), 400
122
+
123
+ # Save the uploaded file to the input folder
124
+ input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename)
125
+ file.save(input_path)
126
+
127
+ # Set the uploaded image in the predictor
128
+ image = np.array(Image.open(input_path).convert("RGB"))
129
+ predictor.set_image(image)
130
+
131
+ # Return a web-accessible URL instead of the file system path
132
+ web_accessible_url = f"/static/uploads/input/{file.filename}"
133
+ print(f"Image uploaded and set for prediction: {input_path}")
134
+ return jsonify({'image_url': web_accessible_url})
135
+
136
+ @app.route('/segment', methods=['POST'])
137
+ def segment():
138
+ """
139
+ Perform segmentation and return the blended image URL.
140
+ """
141
+ try:
142
+ # Extract data from request
143
+ data = request.json
144
+ points = np.array(data.get('points', []))
145
+ labels = np.array(data.get('labels', []))
146
+ current_class = data.get('class', 'voids') # Default to 'voids' if class not provided
147
+
148
+ # Ensure predictor has an image set
149
+ if not predictor.image_set:
150
+ raise ValueError("No image set for prediction.")
151
+
152
+ # Perform SAM prediction
153
+ masks, _, _ = predictor.predict(
154
+ point_coords=points,
155
+ point_labels=labels,
156
+ multimask_output=False
157
+ )
158
+
159
+ # Check if masks exist and have non-zero elements
160
+ if masks is None or masks.size == 0:
161
+ raise RuntimeError("No masks were generated by the predictor.")
162
+
163
+ # Define output paths based on class
164
+ mask_folder = UPLOAD_FOLDERS.get(f'mask_{current_class}')
165
+ segmented_folder = UPLOAD_FOLDERS.get(f'segmented_{current_class}')
166
+
167
+ if not mask_folder or not segmented_folder:
168
+ raise ValueError(f"Invalid class '{current_class}' provided.")
169
+
170
+ os.makedirs(mask_folder, exist_ok=True)
171
+ os.makedirs(segmented_folder, exist_ok=True)
172
+
173
+ # Save the raw mask
174
+ mask_path = os.path.join(mask_folder, 'raw_mask.png')
175
+ save_mask_as_png(masks[0], mask_path)
176
+
177
+ # Generate blended image
178
+ blend_color = [34, 139, 34] if current_class == 'voids' else [30, 144, 255] # Green for voids, blue for chips
179
+ blended_image = blend_mask_with_image(predictor.image, masks[0], blend_color)
180
+
181
+ # Save blended image
182
+ blended_filename = f"blended_{current_class}.png"
183
+ blended_path = os.path.join(segmented_folder, blended_filename)
184
+ Image.fromarray(blended_image).save(blended_path)
185
+
186
+ # Return URL for frontend access
187
+ segmented_url = f"/static/uploads/segmented/{current_class}/{blended_filename}"
188
+ logging.info(f"Segmentation completed for {current_class}. Points: {points}, Labels: {labels}")
189
+ return jsonify({'segmented_url': segmented_url})
190
+
191
+ except ValueError as ve:
192
+ logging.error(f"Value error during segmentation: {ve}")
193
+ return jsonify({'error': str(ve)}), 400
194
+
195
+ except Exception as e:
196
+ logging.error(f"Unexpected error during segmentation: {e}")
197
+ return jsonify({'error': 'Segmentation failed', 'details': str(e)}), 500
198
+
199
+ @app.route('/automatic_segment', methods=['POST'])
200
+ def automatic_segment():
201
+ """Perform automatic segmentation using YOLO."""
202
+ if 'file' not in request.files:
203
+ return jsonify({'error': 'No file uploaded'}), 400
204
+ file = request.files['file']
205
+ if file.filename == '':
206
+ return jsonify({'error': 'No file selected'}), 400
207
+
208
+ input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename)
209
+ file.save(input_path)
210
+
211
+ try:
212
+ # Perform YOLO segmentation
213
+ results = yolo_model.predict(input_path, save=False, save_txt=False)
214
+ output_folder = UPLOAD_FOLDERS['automatic_segmented']
215
+ os.makedirs(output_folder, exist_ok=True)
216
+
217
+ chips_data = []
218
+ chips = []
219
+ voids = []
220
+
221
+ # Process results and save segmented images
222
+ for result in results:
223
+ annotated_image = result.plot()
224
+ result_filename = f"{file.filename.rsplit('.', 1)[0]}_pred.jpg"
225
+ result_path = os.path.join(output_folder, result_filename)
226
+ Image.fromarray(annotated_image).save(result_path)
227
+
228
+ # Separate chips and voids
229
+ for i, label in enumerate(result.boxes.cls): # YOLO labels
230
+ label_name = result.names[int(label)] # Get label name (e.g., 'chip' or 'void')
231
+ box = result.boxes.xyxy[i].cpu().numpy() # Bounding box (x1, y1, x2, y2)
232
+ area = float((box[2] - box[0]) * (box[3] - box[1])) # Calculate area
233
+
234
+ if label_name == 'chip':
235
+ chips.append({'box': box, 'area': area, 'voids': []})
236
+ elif label_name == 'void':
237
+ voids.append({'box': box, 'area': area})
238
+
239
+ # Assign voids to chips based on proximity
240
+ for void in voids:
241
+ void_centroid = [
242
+ (void['box'][0] + void['box'][2]) / 2, # x centroid
243
+ (void['box'][1] + void['box'][3]) / 2 # y centroid
244
+ ]
245
+ for chip in chips:
246
+ # Check if void centroid is within chip bounding box
247
+ if (chip['box'][0] <= void_centroid[0] <= chip['box'][2] and
248
+ chip['box'][1] <= void_centroid[1] <= chip['box'][3]):
249
+ chip['voids'].append(void)
250
+ break
251
+
252
+ # Calculate metrics for each chip
253
+ for idx, chip in enumerate(chips):
254
+ chip_area = chip['area']
255
+ total_void_area = sum([float(void['area']) for void in chip['voids']])
256
+ max_void_area = max([float(void['area']) for void in chip['voids']], default=0)
257
+
258
+ void_percentage = (total_void_area / chip_area) * 100 if chip_area > 0 else 0
259
+ max_void_percentage = (max_void_area / chip_area) * 100 if chip_area > 0 else 0
260
+
261
+ chips_data.append({
262
+ "chip_number": int(idx + 1),
263
+ "chip_area": round(chip_area, 2),
264
+ "void_percentage": round(void_percentage, 2),
265
+ "max_void_percentage": round(max_void_percentage, 2)
266
+ })
267
+
268
+ # Return the segmented image URL and table data
269
+ segmented_url = f"/static/uploads/segmented/automatic/{result_filename}"
270
+ return jsonify({
271
+ "segmented_url": segmented_url, # Use the URL for frontend access
272
+ "table_data": {
273
+ "image_name": file.filename,
274
+ "chips": chips_data
275
+ }
276
+ })
277
+
278
+ except Exception as e:
279
+ print(f"Error in automatic segmentation: {e}")
280
+ return jsonify({'error': 'Segmentation failed.'}), 500
281
+
282
+ @app.route('/save_both', methods=['POST'])
283
+ def save_both():
284
+ """Save both the image and masks into the history folders."""
285
+ data = request.json
286
+ image_name = data.get('image_name')
287
+
288
+ if not image_name:
289
+ return jsonify({'error': 'Image name not provided'}), 400
290
+
291
+ try:
292
+ # Ensure image_name is a pure file name
293
+ image_name = os.path.basename(image_name) # Strip any directory path
294
+ print(f"Sanitized Image Name: {image_name}")
295
+
296
+ # Correctly resolve the input image path
297
+ input_image_path = os.path.join(UPLOAD_FOLDERS['input'], image_name)
298
+ if not os.path.exists(input_image_path):
299
+ print(f"Input image does not exist: {input_image_path}")
300
+ return jsonify({'error': f'Input image not found: {input_image_path}'}), 404
301
+
302
+ # Copy the image to history/images
303
+ image_history_path = os.path.join(HISTORY_FOLDERS['images'], image_name)
304
+ os.makedirs(os.path.dirname(image_history_path), exist_ok=True)
305
+ shutil.copy(input_image_path, image_history_path)
306
+ print(f"Image saved to history: {image_history_path}")
307
+
308
+ # Backup void mask
309
+ void_mask_path = os.path.join(UPLOAD_FOLDERS['mask_voids'], 'raw_mask.png')
310
+ if os.path.exists(void_mask_path):
311
+ void_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png")
312
+ os.makedirs(os.path.dirname(void_mask_history_path), exist_ok=True)
313
+ shutil.copy(void_mask_path, void_mask_history_path)
314
+ print(f"Voids mask saved to history: {void_mask_history_path}")
315
+ else:
316
+ print(f"Voids mask not found: {void_mask_path}")
317
+
318
+ # Backup chip mask
319
+ chip_mask_path = os.path.join(UPLOAD_FOLDERS['mask_chips'], 'raw_mask.png')
320
+ if os.path.exists(chip_mask_path):
321
+ chip_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png")
322
+ os.makedirs(os.path.dirname(chip_mask_history_path), exist_ok=True)
323
+ shutil.copy(chip_mask_path, chip_mask_history_path)
324
+ print(f"Chips mask saved to history: {chip_mask_history_path}")
325
+ else:
326
+ print(f"Chips mask not found: {chip_mask_path}")
327
+
328
+ return jsonify({'message': 'Image and masks saved successfully!'}), 200
329
+
330
+ except Exception as e:
331
+ print(f"Error saving files: {e}")
332
+ return jsonify({'error': 'Failed to save files.', 'details': str(e)}), 500
333
+
334
+ @app.route('/get_history', methods=['GET'])
335
+ def get_history():
336
+ try:
337
+ saved_images = os.listdir(HISTORY_FOLDERS['images'])
338
+ return jsonify({'status': 'success', 'images': saved_images}), 200
339
+ except Exception as e:
340
+ return jsonify({'status': 'error', 'message': f'Failed to fetch history: {e}'}), 500
341
+
342
+
343
+ @app.route('/delete_history_item', methods=['POST'])
344
+ def delete_history_item():
345
+ data = request.json
346
+ image_name = data.get('image_name')
347
+
348
+ if not image_name:
349
+ return jsonify({'error': 'Image name not provided'}), 400
350
+
351
+ try:
352
+ image_path = os.path.join(HISTORY_FOLDERS['images'], image_name)
353
+ if os.path.exists(image_path):
354
+ os.remove(image_path)
355
+
356
+ void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png")
357
+ if os.path.exists(void_mask_path):
358
+ os.remove(void_mask_path)
359
+
360
+ chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png")
361
+ if os.path.exists(chip_mask_path):
362
+ os.remove(chip_mask_path)
363
+
364
+ return jsonify({'message': f'{image_name} and associated masks deleted successfully.'}), 200
365
+ except Exception as e:
366
+ return jsonify({'error': f'Failed to delete files: {e}'}), 500
367
+
368
+ # Lock for training status updates
369
+ status_lock = Lock()
370
+
371
+ def update_training_status(key, value):
372
+ """Thread-safe update for training status."""
373
+ with status_lock:
374
+ training_status[key] = value
375
+
376
+ @app.route('/retrain_model', methods=['POST'])
377
+ def retrain_model():
378
+ """Handle retrain model workflow."""
379
+ global training_status
380
+
381
+ if training_status.get('running', False):
382
+ return jsonify({'error': 'Training is already in progress'}), 400
383
+
384
+ try:
385
+ # Update training status
386
+ update_training_status('running', True)
387
+ update_training_status('cancelled', False)
388
+ logging.info("Training status updated. Starting training workflow.")
389
+
390
+ # Backup masks and images
391
+ backup_masks_and_images()
392
+ logging.info("Backup completed successfully.")
393
+
394
+ # Prepare YOLO labels
395
+ prepare_yolo_labels()
396
+ logging.info("YOLO labels prepared successfully.")
397
+
398
+ # Start YOLO training in a separate thread
399
+ threading.Thread(target=run_yolo_training).start()
400
+ return jsonify({'message': 'Training started successfully!'}), 200
401
+
402
+ except Exception as e:
403
+ logging.error(f"Error during training preparation: {e}")
404
+ update_training_status('running', False)
405
+ return jsonify({'error': f"Failed to start training: {e}"}), 500
406
+
407
+ def prepare_yolo_labels():
408
+ """Convert all masks into YOLO-compatible labels and copy images to the dataset folder."""
409
+ images_folder = HISTORY_FOLDERS['images'] # Use history images as the source
410
+ train_labels_folder = DATASET_FOLDERS['train_labels']
411
+ train_images_folder = DATASET_FOLDERS['train_images']
412
+ val_labels_folder = DATASET_FOLDERS['val_labels']
413
+ val_images_folder = DATASET_FOLDERS['val_images']
414
+
415
+ # Ensure destination directories exist
416
+ os.makedirs(train_labels_folder, exist_ok=True)
417
+ os.makedirs(train_images_folder, exist_ok=True)
418
+ os.makedirs(val_labels_folder, exist_ok=True)
419
+ os.makedirs(val_images_folder, exist_ok=True)
420
+
421
+ try:
422
+ all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))]
423
+ random.shuffle(all_images) # Shuffle the images for randomness
424
+
425
+ # Determine split index
426
+ split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation
427
+
428
+ # Split images into train and validation sets
429
+ train_images = all_images[:split_idx]
430
+ val_images = all_images[split_idx:]
431
+
432
+ # Process training images
433
+ for image_name in train_images:
434
+ process_image_and_mask(
435
+ image_name,
436
+ source_images_folder=images_folder,
437
+ dest_images_folder=train_images_folder,
438
+ dest_labels_folder=train_labels_folder
439
+ )
440
+
441
+ # Process validation images
442
+ for image_name in val_images:
443
+ process_image_and_mask(
444
+ image_name,
445
+ source_images_folder=images_folder,
446
+ dest_images_folder=val_images_folder,
447
+ dest_labels_folder=val_labels_folder
448
+ )
449
+
450
+ logging.info("YOLO labels prepared, and images split into train and validation successfully.")
451
+
452
+ except Exception as e:
453
+ logging.error(f"Error in preparing YOLO labels: {e}")
454
+ raise
455
+
456
+ import random
457
+
458
+ def prepare_yolo_labels():
459
+ """Convert all masks into YOLO-compatible labels and copy images to the dataset folder."""
460
+ images_folder = HISTORY_FOLDERS['images'] # Use history images as the source
461
+ train_labels_folder = DATASET_FOLDERS['train_labels']
462
+ train_images_folder = DATASET_FOLDERS['train_images']
463
+ val_labels_folder = DATASET_FOLDERS['val_labels']
464
+ val_images_folder = DATASET_FOLDERS['val_images']
465
+
466
+ # Ensure destination directories exist
467
+ os.makedirs(train_labels_folder, exist_ok=True)
468
+ os.makedirs(train_images_folder, exist_ok=True)
469
+ os.makedirs(val_labels_folder, exist_ok=True)
470
+ os.makedirs(val_images_folder, exist_ok=True)
471
+
472
+ try:
473
+ all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))]
474
+ random.shuffle(all_images) # Shuffle the images for randomness
475
+
476
+ # Determine split index
477
+ split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation
478
+
479
+ # Split images into train and validation sets
480
+ train_images = all_images[:split_idx]
481
+ val_images = all_images[split_idx:]
482
+
483
+ # Process training images
484
+ for image_name in train_images:
485
+ process_image_and_mask(
486
+ image_name,
487
+ source_images_folder=images_folder,
488
+ dest_images_folder=train_images_folder,
489
+ dest_labels_folder=train_labels_folder
490
+ )
491
+
492
+ # Process validation images
493
+ for image_name in val_images:
494
+ process_image_and_mask(
495
+ image_name,
496
+ source_images_folder=images_folder,
497
+ dest_images_folder=val_images_folder,
498
+ dest_labels_folder=val_labels_folder
499
+ )
500
+
501
+ logging.info("YOLO labels prepared, and images split into train and validation successfully.")
502
+
503
+ except Exception as e:
504
+ logging.error(f"Error in preparing YOLO labels: {e}")
505
+ raise
506
+
507
+
508
+ def process_image_and_mask(image_name, source_images_folder, dest_images_folder, dest_labels_folder):
509
+ """
510
+ Process a single image and its masks, saving them in the appropriate YOLO format.
511
+ """
512
+ try:
513
+ image_path = os.path.join(source_images_folder, image_name)
514
+ label_file_path = os.path.join(dest_labels_folder, f"{os.path.splitext(image_name)[0]}.txt")
515
+
516
+ # Copy image to the destination images folder
517
+ shutil.copy(image_path, os.path.join(dest_images_folder, image_name))
518
+
519
+ # Clear the label file if it exists
520
+ if os.path.exists(label_file_path):
521
+ os.remove(label_file_path)
522
+
523
+ # Process void mask
524
+ void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png")
525
+ if os.path.exists(void_mask_path):
526
+ convert_mask_to_yolo(
527
+ mask_path=void_mask_path,
528
+ image_path=image_path,
529
+ class_id=0, # Void class
530
+ output_path=label_file_path
531
+ )
532
+
533
+ # Process chip mask
534
+ chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png")
535
+ if os.path.exists(chip_mask_path):
536
+ convert_mask_to_yolo(
537
+ mask_path=chip_mask_path,
538
+ image_path=image_path,
539
+ class_id=1, # Chip class
540
+ output_path=label_file_path,
541
+ append=True # Append chip annotations
542
+ )
543
+
544
+ logging.info(f"Processed {image_name} into YOLO format.")
545
+ except Exception as e:
546
+ logging.error(f"Error processing {image_name}: {e}")
547
+ raise
548
+
549
+ def backup_masks_and_images():
550
+ """Backup current masks and images from history folders."""
551
+ temp_backup_paths = {
552
+ 'voids': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/voids'),
553
+ 'chips': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/chips'),
554
+ 'images': os.path.join(DATASET_FOLDERS['temp_backup'], 'images')
555
+ }
556
+
557
+ # Prepare all backup directories
558
+ for path in temp_backup_paths.values():
559
+ if os.path.exists(path):
560
+ shutil.rmtree(path)
561
+ os.makedirs(path, exist_ok=True)
562
+
563
+ try:
564
+ # Backup images from history
565
+ for file in os.listdir(HISTORY_FOLDERS['images']):
566
+ src_image_path = os.path.join(HISTORY_FOLDERS['images'], file)
567
+ dst_image_path = os.path.join(temp_backup_paths['images'], file)
568
+ shutil.copy(src_image_path, dst_image_path)
569
+
570
+ # Backup void masks from history
571
+ for file in os.listdir(HISTORY_FOLDERS['masks_void']):
572
+ src_void_path = os.path.join(HISTORY_FOLDERS['masks_void'], file)
573
+ dst_void_path = os.path.join(temp_backup_paths['voids'], file)
574
+ shutil.copy(src_void_path, dst_void_path)
575
+
576
+ # Backup chip masks from history
577
+ for file in os.listdir(HISTORY_FOLDERS['masks_chip']):
578
+ src_chip_path = os.path.join(HISTORY_FOLDERS['masks_chip'], file)
579
+ dst_chip_path = os.path.join(temp_backup_paths['chips'], file)
580
+ shutil.copy(src_chip_path, dst_chip_path)
581
+
582
+ logging.info("Masks and images backed up successfully from history.")
583
+ except Exception as e:
584
+ logging.error(f"Error during backup: {e}")
585
+ raise RuntimeError("Backup process failed.")
586
+
587
+ def run_yolo_training(num_epochs=10):
588
+ """Run YOLO training process."""
589
+ global training_process
590
+
591
+ try:
592
+ device = "cuda" if torch.cuda.is_available() else "cpu"
593
+ data_cfg_path = os.path.join(BASE_DIR, "models/data.yaml") # Ensure correct YAML path
594
+
595
+ logging.info(f"Starting YOLO training on {device} with {num_epochs} epochs.")
596
+ logging.info(f"Using dataset configuration: {data_cfg_path}")
597
+
598
+ training_command = [
599
+ "yolo",
600
+ "train",
601
+ f"data={data_cfg_path}",
602
+ f"model={os.path.join(DATASET_FOLDERS['models'], 'best.pt')}",
603
+ f"device={device}",
604
+ f"epochs={num_epochs}",
605
+ "project=runs",
606
+ "name=train"
607
+ ]
608
+
609
+ training_process = subprocess.Popen(
610
+ training_command,
611
+ stdout=subprocess.PIPE,
612
+ stderr=subprocess.STDOUT,
613
+ text=True,
614
+ env=os.environ.copy(),
615
+ )
616
+
617
+ # Display and log output in real time
618
+ for line in iter(training_process.stdout.readline, ''):
619
+ print(line.strip())
620
+ logging.info(line.strip())
621
+ socketio.emit('training_update', {'message': line.strip()}) # Send updates to the frontend
622
+
623
+ training_process.wait()
624
+
625
+ if training_process.returncode == 0:
626
+ finalize_training() # Finalize successfully completed training
627
+ else:
628
+ raise RuntimeError("YOLO training process failed. Check logs for details.")
629
+ except Exception as e:
630
+ logging.error(f"Training error: {e}")
631
+ restore_backup() # Restore the dataset and masks
632
+
633
+ # Emit training error event to the frontend
634
+ socketio.emit('training_status', {'status': 'error', 'message': f"Training failed: {str(e)}"})
635
+ finally:
636
+ update_training_status('running', False)
637
+ training_process = None # Reset the process
638
+
639
+
640
+ @socketio.on('cancel_training')
641
+ def handle_cancel_training():
642
+ """Cancel the YOLO training process."""
643
+ global training_process, training_status
644
+
645
+ if not training_status.get('running', False):
646
+ socketio.emit('button_update', {'action': 'retrain'}) # Update button to retrain
647
+ return
648
+
649
+ try:
650
+ training_process.terminate()
651
+ training_process.wait()
652
+ training_status['running'] = False
653
+ training_status['cancelled'] = True
654
+
655
+ restore_backup()
656
+ cleanup_train_val_directories()
657
+
658
+ # Emit button state change
659
+ socketio.emit('button_update', {'action': 'retrain'})
660
+ socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'})
661
+ except Exception as e:
662
+ logging.error(f"Error cancelling training: {e}")
663
+ socketio.emit('training_status', {'status': 'error', 'message': str(e)})
664
+
665
+ def finalize_training():
666
+ """Finalize training by promoting the new model and cleaning up."""
667
+ try:
668
+ # Locate the most recent training directory
669
+ runs_dir = os.path.join(BASE_DIR, 'runs')
670
+ if not os.path.exists(runs_dir):
671
+ raise FileNotFoundError("Training runs directory does not exist.")
672
+
673
+ # Get the latest training run folder
674
+ latest_run = max(
675
+ [os.path.join(runs_dir, d) for d in os.listdir(runs_dir)],
676
+ key=os.path.getmtime
677
+ )
678
+ weights_dir = os.path.join(latest_run, 'weights')
679
+ best_model_path = os.path.join(weights_dir, 'best.pt')
680
+
681
+ if not os.path.exists(best_model_path):
682
+ raise FileNotFoundError(f"'best.pt' not found in {weights_dir}.")
683
+
684
+ # Backup the old model
685
+ old_model_folder = DATASET_FOLDERS['models_old']
686
+ os.makedirs(old_model_folder, exist_ok=True)
687
+ existing_best_model = os.path.join(DATASET_FOLDERS['models'], 'best.pt')
688
+
689
+ if os.path.exists(existing_best_model):
690
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
691
+ shutil.move(existing_best_model, os.path.join(old_model_folder, f"old_{timestamp}.pt"))
692
+ logging.info(f"Old model backed up to {old_model_folder}.")
693
+
694
+ # Move the new model to the models directory
695
+ new_model_dest = os.path.join(DATASET_FOLDERS['models'], 'best.pt')
696
+ shutil.move(best_model_path, new_model_dest)
697
+ logging.info(f"New model saved to {new_model_dest}.")
698
+
699
+ # Notify frontend that training is completed
700
+ socketio.emit('training_status', {
701
+ 'status': 'completed',
702
+ 'message': 'Training completed successfully! Model saved as best.pt.'
703
+ })
704
+
705
+ # Clean up train/val directories
706
+ cleanup_train_val_directories()
707
+ logging.info("Train and validation directories cleaned up successfully.")
708
+
709
+ except Exception as e:
710
+ logging.error(f"Error finalizing training: {e}")
711
+ # Emit error status to the frontend
712
+ socketio.emit('training_status', {'status': 'error', 'message': f"Error finalizing training: {str(e)}"})
713
+
714
+ def restore_backup():
715
+ """Restore the dataset and masks from the backup."""
716
+ try:
717
+ temp_backup = DATASET_FOLDERS['temp_backup']
718
+ shutil.copytree(os.path.join(temp_backup, 'masks/voids'), UPLOAD_FOLDERS['mask_voids'], dirs_exist_ok=True)
719
+ shutil.copytree(os.path.join(temp_backup, 'masks/chips'), UPLOAD_FOLDERS['mask_chips'], dirs_exist_ok=True)
720
+ shutil.copytree(os.path.join(temp_backup, 'images'), UPLOAD_FOLDERS['input'], dirs_exist_ok=True)
721
+ logging.info("Backup restored successfully.")
722
+ except Exception as e:
723
+ logging.error(f"Error restoring backup: {e}")
724
+
725
+ @app.route('/cancel_training', methods=['POST'])
726
+ def cancel_training():
727
+ global training_process
728
+
729
+ if training_process is None:
730
+ logging.error("No active training process to terminate.")
731
+ return jsonify({'error': 'No active training process to cancel.'}), 400
732
+
733
+ try:
734
+ training_process.terminate()
735
+ training_process.wait()
736
+ training_process = None # Reset the process after termination
737
+
738
+ # Update training status
739
+ update_training_status('running', False)
740
+ update_training_status('cancelled', True)
741
+
742
+ # Check if the model is already saved as best.pt
743
+ best_model_path = os.path.join(DATASET_FOLDERS['models'], 'best.pt')
744
+ if os.path.exists(best_model_path):
745
+ logging.info(f"Model already saved as best.pt at {best_model_path}.")
746
+ socketio.emit('button_update', {'action': 'revert'}) # Notify frontend to revert button state
747
+ else:
748
+ logging.info("Training canceled, but no new model was saved.")
749
+
750
+ # Restore backup if needed
751
+ restore_backup()
752
+ cleanup_train_val_directories()
753
+
754
+ # Emit status update to frontend
755
+ socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'})
756
+ return jsonify({'message': 'Training canceled and data restored successfully.'}), 200
757
+
758
+ except Exception as e:
759
+ logging.error(f"Error cancelling training: {e}")
760
+ return jsonify({'error': f"Failed to cancel training: {e}"}), 500
761
+
762
+ @app.route('/clear_history', methods=['POST'])
763
+ def clear_history():
764
+ try:
765
+ for folder in [HISTORY_FOLDERS['images'], HISTORY_FOLDERS['masks_chip'], HISTORY_FOLDERS['masks_void']]:
766
+ shutil.rmtree(folder, ignore_errors=True)
767
+ os.makedirs(folder, exist_ok=True) # Recreate the empty folder
768
+ return jsonify({'message': 'History cleared successfully!'}), 200
769
+ except Exception as e:
770
+ return jsonify({'error': f'Failed to clear history: {e}'}), 500
771
+
772
+ @app.route('/training_status', methods=['GET'])
773
+ def get_training_status():
774
+ """Return the current training status."""
775
+ if training_status.get('running', False):
776
+ return jsonify({'status': 'running', 'message': 'Training in progress.'}), 200
777
+ elif training_status.get('cancelled', False):
778
+ return jsonify({'status': 'cancelled', 'message': 'Training was cancelled.'}), 200
779
+ return jsonify({'status': 'idle', 'message': 'No training is currently running.'}), 200
780
+
781
+ def cleanup_train_val_directories():
782
+ """Clear the train and validation directories."""
783
+ try:
784
+ for folder in [DATASET_FOLDERS['train_images'], DATASET_FOLDERS['train_labels'],
785
+ DATASET_FOLDERS['val_images'], DATASET_FOLDERS['val_labels']]:
786
+ shutil.rmtree(folder, ignore_errors=True) # Remove folder contents
787
+ os.makedirs(folder, exist_ok=True) # Recreate empty folders
788
+ logging.info("Train and validation directories cleaned up successfully.")
789
+ except Exception as e:
790
+ logging.error(f"Error cleaning up train/val directories: {e}")
791
+
792
+
793
+ if __name__ == '__main__':
794
+ multiprocessing.set_start_method('spawn') # Required for multiprocessing on Windows
795
+ app.run(debug=True, use_reloader=False)
796
+
797
+
checkpoints/download_ckpts.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ # Use either wget or curl to download the checkpoints
10
+ if command -v wget &> /dev/null; then
11
+ CMD="wget"
12
+ elif command -v curl &> /dev/null; then
13
+ CMD="curl -L -O"
14
+ else
15
+ echo "Please install wget or curl to download the checkpoints."
16
+ exit 1
17
+ fi
18
+
19
+ # Define the URLs for SAM 2 checkpoints
20
+ # SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
21
+ # sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
22
+ # sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
23
+ # sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
24
+ # sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
25
+
26
+ # Download each of the four checkpoints using wget
27
+ # echo "Downloading sam2_hiera_tiny.pt checkpoint..."
28
+ # $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
29
+
30
+ # echo "Downloading sam2_hiera_small.pt checkpoint..."
31
+ # $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
32
+
33
+ # echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
34
+ # $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
35
+
36
+ # echo "Downloading sam2_hiera_large.pt checkpoint..."
37
+ # $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
38
+
39
+ # Define the URLs for SAM 2.1 checkpoints
40
+ SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
41
+ sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
42
+ sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
43
+ sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
44
+ sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
45
+
46
+ # SAM 2.1 checkpoints
47
+ echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
48
+ $CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
49
+
50
+ echo "Downloading sam2.1_hiera_small.pt checkpoint..."
51
+ $CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
52
+
53
+ echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
54
+ $CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
55
+
56
+ echo "Downloading sam2.1_hiera_large.pt checkpoint..."
57
+ $CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
58
+
59
+ echo "All checkpoints are downloaded successfully."
checkpoints/sam2.1_hiera_large.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318
3
+ size 898083611
checkpoints/sam2_hiera_large.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7442e4e9b732a508f80e141e7c2913437a3610ee0c77381a66658c3a445df87b
3
+ size 897952466
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord==0.6.0
2
+ Flask==3.1.0
3
+ Flask_SocketIO==5.4.1
4
+ huggingface_hub==0.24.6
5
+ hydra-core==1.3.2
6
+ iopath==0.1.10
7
+ numpy==2.1.3
8
+ omegaconf==2.3.0
9
+ opencv_python==4.10.0.84
10
+ opencv_python_headless==4.10.0.84
11
+ Pillow==11.0.0
12
+ pycocotools==2.0.8
13
+ torch==2.3.1
14
+ tqdm==4.66.5
15
+ ultralytics==8.3.35
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.11.1
sam2/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+ from hydra.core.global_hydra import GlobalHydra
9
+
10
+ if not GlobalHydra.instance().is_initialized():
11
+ initialize_config_module("sam2", version_base="1.2")
sam2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (515 Bytes). View file
 
sam2/__pycache__/build_sam.cpython-311.pyc ADDED
Binary file (5.73 kB). View file
 
sam2/__pycache__/sam2_image_predictor.cpython-311.pyc ADDED
Binary file (24 kB). View file
 
sam2/automatic_mask_generator.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
+
14
+ from sam2.modeling.sam2_base import SAM2Base
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from sam2.utils.amg import (
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ MaskData,
28
+ remove_small_regions,
29
+ rle_to_mask,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ uncrop_points,
33
+ )
34
+
35
+
36
+ class SAM2AutomaticMaskGenerator:
37
+ def __init__(
38
+ self,
39
+ model: SAM2Base,
40
+ points_per_side: Optional[int] = 32,
41
+ points_per_batch: int = 64,
42
+ pred_iou_thresh: float = 0.8,
43
+ stability_score_thresh: float = 0.95,
44
+ stability_score_offset: float = 1.0,
45
+ mask_threshold: float = 0.0,
46
+ box_nms_thresh: float = 0.7,
47
+ crop_n_layers: int = 0,
48
+ crop_nms_thresh: float = 0.7,
49
+ crop_overlap_ratio: float = 512 / 1500,
50
+ crop_n_points_downscale_factor: int = 1,
51
+ point_grids: Optional[List[np.ndarray]] = None,
52
+ min_mask_region_area: int = 0,
53
+ output_mode: str = "binary_mask",
54
+ use_m2m: bool = False,
55
+ multimask_output: bool = True,
56
+ **kwargs,
57
+ ) -> None:
58
+ """
59
+ Using a SAM 2 model, generates masks for the entire image.
60
+ Generates a grid of point prompts over the image, then filters
61
+ low quality and duplicate masks. The default settings are chosen
62
+ for SAM 2 with a HieraL backbone.
63
+
64
+ Arguments:
65
+ model (Sam): The SAM 2 model to use for mask prediction.
66
+ points_per_side (int or None): The number of points to be sampled
67
+ along one side of the image. The total number of points is
68
+ points_per_side**2. If None, 'point_grids' must provide explicit
69
+ point sampling.
70
+ points_per_batch (int): Sets the number of points run simultaneously
71
+ by the model. Higher numbers may be faster but use more GPU memory.
72
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
73
+ model's predicted mask quality.
74
+ stability_score_thresh (float): A filtering threshold in [0,1], using
75
+ the stability of the mask under changes to the cutoff used to binarize
76
+ the model's mask predictions.
77
+ stability_score_offset (float): The amount to shift the cutoff when
78
+ calculated the stability score.
79
+ mask_threshold (float): Threshold for binarizing the mask logits
80
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
81
+ suppression to filter duplicate masks.
82
+ crop_n_layers (int): If >0, mask prediction will be run again on
83
+ crops of the image. Sets the number of layers to run, where each
84
+ layer has 2**i_layer number of image crops.
85
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
86
+ suppression to filter duplicate masks between different crops.
87
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
88
+ In the first crop layer, crops will overlap by this fraction of
89
+ the image length. Later layers with more crops scale down this overlap.
90
+ crop_n_points_downscale_factor (int): The number of points-per-side
91
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
92
+ point_grids (list(np.ndarray) or None): A list over explicit grids
93
+ of points used for sampling, normalized to [0,1]. The nth grid in the
94
+ list is used in the nth crop layer. Exclusive with points_per_side.
95
+ min_mask_region_area (int): If >0, postprocessing will be applied
96
+ to remove disconnected regions and holes in masks with area smaller
97
+ than min_mask_region_area. Requires opencv.
98
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
99
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
100
+ For large resolutions, 'binary_mask' may consume large amounts of
101
+ memory.
102
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
103
+ multimask_output (bool): Whether to output multimask at each point of the grid.
104
+ """
105
+
106
+ assert (points_per_side is None) != (
107
+ point_grids is None
108
+ ), "Exactly one of points_per_side or point_grid must be provided."
109
+ if points_per_side is not None:
110
+ self.point_grids = build_all_layer_point_grids(
111
+ points_per_side,
112
+ crop_n_layers,
113
+ crop_n_points_downscale_factor,
114
+ )
115
+ elif point_grids is not None:
116
+ self.point_grids = point_grids
117
+ else:
118
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
119
+
120
+ assert output_mode in [
121
+ "binary_mask",
122
+ "uncompressed_rle",
123
+ "coco_rle",
124
+ ], f"Unknown output_mode {output_mode}."
125
+ if output_mode == "coco_rle":
126
+ try:
127
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
128
+ except ImportError as e:
129
+ print("Please install pycocotools")
130
+ raise e
131
+
132
+ self.predictor = SAM2ImagePredictor(
133
+ model,
134
+ max_hole_area=min_mask_region_area,
135
+ max_sprinkle_area=min_mask_region_area,
136
+ )
137
+ self.points_per_batch = points_per_batch
138
+ self.pred_iou_thresh = pred_iou_thresh
139
+ self.stability_score_thresh = stability_score_thresh
140
+ self.stability_score_offset = stability_score_offset
141
+ self.mask_threshold = mask_threshold
142
+ self.box_nms_thresh = box_nms_thresh
143
+ self.crop_n_layers = crop_n_layers
144
+ self.crop_nms_thresh = crop_nms_thresh
145
+ self.crop_overlap_ratio = crop_overlap_ratio
146
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
147
+ self.min_mask_region_area = min_mask_region_area
148
+ self.output_mode = output_mode
149
+ self.use_m2m = use_m2m
150
+ self.multimask_output = multimask_output
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
154
+ """
155
+ Load a pretrained model from the Hugging Face hub.
156
+
157
+ Arguments:
158
+ model_id (str): The Hugging Face repository ID.
159
+ **kwargs: Additional arguments to pass to the model constructor.
160
+
161
+ Returns:
162
+ (SAM2AutomaticMaskGenerator): The loaded model.
163
+ """
164
+ from sam2.build_sam import build_sam2_hf
165
+
166
+ sam_model = build_sam2_hf(model_id, **kwargs)
167
+ return cls(sam_model, **kwargs)
168
+
169
+ @torch.no_grad()
170
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
171
+ """
172
+ Generates masks for the given image.
173
+
174
+ Arguments:
175
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
176
+
177
+ Returns:
178
+ list(dict(str, any)): A list over records for masks. Each record is
179
+ a dict containing the following keys:
180
+ segmentation (dict(str, any) or np.ndarray): The mask. If
181
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
182
+ is a dictionary containing the RLE.
183
+ bbox (list(float)): The box around the mask, in XYWH format.
184
+ area (int): The area in pixels of the mask.
185
+ predicted_iou (float): The model's own prediction of the mask's
186
+ quality. This is filtered by the pred_iou_thresh parameter.
187
+ point_coords (list(list(float))): The point coordinates input
188
+ to the model to generate this mask.
189
+ stability_score (float): A measure of the mask's quality. This
190
+ is filtered on using the stability_score_thresh parameter.
191
+ crop_box (list(float)): The crop of the image used to generate
192
+ the mask, given in XYWH format.
193
+ """
194
+
195
+ # Generate masks
196
+ mask_data = self._generate_masks(image)
197
+
198
+ # Encode masks
199
+ if self.output_mode == "coco_rle":
200
+ mask_data["segmentations"] = [
201
+ coco_encode_rle(rle) for rle in mask_data["rles"]
202
+ ]
203
+ elif self.output_mode == "binary_mask":
204
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
205
+ else:
206
+ mask_data["segmentations"] = mask_data["rles"]
207
+
208
+ # Write mask records
209
+ curr_anns = []
210
+ for idx in range(len(mask_data["segmentations"])):
211
+ ann = {
212
+ "segmentation": mask_data["segmentations"][idx],
213
+ "area": area_from_rle(mask_data["rles"][idx]),
214
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
215
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
216
+ "point_coords": [mask_data["points"][idx].tolist()],
217
+ "stability_score": mask_data["stability_score"][idx].item(),
218
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
219
+ }
220
+ curr_anns.append(ann)
221
+
222
+ return curr_anns
223
+
224
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
225
+ orig_size = image.shape[:2]
226
+ crop_boxes, layer_idxs = generate_crop_boxes(
227
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
228
+ )
229
+
230
+ # Iterate over image crops
231
+ data = MaskData()
232
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
233
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
234
+ data.cat(crop_data)
235
+
236
+ # Remove duplicate masks between crops
237
+ if len(crop_boxes) > 1:
238
+ # Prefer masks from smaller crops
239
+ scores = 1 / box_area(data["crop_boxes"])
240
+ scores = scores.to(data["boxes"].device)
241
+ keep_by_nms = batched_nms(
242
+ data["boxes"].float(),
243
+ scores,
244
+ torch.zeros_like(data["boxes"][:, 0]), # categories
245
+ iou_threshold=self.crop_nms_thresh,
246
+ )
247
+ data.filter(keep_by_nms)
248
+ data.to_numpy()
249
+ return data
250
+
251
+ def _process_crop(
252
+ self,
253
+ image: np.ndarray,
254
+ crop_box: List[int],
255
+ crop_layer_idx: int,
256
+ orig_size: Tuple[int, ...],
257
+ ) -> MaskData:
258
+ # Crop the image and calculate embeddings
259
+ x0, y0, x1, y1 = crop_box
260
+ cropped_im = image[y0:y1, x0:x1, :]
261
+ cropped_im_size = cropped_im.shape[:2]
262
+ self.predictor.set_image(cropped_im)
263
+
264
+ # Get points for this crop
265
+ points_scale = np.array(cropped_im_size)[None, ::-1]
266
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
267
+
268
+ # Generate masks for this crop in batches
269
+ data = MaskData()
270
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
271
+ batch_data = self._process_batch(
272
+ points, cropped_im_size, crop_box, orig_size, normalize=True
273
+ )
274
+ data.cat(batch_data)
275
+ del batch_data
276
+ self.predictor.reset_predictor()
277
+
278
+ # Remove duplicates within this crop.
279
+ keep_by_nms = batched_nms(
280
+ data["boxes"].float(),
281
+ data["iou_preds"],
282
+ torch.zeros_like(data["boxes"][:, 0]), # categories
283
+ iou_threshold=self.box_nms_thresh,
284
+ )
285
+ data.filter(keep_by_nms)
286
+
287
+ # Return to the original image frame
288
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
289
+ data["points"] = uncrop_points(data["points"], crop_box)
290
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
291
+
292
+ return data
293
+
294
+ def _process_batch(
295
+ self,
296
+ points: np.ndarray,
297
+ im_size: Tuple[int, ...],
298
+ crop_box: List[int],
299
+ orig_size: Tuple[int, ...],
300
+ normalize=False,
301
+ ) -> MaskData:
302
+ orig_h, orig_w = orig_size
303
+
304
+ # Run model on this batch
305
+ points = torch.as_tensor(
306
+ points, dtype=torch.float32, device=self.predictor.device
307
+ )
308
+ in_points = self.predictor._transforms.transform_coords(
309
+ points, normalize=normalize, orig_hw=im_size
310
+ )
311
+ in_labels = torch.ones(
312
+ in_points.shape[0], dtype=torch.int, device=in_points.device
313
+ )
314
+ masks, iou_preds, low_res_masks = self.predictor._predict(
315
+ in_points[:, None, :],
316
+ in_labels[:, None],
317
+ multimask_output=self.multimask_output,
318
+ return_logits=True,
319
+ )
320
+
321
+ # Serialize predictions and store in MaskData
322
+ data = MaskData(
323
+ masks=masks.flatten(0, 1),
324
+ iou_preds=iou_preds.flatten(0, 1),
325
+ points=points.repeat_interleave(masks.shape[1], dim=0),
326
+ low_res_masks=low_res_masks.flatten(0, 1),
327
+ )
328
+ del masks
329
+
330
+ if not self.use_m2m:
331
+ # Filter by predicted IoU
332
+ if self.pred_iou_thresh > 0.0:
333
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
334
+ data.filter(keep_mask)
335
+
336
+ # Calculate and filter by stability score
337
+ data["stability_score"] = calculate_stability_score(
338
+ data["masks"], self.mask_threshold, self.stability_score_offset
339
+ )
340
+ if self.stability_score_thresh > 0.0:
341
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
342
+ data.filter(keep_mask)
343
+ else:
344
+ # One step refinement using previous mask predictions
345
+ in_points = self.predictor._transforms.transform_coords(
346
+ data["points"], normalize=normalize, orig_hw=im_size
347
+ )
348
+ labels = torch.ones(
349
+ in_points.shape[0], dtype=torch.int, device=in_points.device
350
+ )
351
+ masks, ious = self.refine_with_m2m(
352
+ in_points, labels, data["low_res_masks"], self.points_per_batch
353
+ )
354
+ data["masks"] = masks.squeeze(1)
355
+ data["iou_preds"] = ious.squeeze(1)
356
+
357
+ if self.pred_iou_thresh > 0.0:
358
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
359
+ data.filter(keep_mask)
360
+
361
+ data["stability_score"] = calculate_stability_score(
362
+ data["masks"], self.mask_threshold, self.stability_score_offset
363
+ )
364
+ if self.stability_score_thresh > 0.0:
365
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
366
+ data.filter(keep_mask)
367
+
368
+ # Threshold masks and calculate boxes
369
+ data["masks"] = data["masks"] > self.mask_threshold
370
+ data["boxes"] = batched_mask_to_box(data["masks"])
371
+
372
+ # Filter boxes that touch crop boundaries
373
+ keep_mask = ~is_box_near_crop_edge(
374
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
375
+ )
376
+ if not torch.all(keep_mask):
377
+ data.filter(keep_mask)
378
+
379
+ # Compress to RLE
380
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
381
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
382
+ del data["masks"]
383
+
384
+ return data
385
+
386
+ @staticmethod
387
+ def postprocess_small_regions(
388
+ mask_data: MaskData, min_area: int, nms_thresh: float
389
+ ) -> MaskData:
390
+ """
391
+ Removes small disconnected regions and holes in masks, then reruns
392
+ box NMS to remove any new duplicates.
393
+
394
+ Edits mask_data in place.
395
+
396
+ Requires open-cv as a dependency.
397
+ """
398
+ if len(mask_data["rles"]) == 0:
399
+ return mask_data
400
+
401
+ # Filter small disconnected regions and holes
402
+ new_masks = []
403
+ scores = []
404
+ for rle in mask_data["rles"]:
405
+ mask = rle_to_mask(rle)
406
+
407
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
408
+ unchanged = not changed
409
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
410
+ unchanged = unchanged and not changed
411
+
412
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
413
+ # Give score=0 to changed masks and score=1 to unchanged masks
414
+ # so NMS will prefer ones that didn't need postprocessing
415
+ scores.append(float(unchanged))
416
+
417
+ # Recalculate boxes and remove any new duplicates
418
+ masks = torch.cat(new_masks, dim=0)
419
+ boxes = batched_mask_to_box(masks)
420
+ keep_by_nms = batched_nms(
421
+ boxes.float(),
422
+ torch.as_tensor(scores),
423
+ torch.zeros_like(boxes[:, 0]), # categories
424
+ iou_threshold=nms_thresh,
425
+ )
426
+
427
+ # Only recalculate RLEs for masks that have changed
428
+ for i_mask in keep_by_nms:
429
+ if scores[i_mask] == 0.0:
430
+ mask_torch = masks[i_mask].unsqueeze(0)
431
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
432
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
433
+ mask_data.filter(keep_by_nms)
434
+
435
+ return mask_data
436
+
437
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
438
+ new_masks = []
439
+ new_iou_preds = []
440
+
441
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
442
+ points_per_batch, points, point_labels, low_res_masks
443
+ ):
444
+ best_masks, best_iou_preds, _ = self.predictor._predict(
445
+ cur_points[:, None, :],
446
+ cur_point_labels[:, None],
447
+ mask_input=low_res_mask[:, None, :],
448
+ multimask_output=False,
449
+ return_logits=True,
450
+ )
451
+ new_masks.append(best_masks)
452
+ new_iou_preds.append(best_iou_preds)
453
+ masks = torch.cat(new_masks, dim=0)
454
+ return masks, torch.cat(new_iou_preds, dim=0)
sam2/build_sam.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+ from hydra import compose
12
+ from hydra.utils import instantiate
13
+ from omegaconf import OmegaConf
14
+
15
+ import sam2
16
+
17
+ # Check if the user is running Python from the parent directory of the sam2 repo
18
+ # (i.e. the directory where this repo is cloned into) -- this is not supported since
19
+ # it could shadow the sam2 package and cause issues.
20
+ if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21
+ # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22
+ # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23
+ # This typically happens because the user is running Python from the parent directory
24
+ # that contains the sam2 repo they cloned.
25
+ raise RuntimeError(
26
+ "You're likely running Python from the parent directory of the sam2 repository "
27
+ "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28
+ "This is not supported since the `sam2` Python package could be shadowed by the "
29
+ "repository name (the repository is also named `sam2` and contains the Python package "
30
+ "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31
+ "rather than its parent dir, or from your home directory) after installing SAM 2."
32
+ )
33
+
34
+
35
+ HF_MODEL_ID_TO_FILENAMES = {
36
+ "facebook/sam2-hiera-tiny": (
37
+ "configs/sam2/sam2_hiera_t.yaml",
38
+ "sam2_hiera_tiny.pt",
39
+ ),
40
+ "facebook/sam2-hiera-small": (
41
+ "configs/sam2/sam2_hiera_s.yaml",
42
+ "sam2_hiera_small.pt",
43
+ ),
44
+ "facebook/sam2-hiera-base-plus": (
45
+ "configs/sam2/sam2_hiera_b+.yaml",
46
+ "sam2_hiera_base_plus.pt",
47
+ ),
48
+ "facebook/sam2-hiera-large": (
49
+ "configs/sam2/sam2_hiera_l.yaml",
50
+ "sam2_hiera_large.pt",
51
+ ),
52
+ "facebook/sam2.1-hiera-tiny": (
53
+ "configs/sam2.1/sam2.1_hiera_t.yaml",
54
+ "sam2.1_hiera_tiny.pt",
55
+ ),
56
+ "facebook/sam2.1-hiera-small": (
57
+ "configs/sam2.1/sam2.1_hiera_s.yaml",
58
+ "sam2.1_hiera_small.pt",
59
+ ),
60
+ "facebook/sam2.1-hiera-base-plus": (
61
+ "configs/sam2.1/sam2.1_hiera_b+.yaml",
62
+ "sam2.1_hiera_base_plus.pt",
63
+ ),
64
+ "facebook/sam2.1-hiera-large": (
65
+ "configs/sam2.1/sam2.1_hiera_l.yaml",
66
+ "sam2.1_hiera_large.pt",
67
+ ),
68
+ }
69
+
70
+
71
+ def build_sam2(
72
+ config_file,
73
+ ckpt_path=None,
74
+ device="cuda",
75
+ mode="eval",
76
+ hydra_overrides_extra=[],
77
+ apply_postprocessing=True,
78
+ **kwargs,
79
+ ):
80
+
81
+ if apply_postprocessing:
82
+ hydra_overrides_extra = hydra_overrides_extra.copy()
83
+ hydra_overrides_extra += [
84
+ # dynamically fall back to multi-mask if the single mask is not stable
85
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88
+ ]
89
+ # Read config and init model
90
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91
+ OmegaConf.resolve(cfg)
92
+ model = instantiate(cfg.model, _recursive_=True)
93
+ _load_checkpoint(model, ckpt_path)
94
+ model = model.to(device)
95
+ if mode == "eval":
96
+ model.eval()
97
+ return model
98
+
99
+
100
+ def build_sam2_video_predictor(
101
+ config_file,
102
+ ckpt_path=None,
103
+ device="cuda",
104
+ mode="eval",
105
+ hydra_overrides_extra=[],
106
+ apply_postprocessing=True,
107
+ **kwargs,
108
+ ):
109
+ hydra_overrides = [
110
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
111
+ ]
112
+ if apply_postprocessing:
113
+ hydra_overrides_extra = hydra_overrides_extra.copy()
114
+ hydra_overrides_extra += [
115
+ # dynamically fall back to multi-mask if the single mask is not stable
116
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
117
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
118
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
119
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
120
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
121
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
122
+ "++model.fill_hole_area=8",
123
+ ]
124
+ hydra_overrides.extend(hydra_overrides_extra)
125
+
126
+ # Read config and init model
127
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
128
+ OmegaConf.resolve(cfg)
129
+ model = instantiate(cfg.model, _recursive_=True)
130
+ _load_checkpoint(model, ckpt_path)
131
+ model = model.to(device)
132
+ if mode == "eval":
133
+ model.eval()
134
+ return model
135
+
136
+
137
+ def _hf_download(model_id):
138
+ from huggingface_hub import hf_hub_download
139
+
140
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
141
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
142
+ return config_name, ckpt_path
143
+
144
+
145
+ def build_sam2_hf(model_id, **kwargs):
146
+ config_name, ckpt_path = _hf_download(model_id)
147
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
148
+
149
+
150
+ def build_sam2_video_predictor_hf(model_id, **kwargs):
151
+ config_name, ckpt_path = _hf_download(model_id)
152
+ return build_sam2_video_predictor(
153
+ config_file=config_name, ckpt_path=ckpt_path, **kwargs
154
+ )
155
+
156
+
157
+ def _load_checkpoint(model, ckpt_path):
158
+ if ckpt_path is not None:
159
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
160
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
161
+ if missing_keys:
162
+ logging.error(missing_keys)
163
+ raise RuntimeError()
164
+ if unexpected_keys:
165
+ logging.error(unexpected_keys)
166
+ raise RuntimeError()
167
+ logging.info("Loaded checkpoint sucessfully")
sam2/configs/sam2.1/sam2.1_hiera_b+.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ no_obj_embed_spatial: true
93
+ # use high-resolution feature map in the SAM mask decoder
94
+ use_high_res_features_in_sam: true
95
+ # output 3 masks on the first click on initial conditioning frames
96
+ multimask_output_in_sam: true
97
+ # SAM heads
98
+ iou_prediction_use_sigmoid: True
99
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
100
+ use_obj_ptrs_in_encoder: true
101
+ add_tpos_enc_to_obj_ptrs: true
102
+ proj_tpos_enc_in_obj_ptrs: true
103
+ use_signed_tpos_enc_to_obj_ptrs: true
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
sam2/configs/sam2.1/sam2.1_hiera_l.yaml ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ compile_image_encoder: False
sam2/configs/sam2.1/sam2.1_hiera_s.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ no_obj_embed_spatial: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: true
105
+ proj_tpos_enc_in_obj_ptrs: true
106
+ use_signed_tpos_enc_to_obj_ptrs: true
107
+ only_obj_ptrs_in_the_past_for_eval: true
108
+ # object occlusion prediction
109
+ pred_obj_scores: true
110
+ pred_obj_scores_mlp: true
111
+ fixed_no_obj_ptr: true
112
+ # multimask tracking settings
113
+ multimask_output_for_tracking: true
114
+ use_multimask_token_for_obj_ptr: true
115
+ multimask_min_pt_num: 0
116
+ multimask_max_pt_num: 1
117
+ use_mlp_for_obj_ptr_proj: true
118
+ # Compilation flag
119
+ compile_image_encoder: False
sam2/configs/sam2.1/sam2.1_hiera_t.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ # HieraT does not currently support compilation, should always be set to False
121
+ compile_image_encoder: False
sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 8
8
+ max_num_objects: 3
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 40
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: null # PATH to MOSE JPEGImages folder
17
+ gt_folder: null # PATH to MOSE Annotations folder
18
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
19
+ multiplier: 2
20
+
21
+ # Video transforms
22
+ vos:
23
+ train_transforms:
24
+ - _target_: training.dataset.transforms.ComposeAPI
25
+ transforms:
26
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
27
+ consistent_transform: True
28
+ - _target_: training.dataset.transforms.RandomAffine
29
+ degrees: 25
30
+ shear: 20
31
+ image_interpolation: bilinear
32
+ consistent_transform: True
33
+ - _target_: training.dataset.transforms.RandomResizeAPI
34
+ sizes: ${scratch.resolution}
35
+ square: true
36
+ consistent_transform: True
37
+ - _target_: training.dataset.transforms.ColorJitter
38
+ consistent_transform: True
39
+ brightness: 0.1
40
+ contrast: 0.03
41
+ saturation: 0.03
42
+ hue: null
43
+ - _target_: training.dataset.transforms.RandomGrayscale
44
+ p: 0.05
45
+ consistent_transform: True
46
+ - _target_: training.dataset.transforms.ColorJitter
47
+ consistent_transform: False
48
+ brightness: 0.1
49
+ contrast: 0.05
50
+ saturation: 0.05
51
+ hue: null
52
+ - _target_: training.dataset.transforms.ToTensorAPI
53
+ - _target_: training.dataset.transforms.NormalizeAPI
54
+ mean: [0.485, 0.456, 0.406]
55
+ std: [0.229, 0.224, 0.225]
56
+
57
+ trainer:
58
+ _target_: training.trainer.Trainer
59
+ mode: train_only
60
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
61
+ accelerator: cuda
62
+ seed_value: 123
63
+
64
+ model:
65
+ _target_: training.model.sam2.SAM2Train
66
+ image_encoder:
67
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
+ scalp: 1
69
+ trunk:
70
+ _target_: sam2.modeling.backbones.hieradet.Hiera
71
+ embed_dim: 112
72
+ num_heads: 2
73
+ drop_path_rate: 0.1
74
+ neck:
75
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
+ position_encoding:
77
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
+ num_pos_feats: 256
79
+ normalize: true
80
+ scale: null
81
+ temperature: 10000
82
+ d_model: 256
83
+ backbone_channel_list: [896, 448, 224, 112]
84
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
85
+ fpn_interp_model: nearest
86
+
87
+ memory_attention:
88
+ _target_: sam2.modeling.memory_attention.MemoryAttention
89
+ d_model: 256
90
+ pos_enc_at_input: true
91
+ layer:
92
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
+ activation: relu
94
+ dim_feedforward: 2048
95
+ dropout: 0.1
96
+ pos_enc_at_attn: false
97
+ self_attention:
98
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
99
+ rope_theta: 10000.0
100
+ feat_sizes: [32, 32]
101
+ embedding_dim: 256
102
+ num_heads: 1
103
+ downsample_rate: 1
104
+ dropout: 0.1
105
+ d_model: 256
106
+ pos_enc_at_cross_attn_keys: true
107
+ pos_enc_at_cross_attn_queries: false
108
+ cross_attention:
109
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
110
+ rope_theta: 10000.0
111
+ feat_sizes: [32, 32]
112
+ rope_k_repeat: True
113
+ embedding_dim: 256
114
+ num_heads: 1
115
+ downsample_rate: 1
116
+ dropout: 0.1
117
+ kv_in_dim: 64
118
+ num_layers: 4
119
+
120
+ memory_encoder:
121
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
+ out_dim: 64
123
+ position_encoding:
124
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
+ num_pos_feats: 64
126
+ normalize: true
127
+ scale: null
128
+ temperature: 10000
129
+ mask_downsampler:
130
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
+ kernel_size: 3
132
+ stride: 2
133
+ padding: 1
134
+ fuser:
135
+ _target_: sam2.modeling.memory_encoder.Fuser
136
+ layer:
137
+ _target_: sam2.modeling.memory_encoder.CXBlock
138
+ dim: 256
139
+ kernel_size: 7
140
+ padding: 3
141
+ layer_scale_init_value: 1e-6
142
+ use_dwconv: True # depth-wise convs
143
+ num_layers: 2
144
+
145
+ num_maskmem: 7
146
+ image_size: ${scratch.resolution}
147
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
148
+ sigmoid_scale_for_mem_enc: 20.0
149
+ sigmoid_bias_for_mem_enc: -10.0
150
+ use_mask_input_as_output_without_sam: true
151
+ # Memory
152
+ directly_add_no_mem_embed: true
153
+ no_obj_embed_spatial: true
154
+ # use high-resolution feature map in the SAM mask decoder
155
+ use_high_res_features_in_sam: true
156
+ # output 3 masks on the first click on initial conditioning frames
157
+ multimask_output_in_sam: true
158
+ # SAM heads
159
+ iou_prediction_use_sigmoid: True
160
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
161
+ use_obj_ptrs_in_encoder: true
162
+ add_tpos_enc_to_obj_ptrs: true
163
+ proj_tpos_enc_in_obj_ptrs: true
164
+ use_signed_tpos_enc_to_obj_ptrs: true
165
+ only_obj_ptrs_in_the_past_for_eval: true
166
+ # object occlusion prediction
167
+ pred_obj_scores: true
168
+ pred_obj_scores_mlp: true
169
+ fixed_no_obj_ptr: true
170
+ # multimask tracking settings
171
+ multimask_output_for_tracking: true
172
+ use_multimask_token_for_obj_ptr: true
173
+ multimask_min_pt_num: 0
174
+ multimask_max_pt_num: 1
175
+ use_mlp_for_obj_ptr_proj: true
176
+ # Compilation flag
177
+ # compile_image_encoder: False
178
+
179
+ ####### Training specific params #######
180
+ # box/point input and corrections
181
+ prob_to_use_pt_input_for_train: 0.5
182
+ prob_to_use_pt_input_for_eval: 0.0
183
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
184
+ prob_to_use_box_input_for_eval: 0.0
185
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
186
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
187
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
188
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
189
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
190
+ # maximum 2 initial conditioning frames
191
+ num_init_cond_frames_for_train: 2
192
+ rand_init_cond_frames_for_train: True # random 1~2
193
+ num_correction_pt_per_frame: 7
194
+ use_act_ckpt_iterative_pt_sampling: false
195
+
196
+
197
+
198
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
199
+ forward_backbone_per_frame_for_eval: True
200
+
201
+
202
+ data:
203
+ train:
204
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
205
+ phases_per_epoch: ${scratch.phases_per_epoch}
206
+ batch_sizes:
207
+ - ${scratch.train_batch_size}
208
+
209
+ datasets:
210
+ - _target_: training.dataset.utils.RepeatFactorWrapper
211
+ dataset:
212
+ _target_: training.dataset.utils.ConcatDataset
213
+ datasets:
214
+ - _target_: training.dataset.vos_dataset.VOSDataset
215
+ transforms: ${vos.train_transforms}
216
+ training: true
217
+ video_dataset:
218
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
219
+ img_folder: ${dataset.img_folder}
220
+ gt_folder: ${dataset.gt_folder}
221
+ file_list_txt: ${dataset.file_list_txt}
222
+ sampler:
223
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
224
+ num_frames: ${scratch.num_frames}
225
+ max_num_objects: ${scratch.max_num_objects}
226
+ multiplier: ${dataset.multiplier}
227
+ shuffle: True
228
+ num_workers: ${scratch.num_train_workers}
229
+ pin_memory: True
230
+ drop_last: True
231
+ collate_fn:
232
+ _target_: training.utils.data_utils.collate_fn
233
+ _partial_: true
234
+ dict_key: all
235
+
236
+ optim:
237
+ amp:
238
+ enabled: True
239
+ amp_dtype: bfloat16
240
+
241
+ optimizer:
242
+ _target_: torch.optim.AdamW
243
+
244
+ gradient_clip:
245
+ _target_: training.optimizer.GradientClipper
246
+ max_norm: 0.1
247
+ norm_type: 2
248
+
249
+ param_group_modifiers:
250
+ - _target_: training.optimizer.layer_decay_param_modifier
251
+ _partial_: True
252
+ layer_decay_value: 0.9
253
+ apply_to: 'image_encoder.trunk'
254
+ overrides:
255
+ - pattern: '*pos_embed*'
256
+ value: 1.0
257
+
258
+ options:
259
+ lr:
260
+ - scheduler:
261
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
262
+ start_value: ${scratch.base_lr}
263
+ end_value: ${divide:${scratch.base_lr},10}
264
+ - scheduler:
265
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
266
+ start_value: ${scratch.vision_lr}
267
+ end_value: ${divide:${scratch.vision_lr},10}
268
+ param_names:
269
+ - 'image_encoder.*'
270
+ weight_decay:
271
+ - scheduler:
272
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
273
+ value: 0.1
274
+ - scheduler:
275
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
276
+ value: 0.0
277
+ param_names:
278
+ - '*bias*'
279
+ module_cls_names: ['torch.nn.LayerNorm']
280
+
281
+ loss:
282
+ all:
283
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
284
+ weight_dict:
285
+ loss_mask: 20
286
+ loss_dice: 1
287
+ loss_iou: 1
288
+ loss_class: 1
289
+ supervise_all_iou: true
290
+ iou_use_l1_loss: true
291
+ pred_obj_scores: true
292
+ focal_gamma_obj_score: 0.0
293
+ focal_alpha_obj_score: -1.0
294
+
295
+ distributed:
296
+ backend: nccl
297
+ find_unused_parameters: True
298
+
299
+ logging:
300
+ tensorboard_writer:
301
+ _target_: training.utils.logger.make_tensorboard_logger
302
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
303
+ flush_secs: 120
304
+ should_log: True
305
+ log_dir: ${launcher.experiment_log_dir}/logs
306
+ log_freq: 10
307
+
308
+ # initialize from a SAM 2 checkpoint
309
+ checkpoint:
310
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
311
+ save_freq: 0 # 0 only last checkpoint is saved.
312
+ model_weight_initializer:
313
+ _partial_: True
314
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
315
+ strict: True
316
+ ignore_unexpected_keys: null
317
+ ignore_missing_keys: null
318
+
319
+ state_dict:
320
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
321
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
322
+ ckpt_state_dict_keys: ['model']
323
+
324
+ launcher:
325
+ num_nodes: 1
326
+ gpus_per_node: 8
327
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
328
+
329
+ # SLURM args if running on a cluster
330
+ submitit:
331
+ partition: null
332
+ account: null
333
+ qos: null
334
+ cpus_per_task: 10
335
+ use_cluster: false
336
+ timeout_hour: 24
337
+ name: null
338
+ port_range: [10000, 65000]
339
+
sam2/configs/sam2/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
sam2/configs/sam2/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
sam2/configs/sam2/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
sam2/configs/sam2/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
sam2/csrc/connected_components.cu ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ // adapted from https://github.com/zsef123/Connected_components_PyTorch
8
+ // with license found in the LICENSE_cctorch file in the root directory.
9
+ #include <ATen/cuda/CUDAContext.h>
10
+ #include <cuda.h>
11
+ #include <cuda_runtime.h>
12
+ #include <torch/extension.h>
13
+ #include <torch/script.h>
14
+ #include <vector>
15
+
16
+ // 2d
17
+ #define BLOCK_ROWS 16
18
+ #define BLOCK_COLS 16
19
+
20
+ namespace cc2d {
21
+
22
+ template <typename T>
23
+ __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
24
+ return (bitmap >> pos) & 1;
25
+ }
26
+
27
+ __device__ int32_t find(const int32_t* s_buf, int32_t n) {
28
+ while (s_buf[n] != n)
29
+ n = s_buf[n];
30
+ return n;
31
+ }
32
+
33
+ __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
34
+ const int32_t id = n;
35
+ while (s_buf[n] != n) {
36
+ n = s_buf[n];
37
+ s_buf[id] = n;
38
+ }
39
+ return n;
40
+ }
41
+
42
+ __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
43
+ bool done;
44
+ do {
45
+ a = find(s_buf, a);
46
+ b = find(s_buf, b);
47
+
48
+ if (a < b) {
49
+ int32_t old = atomicMin(s_buf + b, a);
50
+ done = (old == b);
51
+ b = old;
52
+ } else if (b < a) {
53
+ int32_t old = atomicMin(s_buf + a, b);
54
+ done = (old == a);
55
+ a = old;
56
+ } else
57
+ done = true;
58
+
59
+ } while (!done);
60
+ }
61
+
62
+ __global__ void
63
+ init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
64
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
65
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
66
+ const uint32_t idx = row * W + col;
67
+
68
+ if (row < H && col < W)
69
+ label[idx] = idx;
70
+ }
71
+
72
+ __global__ void
73
+ merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
74
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
75
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
76
+ const uint32_t idx = row * W + col;
77
+
78
+ if (row >= H || col >= W)
79
+ return;
80
+
81
+ uint32_t P = 0;
82
+
83
+ if (img[idx])
84
+ P |= 0x777;
85
+ if (row + 1 < H && img[idx + W])
86
+ P |= 0x777 << 4;
87
+ if (col + 1 < W && img[idx + 1])
88
+ P |= 0x777 << 1;
89
+
90
+ if (col == 0)
91
+ P &= 0xEEEE;
92
+ if (col + 1 >= W)
93
+ P &= 0x3333;
94
+ else if (col + 2 >= W)
95
+ P &= 0x7777;
96
+
97
+ if (row == 0)
98
+ P &= 0xFFF0;
99
+ if (row + 1 >= H)
100
+ P &= 0xFF;
101
+
102
+ if (P > 0) {
103
+ // If need check about top-left pixel(if flag the first bit) and hit the
104
+ // top-left pixel
105
+ if (hasBit(P, 0) && img[idx - W - 1]) {
106
+ union_(label, idx, idx - 2 * W - 2); // top left block
107
+ }
108
+
109
+ if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
110
+ union_(label, idx, idx - 2 * W); // top bottom block
111
+
112
+ if (hasBit(P, 3) && img[idx + 2 - W])
113
+ union_(label, idx, idx - 2 * W + 2); // top right block
114
+
115
+ if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
116
+ union_(label, idx, idx - 2); // just left block
117
+ }
118
+ }
119
+
120
+ __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
121
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
122
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
123
+ const uint32_t idx = row * W + col;
124
+
125
+ if (row < H && col < W)
126
+ find_n_compress(label, idx);
127
+ }
128
+
129
+ __global__ void final_labeling(
130
+ const uint8_t* img,
131
+ int32_t* label,
132
+ const int32_t W,
133
+ const int32_t H) {
134
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
135
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
136
+ const uint32_t idx = row * W + col;
137
+
138
+ if (row >= H || col >= W)
139
+ return;
140
+
141
+ int32_t y = label[idx] + 1;
142
+
143
+ if (img[idx])
144
+ label[idx] = y;
145
+ else
146
+ label[idx] = 0;
147
+
148
+ if (col + 1 < W) {
149
+ if (img[idx + 1])
150
+ label[idx + 1] = y;
151
+ else
152
+ label[idx + 1] = 0;
153
+
154
+ if (row + 1 < H) {
155
+ if (img[idx + W + 1])
156
+ label[idx + W + 1] = y;
157
+ else
158
+ label[idx + W + 1] = 0;
159
+ }
160
+ }
161
+
162
+ if (row + 1 < H) {
163
+ if (img[idx + W])
164
+ label[idx + W] = y;
165
+ else
166
+ label[idx + W] = 0;
167
+ }
168
+ }
169
+
170
+ __global__ void init_counting(
171
+ const int32_t* label,
172
+ int32_t* count_init,
173
+ const int32_t W,
174
+ const int32_t H) {
175
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
176
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
177
+ const uint32_t idx = row * W + col;
178
+
179
+ if (row >= H || col >= W)
180
+ return;
181
+
182
+ int32_t y = label[idx];
183
+ if (y > 0) {
184
+ int32_t count_idx = y - 1;
185
+ atomicAdd(count_init + count_idx, 1);
186
+ }
187
+ }
188
+
189
+ __global__ void final_counting(
190
+ const int32_t* label,
191
+ const int32_t* count_init,
192
+ int32_t* count_final,
193
+ const int32_t W,
194
+ const int32_t H) {
195
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
196
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
197
+ const uint32_t idx = row * W + col;
198
+
199
+ if (row >= H || col >= W)
200
+ return;
201
+
202
+ int32_t y = label[idx];
203
+ if (y > 0) {
204
+ int32_t count_idx = y - 1;
205
+ count_final[idx] = count_init[count_idx];
206
+ } else {
207
+ count_final[idx] = 0;
208
+ }
209
+ }
210
+
211
+ } // namespace cc2d
212
+
213
+ std::vector<torch::Tensor> get_connected_componnets(
214
+ const torch::Tensor& inputs) {
215
+ AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
216
+ AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
217
+ AT_ASSERTM(
218
+ inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
219
+
220
+ const uint32_t N = inputs.size(0);
221
+ const uint32_t C = inputs.size(1);
222
+ const uint32_t H = inputs.size(2);
223
+ const uint32_t W = inputs.size(3);
224
+
225
+ AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
226
+ AT_ASSERTM((H % 2) == 0, "height must be an even number");
227
+ AT_ASSERTM((W % 2) == 0, "width must be an even number");
228
+
229
+ // label must be uint32_t
230
+ auto label_options =
231
+ torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
232
+ torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
233
+ torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
234
+ torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
235
+
236
+ dim3 grid = dim3(
237
+ ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
238
+ ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
239
+ dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
240
+ dim3 grid_count =
241
+ dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
242
+ dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
243
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
244
+
245
+ for (int n = 0; n < N; n++) {
246
+ uint32_t offset = n * H * W;
247
+
248
+ cc2d::init_labeling<<<grid, block, 0, stream>>>(
249
+ labels.data_ptr<int32_t>() + offset, W, H);
250
+ cc2d::merge<<<grid, block, 0, stream>>>(
251
+ inputs.data_ptr<uint8_t>() + offset,
252
+ labels.data_ptr<int32_t>() + offset,
253
+ W,
254
+ H);
255
+ cc2d::compression<<<grid, block, 0, stream>>>(
256
+ labels.data_ptr<int32_t>() + offset, W, H);
257
+ cc2d::final_labeling<<<grid, block, 0, stream>>>(
258
+ inputs.data_ptr<uint8_t>() + offset,
259
+ labels.data_ptr<int32_t>() + offset,
260
+ W,
261
+ H);
262
+
263
+ // get the counting of each pixel
264
+ cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
265
+ labels.data_ptr<int32_t>() + offset,
266
+ counts_init.data_ptr<int32_t>() + offset,
267
+ W,
268
+ H);
269
+ cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
270
+ labels.data_ptr<int32_t>() + offset,
271
+ counts_init.data_ptr<int32_t>() + offset,
272
+ counts_final.data_ptr<int32_t>() + offset,
273
+ W,
274
+ H);
275
+ }
276
+
277
+ // returned values are [labels, counts]
278
+ std::vector<torch::Tensor> outputs;
279
+ outputs.push_back(labels);
280
+ outputs.push_back(counts_final);
281
+ return outputs;
282
+ }
283
+
284
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
285
+ m.def(
286
+ "get_connected_componnets",
287
+ &get_connected_componnets,
288
+ "get_connected_componnets");
289
+ }
sam2/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (167 Bytes). View file
 
sam2/modeling/__pycache__/memory_attention.cpython-311.pyc ADDED
Binary file (7.47 kB). View file
 
sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc ADDED
Binary file (8.75 kB). View file
 
sam2/modeling/__pycache__/position_encoding.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
sam2/modeling/__pycache__/sam2_base.cpython-311.pyc ADDED
Binary file (32.9 kB). View file
 
sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
sam2/modeling/backbones/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc ADDED
Binary file (5.85 kB). View file
 
sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.72 kB). View file
 
sam2/modeling/backbones/hieradet.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from functools import partial
9
+ from typing import List, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from iopath.common.file_io import g_pathmgr
15
+
16
+ from sam2.modeling.backbones.utils import (
17
+ PatchEmbed,
18
+ window_partition,
19
+ window_unpartition,
20
+ )
21
+
22
+ from sam2.modeling.sam2_utils import DropPath, MLP
23
+
24
+
25
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26
+ if pool is None:
27
+ return x
28
+ # (B, H, W, C) -> (B, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+ x = pool(x)
31
+ # (B, C, H', W') -> (B, H', W', C)
32
+ x = x.permute(0, 2, 3, 1)
33
+ if norm:
34
+ x = norm(x)
35
+
36
+ return x
37
+
38
+
39
+ class MultiScaleAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ dim_out: int,
44
+ num_heads: int,
45
+ q_pool: nn.Module = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.dim = dim
50
+ self.dim_out = dim_out
51
+ self.num_heads = num_heads
52
+ self.q_pool = q_pool
53
+ self.qkv = nn.Linear(dim, dim_out * 3)
54
+ self.proj = nn.Linear(dim_out, dim_out)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ B, H, W, _ = x.shape
58
+ # qkv with shape (B, H * W, 3, nHead, C)
59
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60
+ # q, k, v with shape (B, H * W, nheads, C)
61
+ q, k, v = torch.unbind(qkv, 2)
62
+
63
+ # Q pooling (for downsample at stage changes)
64
+ if self.q_pool:
65
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66
+ H, W = q.shape[1:3] # downsampled shape
67
+ q = q.reshape(B, H * W, self.num_heads, -1)
68
+
69
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70
+ x = F.scaled_dot_product_attention(
71
+ q.transpose(1, 2),
72
+ k.transpose(1, 2),
73
+ v.transpose(1, 2),
74
+ )
75
+ # Transpose back
76
+ x = x.transpose(1, 2)
77
+ x = x.reshape(B, H, W, -1)
78
+
79
+ x = self.proj(x)
80
+
81
+ return x
82
+
83
+
84
+ class MultiScaleBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ drop_path: float = 0.0,
92
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
93
+ q_stride: Tuple[int, int] = None,
94
+ act_layer: nn.Module = nn.GELU,
95
+ window_size: int = 0,
96
+ ):
97
+ super().__init__()
98
+
99
+ if isinstance(norm_layer, str):
100
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101
+
102
+ self.dim = dim
103
+ self.dim_out = dim_out
104
+ self.norm1 = norm_layer(dim)
105
+
106
+ self.window_size = window_size
107
+
108
+ self.pool, self.q_stride = None, q_stride
109
+ if self.q_stride:
110
+ self.pool = nn.MaxPool2d(
111
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
112
+ )
113
+
114
+ self.attn = MultiScaleAttention(
115
+ dim,
116
+ dim_out,
117
+ num_heads=num_heads,
118
+ q_pool=self.pool,
119
+ )
120
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121
+
122
+ self.norm2 = norm_layer(dim_out)
123
+ self.mlp = MLP(
124
+ dim_out,
125
+ int(dim_out * mlp_ratio),
126
+ dim_out,
127
+ num_layers=2,
128
+ activation=act_layer,
129
+ )
130
+
131
+ if dim != dim_out:
132
+ self.proj = nn.Linear(dim, dim_out)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ shortcut = x # B, H, W, C
136
+ x = self.norm1(x)
137
+
138
+ # Skip connection
139
+ if self.dim != self.dim_out:
140
+ shortcut = do_pool(self.proj(x), self.pool)
141
+
142
+ # Window partition
143
+ window_size = self.window_size
144
+ if window_size > 0:
145
+ H, W = x.shape[1], x.shape[2]
146
+ x, pad_hw = window_partition(x, window_size)
147
+
148
+ # Window Attention + Q Pooling (if stage change)
149
+ x = self.attn(x)
150
+ if self.q_stride:
151
+ # Shapes have changed due to Q pooling
152
+ window_size = self.window_size // self.q_stride[0]
153
+ H, W = shortcut.shape[1:3]
154
+
155
+ pad_h = (window_size - H % window_size) % window_size
156
+ pad_w = (window_size - W % window_size) % window_size
157
+ pad_hw = (H + pad_h, W + pad_w)
158
+
159
+ # Reverse window partition
160
+ if self.window_size > 0:
161
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
162
+
163
+ x = shortcut + self.drop_path(x)
164
+ # MLP
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class Hiera(nn.Module):
170
+ """
171
+ Reference: https://arxiv.org/abs/2306.00989
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ embed_dim: int = 96, # initial embed dim
177
+ num_heads: int = 1, # initial number of heads
178
+ drop_path_rate: float = 0.0, # stochastic depth
179
+ q_pool: int = 3, # number of q_pool stages
180
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
183
+ head_mul: float = 2.0, # head_mul factor at stage shift
184
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185
+ # window size per stage, when not using global att.
186
+ window_spec: Tuple[int, ...] = (
187
+ 8,
188
+ 4,
189
+ 14,
190
+ 7,
191
+ ),
192
+ # global attn in these blocks
193
+ global_att_blocks: Tuple[int, ...] = (
194
+ 12,
195
+ 16,
196
+ 20,
197
+ ),
198
+ weights_path=None,
199
+ return_interm_layers=True, # return feats from every stage
200
+ ):
201
+ super().__init__()
202
+
203
+ assert len(stages) == len(window_spec)
204
+ self.window_spec = window_spec
205
+
206
+ depth = sum(stages)
207
+ self.q_stride = q_stride
208
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
210
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211
+ self.return_interm_layers = return_interm_layers
212
+
213
+ self.patch_embed = PatchEmbed(
214
+ embed_dim=embed_dim,
215
+ )
216
+ # Which blocks have global att?
217
+ self.global_att_blocks = global_att_blocks
218
+
219
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
+ self.pos_embed = nn.Parameter(
222
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
+ )
224
+ self.pos_embed_window = nn.Parameter(
225
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
+ )
227
+
228
+ dpr = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
+ ] # stochastic depth decay rule
231
+
232
+ cur_stage = 1
233
+ self.blocks = nn.ModuleList()
234
+
235
+ for i in range(depth):
236
+ dim_out = embed_dim
237
+ # lags by a block, so first block of
238
+ # next stage uses an initial window size
239
+ # of previous stage and final window size of current stage
240
+ window_size = self.window_spec[cur_stage - 1]
241
+
242
+ if self.global_att_blocks is not None:
243
+ window_size = 0 if i in self.global_att_blocks else window_size
244
+
245
+ if i - 1 in self.stage_ends:
246
+ dim_out = int(embed_dim * dim_mul)
247
+ num_heads = int(num_heads * head_mul)
248
+ cur_stage += 1
249
+
250
+ block = MultiScaleBlock(
251
+ dim=embed_dim,
252
+ dim_out=dim_out,
253
+ num_heads=num_heads,
254
+ drop_path=dpr[i],
255
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
256
+ window_size=window_size,
257
+ )
258
+
259
+ embed_dim = dim_out
260
+ self.blocks.append(block)
261
+
262
+ self.channel_list = (
263
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
+ if return_interm_layers
265
+ else [self.blocks[-1].dim_out]
266
+ )
267
+
268
+ if weights_path is not None:
269
+ with g_pathmgr.open(weights_path, "rb") as f:
270
+ chkpt = torch.load(f, map_location="cpu")
271
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272
+
273
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274
+ h, w = hw
275
+ window_embed = self.pos_embed_window
276
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277
+ pos_embed = pos_embed + window_embed.tile(
278
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279
+ )
280
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
281
+ return pos_embed
282
+
283
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284
+ x = self.patch_embed(x)
285
+ # x: (B, H, W, C)
286
+
287
+ # Add pos embed
288
+ x = x + self._get_pos_embed(x.shape[1:3])
289
+
290
+ outputs = []
291
+ for i, blk in enumerate(self.blocks):
292
+ x = blk(x)
293
+ if (i == self.stage_ends[-1]) or (
294
+ i in self.stage_ends and self.return_interm_layers
295
+ ):
296
+ feats = x.permute(0, 3, 1, 2)
297
+ outputs.append(feats)
298
+
299
+ return outputs
300
+
301
+ def get_layer_id(self, layer_name):
302
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303
+ num_layers = self.get_num_layers()
304
+
305
+ if layer_name.find("rel_pos") != -1:
306
+ return num_layers + 1
307
+ elif layer_name.find("pos_embed") != -1:
308
+ return 0
309
+ elif layer_name.find("patch_embed") != -1:
310
+ return 0
311
+ elif layer_name.find("blocks") != -1:
312
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313
+ else:
314
+ return num_layers + 1
315
+
316
+ def get_num_layers(self) -> int:
317
+ return len(self.blocks)
sam2/modeling/backbones/image_encoder.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ImageEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ trunk: nn.Module,
18
+ neck: nn.Module,
19
+ scalp: int = 0,
20
+ ):
21
+ super().__init__()
22
+ self.trunk = trunk
23
+ self.neck = neck
24
+ self.scalp = scalp
25
+ assert (
26
+ self.trunk.channel_list == self.neck.backbone_channel_list
27
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28
+
29
+ def forward(self, sample: torch.Tensor):
30
+ # Forward through backbone
31
+ features, pos = self.neck(self.trunk(sample))
32
+ if self.scalp > 0:
33
+ # Discard the lowest resolution features
34
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
35
+
36
+ src = features[-1]
37
+ output = {
38
+ "vision_features": src,
39
+ "vision_pos_enc": pos,
40
+ "backbone_fpn": features,
41
+ }
42
+ return output
43
+
44
+
45
+ class FpnNeck(nn.Module):
46
+ """
47
+ A modified variant of Feature Pyramid Network (FPN) neck
48
+ (we remove output conv and also do bicubic interpolation similar to ViT
49
+ pos embed interpolation)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ position_encoding: nn.Module,
55
+ d_model: int,
56
+ backbone_channel_list: List[int],
57
+ kernel_size: int = 1,
58
+ stride: int = 1,
59
+ padding: int = 0,
60
+ fpn_interp_model: str = "bilinear",
61
+ fuse_type: str = "sum",
62
+ fpn_top_down_levels: Optional[List[int]] = None,
63
+ ):
64
+ """Initialize the neck
65
+ :param trunk: the backbone
66
+ :param position_encoding: the positional encoding to use
67
+ :param d_model: the dimension of the model
68
+ :param neck_norm: the normalization to use
69
+ """
70
+ super().__init__()
71
+ self.position_encoding = position_encoding
72
+ self.convs = nn.ModuleList()
73
+ self.backbone_channel_list = backbone_channel_list
74
+ self.d_model = d_model
75
+ for dim in backbone_channel_list:
76
+ current = nn.Sequential()
77
+ current.add_module(
78
+ "conv",
79
+ nn.Conv2d(
80
+ in_channels=dim,
81
+ out_channels=d_model,
82
+ kernel_size=kernel_size,
83
+ stride=stride,
84
+ padding=padding,
85
+ ),
86
+ )
87
+
88
+ self.convs.append(current)
89
+ self.fpn_interp_model = fpn_interp_model
90
+ assert fuse_type in ["sum", "avg"]
91
+ self.fuse_type = fuse_type
92
+
93
+ # levels to have top-down features in its outputs
94
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95
+ # have top-down propagation, while outputs of level 0 and level 1 have only
96
+ # lateral features from the same backbone level.
97
+ if fpn_top_down_levels is None:
98
+ # default is to have top-down features on all levels
99
+ fpn_top_down_levels = range(len(self.convs))
100
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
101
+
102
+ def forward(self, xs: List[torch.Tensor]):
103
+
104
+ out = [None] * len(self.convs)
105
+ pos = [None] * len(self.convs)
106
+ assert len(xs) == len(self.convs)
107
+ # fpn forward pass
108
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109
+ prev_features = None
110
+ # forward in top-down order (from low to high resolution)
111
+ n = len(self.convs) - 1
112
+ for i in range(n, -1, -1):
113
+ x = xs[i]
114
+ lateral_features = self.convs[n - i](x)
115
+ if i in self.fpn_top_down_levels and prev_features is not None:
116
+ top_down_features = F.interpolate(
117
+ prev_features.to(dtype=torch.float32),
118
+ scale_factor=2.0,
119
+ mode=self.fpn_interp_model,
120
+ align_corners=(
121
+ None if self.fpn_interp_model == "nearest" else False
122
+ ),
123
+ antialias=False,
124
+ )
125
+ prev_features = lateral_features + top_down_features
126
+ if self.fuse_type == "avg":
127
+ prev_features /= 2
128
+ else:
129
+ prev_features = lateral_features
130
+ x_out = prev_features
131
+ out[i] = x_out
132
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133
+
134
+ return out, pos
sam2/modeling/backbones/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+ Returns:
23
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
24
+ (Hp, Wp): padded height and width before partition
25
+ """
26
+ B, H, W, C = x.shape
27
+
28
+ pad_h = (window_size - H % window_size) % window_size
29
+ pad_w = (window_size - W % window_size) % window_size
30
+ if pad_h > 0 or pad_w > 0:
31
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32
+ Hp, Wp = H + pad_h, W + pad_w
33
+
34
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35
+ windows = (
36
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ )
38
+ return windows, (Hp, Wp)
39
+
40
+
41
+ def window_unpartition(windows, window_size, pad_hw, hw):
42
+ """
43
+ Window unpartition into original sequences and removing padding.
44
+ Args:
45
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46
+ window_size (int): window size.
47
+ pad_hw (Tuple): padded height and width (Hp, Wp).
48
+ hw (Tuple): original height and width (H, W) before padding.
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(
56
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57
+ )
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ """
67
+ Image to Patch Embedding.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ kernel_size: Tuple[int, ...] = (7, 7),
73
+ stride: Tuple[int, ...] = (4, 4),
74
+ padding: Tuple[int, ...] = (3, 3),
75
+ in_chans: int = 3,
76
+ embed_dim: int = 768,
77
+ ):
78
+ """
79
+ Args:
80
+ kernel_size (Tuple): kernel size of the projection layer.
81
+ stride (Tuple): stride of the projection layer.
82
+ padding (Tuple): padding size of the projection layer.
83
+ in_chans (int): Number of input image channels.
84
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
85
+ """
86
+ super().__init__()
87
+ self.proj = nn.Conv2d(
88
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ x = self.proj(x)
93
+ # B C H W -> B H W C
94
+ x = x.permute(0, 2, 3, 1)
95
+ return x
sam2/modeling/memory_attention.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from sam2.modeling.sam.transformer import RoPEAttention
13
+
14
+ from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
+
16
+
17
+ class MemoryAttentionLayer(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ activation: str,
22
+ cross_attention: nn.Module,
23
+ d_model: int,
24
+ dim_feedforward: int,
25
+ dropout: float,
26
+ pos_enc_at_attn: bool,
27
+ pos_enc_at_cross_attn_keys: bool,
28
+ pos_enc_at_cross_attn_queries: bool,
29
+ self_attention: nn.Module,
30
+ ):
31
+ super().__init__()
32
+ self.d_model = d_model
33
+ self.dim_feedforward = dim_feedforward
34
+ self.dropout_value = dropout
35
+ self.self_attn = self_attention
36
+ self.cross_attn_image = cross_attention
37
+
38
+ # Implementation of Feedforward model
39
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
40
+ self.dropout = nn.Dropout(dropout)
41
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
42
+
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2 = nn.LayerNorm(d_model)
45
+ self.norm3 = nn.LayerNorm(d_model)
46
+ self.dropout1 = nn.Dropout(dropout)
47
+ self.dropout2 = nn.Dropout(dropout)
48
+ self.dropout3 = nn.Dropout(dropout)
49
+
50
+ self.activation_str = activation
51
+ self.activation = get_activation_fn(activation)
52
+
53
+ # Where to add pos enc
54
+ self.pos_enc_at_attn = pos_enc_at_attn
55
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57
+
58
+ def _forward_sa(self, tgt, query_pos):
59
+ # Self-Attention
60
+ tgt2 = self.norm1(tgt)
61
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62
+ tgt2 = self.self_attn(q, k, v=tgt2)
63
+ tgt = tgt + self.dropout1(tgt2)
64
+ return tgt
65
+
66
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67
+ kwds = {}
68
+ if num_k_exclude_rope > 0:
69
+ assert isinstance(self.cross_attn_image, RoPEAttention)
70
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71
+
72
+ # Cross-Attention
73
+ tgt2 = self.norm2(tgt)
74
+ tgt2 = self.cross_attn_image(
75
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77
+ v=memory,
78
+ **kwds,
79
+ )
80
+ tgt = tgt + self.dropout2(tgt2)
81
+ return tgt
82
+
83
+ def forward(
84
+ self,
85
+ tgt,
86
+ memory,
87
+ pos: Optional[Tensor] = None,
88
+ query_pos: Optional[Tensor] = None,
89
+ num_k_exclude_rope: int = 0,
90
+ ) -> torch.Tensor:
91
+
92
+ # Self-Attn, Cross-Attn
93
+ tgt = self._forward_sa(tgt, query_pos)
94
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95
+ # MLP
96
+ tgt2 = self.norm3(tgt)
97
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt
100
+
101
+
102
+ class MemoryAttention(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ pos_enc_at_input: bool,
107
+ layer: nn.Module,
108
+ num_layers: int,
109
+ batch_first: bool = True, # Do layers expect batch first input?
110
+ ):
111
+ super().__init__()
112
+ self.d_model = d_model
113
+ self.layers = get_clones(layer, num_layers)
114
+ self.num_layers = num_layers
115
+ self.norm = nn.LayerNorm(d_model)
116
+ self.pos_enc_at_input = pos_enc_at_input
117
+ self.batch_first = batch_first
118
+
119
+ def forward(
120
+ self,
121
+ curr: torch.Tensor, # self-attention inputs
122
+ memory: torch.Tensor, # cross-attention inputs
123
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126
+ ):
127
+ if isinstance(curr, list):
128
+ assert isinstance(curr_pos, list)
129
+ assert len(curr) == len(curr_pos) == 1
130
+ curr, curr_pos = (
131
+ curr[0],
132
+ curr_pos[0],
133
+ )
134
+
135
+ assert (
136
+ curr.shape[1] == memory.shape[1]
137
+ ), "Batch size must be the same for curr and memory"
138
+
139
+ output = curr
140
+ if self.pos_enc_at_input and curr_pos is not None:
141
+ output = output + 0.1 * curr_pos
142
+
143
+ if self.batch_first:
144
+ # Convert to batch first
145
+ output = output.transpose(0, 1)
146
+ curr_pos = curr_pos.transpose(0, 1)
147
+ memory = memory.transpose(0, 1)
148
+ memory_pos = memory_pos.transpose(0, 1)
149
+
150
+ for layer in self.layers:
151
+ kwds = {}
152
+ if isinstance(layer.cross_attn_image, RoPEAttention):
153
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154
+
155
+ output = layer(
156
+ tgt=output,
157
+ memory=memory,
158
+ pos=memory_pos,
159
+ query_pos=curr_pos,
160
+ **kwds,
161
+ )
162
+ normed_output = self.norm(output)
163
+
164
+ if self.batch_first:
165
+ # Convert back to seq first
166
+ normed_output = normed_output.transpose(0, 1)
167
+ curr_pos = curr_pos.transpose(0, 1)
168
+
169
+ return normed_output
sam2/modeling/memory_encoder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
sam2/modeling/position_encoding.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention Is All You Need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ ):
29
+ super().__init__()
30
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
31
+ self.num_pos_feats = num_pos_feats // 2
32
+ self.temperature = temperature
33
+ self.normalize = normalize
34
+ if scale is not None and normalize is False:
35
+ raise ValueError("normalize should be True if scale is passed")
36
+ if scale is None:
37
+ scale = 2 * math.pi
38
+ self.scale = scale
39
+
40
+ self.cache = {}
41
+
42
+ def _encode_xy(self, x, y):
43
+ # The positions are expected to be normalized
44
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
45
+ x_embed = x * self.scale
46
+ y_embed = y * self.scale
47
+
48
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
49
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
50
+
51
+ pos_x = x_embed[:, None] / dim_t
52
+ pos_y = y_embed[:, None] / dim_t
53
+ pos_x = torch.stack(
54
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
55
+ ).flatten(1)
56
+ pos_y = torch.stack(
57
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
58
+ ).flatten(1)
59
+ return pos_x, pos_y
60
+
61
+ @torch.no_grad()
62
+ def encode_boxes(self, x, y, w, h):
63
+ pos_x, pos_y = self._encode_xy(x, y)
64
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
65
+ return pos
66
+
67
+ encode = encode_boxes # Backwards compatibility
68
+
69
+ @torch.no_grad()
70
+ def encode_points(self, x, y, labels):
71
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
72
+ assert bx == by and nx == ny and bx == bl and nx == nl
73
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
74
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
75
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
76
+ return pos
77
+
78
+ @torch.no_grad()
79
+ def forward(self, x: torch.Tensor):
80
+ cache_key = (x.shape[-2], x.shape[-1])
81
+ if cache_key in self.cache:
82
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
83
+ y_embed = (
84
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
85
+ .view(1, -1, 1)
86
+ .repeat(x.shape[0], 1, x.shape[-1])
87
+ )
88
+ x_embed = (
89
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
90
+ .view(1, 1, -1)
91
+ .repeat(x.shape[0], x.shape[-2], 1)
92
+ )
93
+
94
+ if self.normalize:
95
+ eps = 1e-6
96
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
97
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
98
+
99
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
100
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101
+
102
+ pos_x = x_embed[:, :, :, None] / dim_t
103
+ pos_y = y_embed[:, :, :, None] / dim_t
104
+ pos_x = torch.stack(
105
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
106
+ ).flatten(3)
107
+ pos_y = torch.stack(
108
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
109
+ ).flatten(3)
110
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
111
+ self.cache[cache_key] = pos[0]
112
+ return pos
113
+
114
+
115
+ class PositionEmbeddingRandom(nn.Module):
116
+ """
117
+ Positional encoding using random spatial frequencies.
118
+ """
119
+
120
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
121
+ super().__init__()
122
+ if scale is None or scale <= 0.0:
123
+ scale = 1.0
124
+ self.register_buffer(
125
+ "positional_encoding_gaussian_matrix",
126
+ scale * torch.randn((2, num_pos_feats)),
127
+ )
128
+
129
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
130
+ """Positionally encode points that are normalized to [0,1]."""
131
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
132
+ coords = 2 * coords - 1
133
+ coords = coords @ self.positional_encoding_gaussian_matrix
134
+ coords = 2 * np.pi * coords
135
+ # outputs d_1 x ... x d_n x C shape
136
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
137
+
138
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
139
+ """Generate positional encoding for a grid of the specified size."""
140
+ h, w = size
141
+ device: Any = self.positional_encoding_gaussian_matrix.device
142
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
143
+ y_embed = grid.cumsum(dim=0) - 0.5
144
+ x_embed = grid.cumsum(dim=1) - 0.5
145
+ y_embed = y_embed / h
146
+ x_embed = x_embed / w
147
+
148
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
149
+ return pe.permute(2, 0, 1) # C x H x W
150
+
151
+ def forward_with_coords(
152
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
153
+ ) -> torch.Tensor:
154
+ """Positionally encode points that are not normalized to [0,1]."""
155
+ coords = coords_input.clone()
156
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
157
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
158
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
159
+
160
+
161
+ # Rotary Positional Encoding, adapted from:
162
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
163
+ # 2. https://github.com/naver-ai/rope-vit
164
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
165
+
166
+
167
+ def init_t_xy(end_x: int, end_y: int):
168
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
169
+ t_x = (t % end_x).float()
170
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
171
+ return t_x, t_y
172
+
173
+
174
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
175
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
177
+
178
+ t_x, t_y = init_t_xy(end_x, end_y)
179
+ freqs_x = torch.outer(t_x, freqs_x)
180
+ freqs_y = torch.outer(t_y, freqs_y)
181
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
182
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
183
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
184
+
185
+
186
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
187
+ ndim = x.ndim
188
+ assert 0 <= 1 < ndim
189
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
190
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
191
+ return freqs_cis.view(*shape)
192
+
193
+
194
+ def apply_rotary_enc(
195
+ xq: torch.Tensor,
196
+ xk: torch.Tensor,
197
+ freqs_cis: torch.Tensor,
198
+ repeat_freqs_k: bool = False,
199
+ ):
200
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
201
+ xk_ = (
202
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
203
+ if xk.shape[-2] != 0
204
+ else None
205
+ )
206
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
207
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
208
+ if xk_ is None:
209
+ # no keys to rotate, due to dropout
210
+ return xq_out.type_as(xq).to(xq.device), xk
211
+ # repeat freqs along seq_len dim to match k seq_len
212
+ if repeat_freqs_k:
213
+ r = xk_.shape[-2] // xq_.shape[-2]
214
+ if freqs_cis.is_cuda:
215
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
216
+ else:
217
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
218
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
219
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
220
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
221
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
sam2/modeling/sam/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc ADDED
Binary file (9.84 kB). View file
 
sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
sam2/modeling/sam/mask_decoder.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(
67
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
+ ),
69
+ LayerNorm2d(transformer_dim // 4),
70
+ activation(),
71
+ nn.ConvTranspose2d(
72
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
+ ),
74
+ activation(),
75
+ )
76
+ self.use_high_res_features = use_high_res_features
77
+ if use_high_res_features:
78
+ self.conv_s0 = nn.Conv2d(
79
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
+ )
81
+ self.conv_s1 = nn.Conv2d(
82
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
+ )
84
+
85
+ self.output_hypernetworks_mlps = nn.ModuleList(
86
+ [
87
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
+ for i in range(self.num_mask_tokens)
89
+ ]
90
+ )
91
+
92
+ self.iou_prediction_head = MLP(
93
+ transformer_dim,
94
+ iou_head_hidden_dim,
95
+ self.num_mask_tokens,
96
+ iou_head_depth,
97
+ sigmoid_output=iou_prediction_use_sigmoid,
98
+ )
99
+ if self.pred_obj_scores:
100
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101
+ if pred_obj_scores_mlp:
102
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103
+
104
+ # When outputting a single mask, optionally we can dynamically fall back to the best
105
+ # multimask output token if the single mask output token gives low stability scores.
106
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ repeat_image: bool,
118
+ high_res_features: Optional[List[torch.Tensor]] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Predict masks given image and prompt embeddings.
122
+
123
+ Arguments:
124
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
125
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128
+ multimask_output (bool): Whether to return multiple masks or a single
129
+ mask.
130
+
131
+ Returns:
132
+ torch.Tensor: batched predicted masks
133
+ torch.Tensor: batched predictions of mask quality
134
+ torch.Tensor: batched SAM token for mask output
135
+ """
136
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137
+ image_embeddings=image_embeddings,
138
+ image_pe=image_pe,
139
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
140
+ dense_prompt_embeddings=dense_prompt_embeddings,
141
+ repeat_image=repeat_image,
142
+ high_res_features=high_res_features,
143
+ )
144
+
145
+ # Select the correct mask or masks for output
146
+ if multimask_output:
147
+ masks = masks[:, 1:, :, :]
148
+ iou_pred = iou_pred[:, 1:]
149
+ elif self.dynamic_multimask_via_stability and not self.training:
150
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151
+ else:
152
+ masks = masks[:, 0:1, :, :]
153
+ iou_pred = iou_pred[:, 0:1]
154
+
155
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
156
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157
+ else:
158
+ # Take the mask output token. Here we *always* use the token for single mask output.
159
+ # At test time, even if we track after 1-click (and using multimask_output=True),
160
+ # we still take the single mask token here. The rationale is that we always track
161
+ # after multiple clicks during training, so the past tokens seen during training
162
+ # are always the single mask token (and we'll let it be the object-memory token).
163
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164
+
165
+ # Prepare output
166
+ return masks, iou_pred, sam_tokens_out, object_score_logits
167
+
168
+ def predict_masks(
169
+ self,
170
+ image_embeddings: torch.Tensor,
171
+ image_pe: torch.Tensor,
172
+ sparse_prompt_embeddings: torch.Tensor,
173
+ dense_prompt_embeddings: torch.Tensor,
174
+ repeat_image: bool,
175
+ high_res_features: Optional[List[torch.Tensor]] = None,
176
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ """Predicts masks. See 'forward' for more details."""
178
+ # Concatenate output tokens
179
+ s = 0
180
+ if self.pred_obj_scores:
181
+ output_tokens = torch.cat(
182
+ [
183
+ self.obj_score_token.weight,
184
+ self.iou_token.weight,
185
+ self.mask_tokens.weight,
186
+ ],
187
+ dim=0,
188
+ )
189
+ s = 1
190
+ else:
191
+ output_tokens = torch.cat(
192
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
193
+ )
194
+ output_tokens = output_tokens.unsqueeze(0).expand(
195
+ sparse_prompt_embeddings.size(0), -1, -1
196
+ )
197
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198
+
199
+ # Expand per-image data in batch direction to be per-mask
200
+ if repeat_image:
201
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202
+ else:
203
+ assert image_embeddings.shape[0] == tokens.shape[0]
204
+ src = image_embeddings
205
+ src = src + dense_prompt_embeddings
206
+ assert (
207
+ image_pe.size(0) == 1
208
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210
+ b, c, h, w = src.shape
211
+
212
+ # Run the transformer
213
+ hs, src = self.transformer(src, pos_src, tokens)
214
+ iou_token_out = hs[:, s, :]
215
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
216
+
217
+ # Upscale mask embeddings and predict masks using the mask tokens
218
+ src = src.transpose(1, 2).view(b, c, h, w)
219
+ if not self.use_high_res_features:
220
+ upscaled_embedding = self.output_upscaling(src)
221
+ else:
222
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
223
+ feat_s0, feat_s1 = high_res_features
224
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
225
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
226
+
227
+ hyper_in_list: List[torch.Tensor] = []
228
+ for i in range(self.num_mask_tokens):
229
+ hyper_in_list.append(
230
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231
+ )
232
+ hyper_in = torch.stack(hyper_in_list, dim=1)
233
+ b, c, h, w = upscaled_embedding.shape
234
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
235
+
236
+ # Generate mask quality predictions
237
+ iou_pred = self.iou_prediction_head(iou_token_out)
238
+ if self.pred_obj_scores:
239
+ assert s == 1
240
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
241
+ else:
242
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
243
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
244
+
245
+ return masks, iou_pred, mask_tokens_out, object_score_logits
246
+
247
+ def _get_stability_scores(self, mask_logits):
248
+ """
249
+ Compute stability scores of the mask logits based on the IoU between upper and
250
+ lower thresholds.
251
+ """
252
+ mask_logits = mask_logits.flatten(-2)
253
+ stability_delta = self.dynamic_multimask_stability_delta
254
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
255
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
256
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
257
+ return stability_scores
258
+
259
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
260
+ """
261
+ When outputting a single mask, if the stability score from the current single-mask
262
+ output (based on output token 0) falls below a threshold, we instead select from
263
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
264
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
265
+ """
266
+ # The best mask from multimask output tokens (1~3)
267
+ multimask_logits = all_mask_logits[:, 1:, :, :]
268
+ multimask_iou_scores = all_iou_scores[:, 1:]
269
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270
+ batch_inds = torch.arange(
271
+ multimask_iou_scores.size(0), device=all_iou_scores.device
272
+ )
273
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
275
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
276
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
277
+
278
+ # The mask from singlemask output token 0 and its stability score
279
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
280
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
281
+ stability_scores = self._get_stability_scores(singlemask_logits)
282
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
283
+
284
+ # Dynamically fall back to best multimask output upon low stability scores.
285
+ mask_logits_out = torch.where(
286
+ is_stable[..., None, None].expand_as(singlemask_logits),
287
+ singlemask_logits,
288
+ best_multimask_logits,
289
+ )
290
+ iou_scores_out = torch.where(
291
+ is_stable.expand_as(singlemask_iou_scores),
292
+ singlemask_iou_scores,
293
+ best_multimask_iou_scores,
294
+ )
295
+ return mask_logits_out, iou_scores_out
sam2/modeling/sam/prompt_encoder.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [
48
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
+ ]
50
+ self.point_embeddings = nn.ModuleList(point_embeddings)
51
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
+
53
+ self.mask_input_size = (
54
+ 4 * image_embedding_size[0],
55
+ 4 * image_embedding_size[1],
56
+ )
57
+ self.mask_downscaling = nn.Sequential(
58
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59
+ LayerNorm2d(mask_in_chans // 4),
60
+ activation(),
61
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62
+ LayerNorm2d(mask_in_chans),
63
+ activation(),
64
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
+ )
66
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
67
+
68
+ def get_dense_pe(self) -> torch.Tensor:
69
+ """
70
+ Returns the positional encoding used to encode point prompts,
71
+ applied to a dense set of points the shape of the image encoding.
72
+
73
+ Returns:
74
+ torch.Tensor: Positional encoding with shape
75
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
76
+ """
77
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
+
79
+ def _embed_points(
80
+ self,
81
+ points: torch.Tensor,
82
+ labels: torch.Tensor,
83
+ pad: bool,
84
+ ) -> torch.Tensor:
85
+ """Embeds point prompts."""
86
+ points = points + 0.5 # Shift to center of pixel
87
+ if pad:
88
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
+ points = torch.cat([points, padding_point], dim=1)
91
+ labels = torch.cat([labels, padding_label], dim=1)
92
+ point_embedding = self.pe_layer.forward_with_coords(
93
+ points, self.input_image_size
94
+ )
95
+ point_embedding[labels == -1] = 0.0
96
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
97
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
98
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
99
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
100
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
101
+ return point_embedding
102
+
103
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
104
+ """Embeds box prompts."""
105
+ boxes = boxes + 0.5 # Shift to center of pixel
106
+ coords = boxes.reshape(-1, 2, 2)
107
+ corner_embedding = self.pe_layer.forward_with_coords(
108
+ coords, self.input_image_size
109
+ )
110
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
111
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
112
+ return corner_embedding
113
+
114
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
115
+ """Embeds mask inputs."""
116
+ mask_embedding = self.mask_downscaling(masks)
117
+ return mask_embedding
118
+
119
+ def _get_batch_size(
120
+ self,
121
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
122
+ boxes: Optional[torch.Tensor],
123
+ masks: Optional[torch.Tensor],
124
+ ) -> int:
125
+ """
126
+ Gets the batch size of the output given the batch size of the input prompts.
127
+ """
128
+ if points is not None:
129
+ return points[0].shape[0]
130
+ elif boxes is not None:
131
+ return boxes.shape[0]
132
+ elif masks is not None:
133
+ return masks.shape[0]
134
+ else:
135
+ return 1
136
+
137
+ def _get_device(self) -> torch.device:
138
+ return self.point_embeddings[0].weight.device
139
+
140
+ def forward(
141
+ self,
142
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143
+ boxes: Optional[torch.Tensor],
144
+ masks: Optional[torch.Tensor],
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Embeds different types of prompts, returning both sparse and dense
148
+ embeddings.
149
+
150
+ Arguments:
151
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
152
+ and labels to embed.
153
+ boxes (torch.Tensor or none): boxes to embed
154
+ masks (torch.Tensor or none): masks to embed
155
+
156
+ Returns:
157
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
158
+ BxNx(embed_dim), where N is determined by the number of input points
159
+ and boxes.
160
+ torch.Tensor: dense embeddings for the masks, in the shape
161
+ Bx(embed_dim)x(embed_H)x(embed_W)
162
+ """
163
+ bs = self._get_batch_size(points, boxes, masks)
164
+ sparse_embeddings = torch.empty(
165
+ (bs, 0, self.embed_dim), device=self._get_device()
166
+ )
167
+ if points is not None:
168
+ coords, labels = points
169
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
170
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
171
+ if boxes is not None:
172
+ box_embeddings = self._embed_boxes(boxes)
173
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
174
+
175
+ if masks is not None:
176
+ dense_embeddings = self._embed_masks(masks)
177
+ else:
178
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
179
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
180
+ )
181
+
182
+ return sparse_embeddings, dense_embeddings