diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4fd0e74be0aee2cd2b6c976123eefa81270817e2 Binary files /dev/null and b/.DS_Store differ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..ac23ed037d3fcdeff1faa6396c2bfe880d9f8506 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +# Utiliser une image de base Python légère +FROM python:3.9-slim + +# Définir le répertoire de travail +WORKDIR /app + +# Copier les fichiers nécessaires dans le conteneur +COPY . /app + +# Installer les dépendances +RUN pip install --no-cache-dir -r requirements.txt + +# Exposer le port 7860 pour le serveur Flask +EXPOSE 7860 + +# Commande pour démarrer Flask +CMD ["python", "app.py"] \ No newline at end of file diff --git a/README.md b/README.md index 4fdb8750c7eedff4649aaba84df47013be6c6066..c247b19030f898e97dc3f4458673b608ee240804 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- -title: Project -emoji: 📊 -colorFrom: indigo -colorTo: blue +title: Segmentation Project +emoji: 😻 +colorFrom: red +colorTo: purple sdk: docker pinned: false --- diff --git a/app.log b/app.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e67b777b4f91e2459e2fab078d222f2ebd5521dd --- /dev/null +++ b/app.py @@ -0,0 +1,824 @@ +from flask import Flask, render_template, request, jsonify +from flask_socketio import SocketIO +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +import shutil +import numpy as np +from PIL import Image +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +class Predictor: + def __init__(self, model_cfg, checkpoint, device): + self.device = device + self.model = build_sam2(model_cfg, checkpoint, device=device) + self.predictor = SAM2ImagePredictor(self.model) + self.image_set = False + + def set_image(self, image): + """Set the image for SAM prediction.""" + self.image = image + self.predictor.set_image(image) + self.image_set = True + + def predict(self, point_coords, point_labels, multimask_output=False): + """Run SAM prediction.""" + if not self.image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + return self.predictor.predict( + point_coords=point_coords, + point_labels=point_labels, + multimask_output=multimask_output + ) +from utils.helpers import ( + blend_mask_with_image, + save_mask_as_png, + convert_mask_to_yolo, +) +import torch +from ultralytics import YOLO +import threading +from threading import Lock +import subprocess +import time +import logging +import multiprocessing +import json + + +# Initialize Flask app and SocketIO +app = Flask(__name__) +socketio = SocketIO(app) + +# Define Base Directory +BASE_DIR = os.path.abspath(os.path.dirname(__file__)) + +# Folder structure with absolute paths +UPLOAD_FOLDERS = { + 'input': os.path.join(BASE_DIR, 'static/uploads/input'), + 'segmented_voids': os.path.join(BASE_DIR, 'static/uploads/segmented/voids'), + 'segmented_chips': os.path.join(BASE_DIR, 'static/uploads/segmented/chips'), + 'mask_voids': os.path.join(BASE_DIR, 'static/uploads/mask/voids'), + 'mask_chips': os.path.join(BASE_DIR, 'static/uploads/mask/chips'), + 'automatic_segmented': os.path.join(BASE_DIR, 'static/uploads/segmented/automatic'), +} + +HISTORY_FOLDERS = { + 'images': os.path.join(BASE_DIR, 'static/history/images'), + 'masks_chip': os.path.join(BASE_DIR, 'static/history/masks/chip'), + 'masks_void': os.path.join(BASE_DIR, 'static/history/masks/void'), +} + +DATASET_FOLDERS = { + 'train_images': os.path.join(BASE_DIR, 'dataset/train/images'), + 'train_labels': os.path.join(BASE_DIR, 'dataset/train/labels'), + 'val_images': os.path.join(BASE_DIR, 'dataset/val/images'), + 'val_labels': os.path.join(BASE_DIR, 'dataset/val/labels'), + 'temp_backup': os.path.join(BASE_DIR, 'temp_backup'), + 'models': os.path.join(BASE_DIR, 'models'), + 'models_old': os.path.join(BASE_DIR, 'models/old'), +} + +# Ensure all folders exist +for folder_name, folder_path in {**UPLOAD_FOLDERS, **HISTORY_FOLDERS, **DATASET_FOLDERS}.items(): + os.makedirs(folder_path, exist_ok=True) + logging.info(f"Ensured folder exists: {folder_name} -> {folder_path}") + +training_process = None + + +def initialize_training_status(): + """Initialize global training status.""" + global training_status + training_status = {'running': False, 'cancelled': False} + +def persist_training_status(): + """Save training status to a file.""" + with open(os.path.join(BASE_DIR, 'training_status.json'), 'w') as status_file: + json.dump(training_status, status_file) + +def load_training_status(): + """Load training status from a file.""" + global training_status + status_path = os.path.join(BASE_DIR, 'training_status.json') + if os.path.exists(status_path): + with open(status_path, 'r') as status_file: + training_status = json.load(status_file) + else: + training_status = {'running': False, 'cancelled': False} + +load_training_status() + +os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0" + +# Initialize SAM Predictor +MODEL_CFG = r"sam2/sam2_hiera_l.yaml" +CHECKPOINT = r"sam2/checkpoints/sam2.1_hiera_large.pt" +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +predictor = Predictor(MODEL_CFG, CHECKPOINT, DEVICE) + +# Initialize YOLO-seg +YOLO_CFG = os.path.join(DATASET_FOLDERS['models'], "best.pt") +yolo_model = YOLO(YOLO_CFG) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler(os.path.join(BASE_DIR, "app.log")) # Log to a file + ] +) + + +@app.route('/') +def index(): + """Serve the main UI.""" + return render_template('index.html') + +@app.route('/upload', methods=['POST']) +def upload_image(): + """Handle image uploads.""" + if 'file' not in request.files: + return jsonify({'error': 'No file uploaded'}), 400 + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'No file selected'}), 400 + + # Save the uploaded file to the input folder + input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) + file.save(input_path) + + # Set the uploaded image in the predictor + image = np.array(Image.open(input_path).convert("RGB")) + predictor.set_image(image) + + # Return a web-accessible URL instead of the file system path + web_accessible_url = f"/static/uploads/input/{file.filename}" + print(f"Image uploaded and set for prediction: {input_path}") + return jsonify({'image_url': web_accessible_url}) + +@app.route('/segment', methods=['POST']) +def segment(): + """ + Perform segmentation and return the blended image URL. + """ + try: + # Extract data from request + data = request.json + points = np.array(data.get('points', [])) + labels = np.array(data.get('labels', [])) + current_class = data.get('class', 'voids') # Default to 'voids' if class not provided + + # Ensure predictor has an image set + if not predictor.image_set: + raise ValueError("No image set for prediction.") + + # Perform SAM prediction + masks, _, _ = predictor.predict( + point_coords=points, + point_labels=labels, + multimask_output=False + ) + + # Check if masks exist and have non-zero elements + if masks is None or masks.size == 0: + raise RuntimeError("No masks were generated by the predictor.") + + # Define output paths based on class + mask_folder = UPLOAD_FOLDERS.get(f'mask_{current_class}') + segmented_folder = UPLOAD_FOLDERS.get(f'segmented_{current_class}') + + if not mask_folder or not segmented_folder: + raise ValueError(f"Invalid class '{current_class}' provided.") + + os.makedirs(mask_folder, exist_ok=True) + os.makedirs(segmented_folder, exist_ok=True) + + # Save the raw mask + mask_path = os.path.join(mask_folder, 'raw_mask.png') + save_mask_as_png(masks[0], mask_path) + + # Generate blended image + blend_color = [34, 139, 34] if current_class == 'voids' else [30, 144, 255] # Green for voids, blue for chips + blended_image = blend_mask_with_image(predictor.image, masks[0], blend_color) + + # Save blended image + blended_filename = f"blended_{current_class}.png" + blended_path = os.path.join(segmented_folder, blended_filename) + Image.fromarray(blended_image).save(blended_path) + + # Return URL for frontend access + segmented_url = f"/static/uploads/segmented/{current_class}/{blended_filename}" + logging.info(f"Segmentation completed for {current_class}. Points: {points}, Labels: {labels}") + return jsonify({'segmented_url': segmented_url}) + + except ValueError as ve: + logging.error(f"Value error during segmentation: {ve}") + return jsonify({'error': str(ve)}), 400 + + except Exception as e: + logging.error(f"Unexpected error during segmentation: {e}") + return jsonify({'error': 'Segmentation failed', 'details': str(e)}), 500 + +@app.route('/automatic_segment', methods=['POST']) +def automatic_segment(): + """Perform automatic segmentation using YOLO.""" + if 'file' not in request.files: + return jsonify({'error': 'No file uploaded'}), 400 + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'No file selected'}), 400 + + input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) + file.save(input_path) + + try: + # Perform YOLO segmentation + results = yolo_model.predict(input_path, save=False, save_txt=False) + output_folder = UPLOAD_FOLDERS['automatic_segmented'] + os.makedirs(output_folder, exist_ok=True) + + chips_data = [] + chips = [] + voids = [] + + # Process results and save segmented images + for result in results: + annotated_image = result.plot() + result_filename = f"{file.filename.rsplit('.', 1)[0]}_pred.jpg" + result_path = os.path.join(output_folder, result_filename) + Image.fromarray(annotated_image).save(result_path) + + # Separate chips and voids + for i, label in enumerate(result.boxes.cls): # YOLO labels + label_name = result.names[int(label)] # Get label name (e.g., 'chip' or 'void') + box = result.boxes.xyxy[i].cpu().numpy() # Bounding box (x1, y1, x2, y2) + area = float((box[2] - box[0]) * (box[3] - box[1])) # Calculate area + + if label_name == 'chip': + chips.append({'box': box, 'area': area, 'voids': []}) + elif label_name == 'void': + voids.append({'box': box, 'area': area}) + + # Assign voids to chips based on proximity + for void in voids: + void_centroid = [ + (void['box'][0] + void['box'][2]) / 2, # x centroid + (void['box'][1] + void['box'][3]) / 2 # y centroid + ] + for chip in chips: + # Check if void centroid is within chip bounding box + if (chip['box'][0] <= void_centroid[0] <= chip['box'][2] and + chip['box'][1] <= void_centroid[1] <= chip['box'][3]): + chip['voids'].append(void) + break + + # Calculate metrics for each chip + for idx, chip in enumerate(chips): + chip_area = chip['area'] + total_void_area = sum([float(void['area']) for void in chip['voids']]) + max_void_area = max([float(void['area']) for void in chip['voids']], default=0) + + void_percentage = (total_void_area / chip_area) * 100 if chip_area > 0 else 0 + max_void_percentage = (max_void_area / chip_area) * 100 if chip_area > 0 else 0 + + chips_data.append({ + "chip_number": int(idx + 1), + "chip_area": round(chip_area, 2), + "void_percentage": round(void_percentage, 2), + "max_void_percentage": round(max_void_percentage, 2) + }) + + # Return the segmented image URL and table data + segmented_url = f"/static/uploads/segmented/automatic/{result_filename}" + return jsonify({ + "segmented_url": segmented_url, # Use the URL for frontend access + "table_data": { + "image_name": file.filename, + "chips": chips_data + } + }) + + except Exception as e: + print(f"Error in automatic segmentation: {e}") + return jsonify({'error': 'Segmentation failed.'}), 500 + +@app.route('/save_both', methods=['POST']) +def save_both(): + """Save both the image and masks into the history folders.""" + data = request.json + image_name = data.get('image_name') + + if not image_name: + return jsonify({'error': 'Image name not provided'}), 400 + + try: + # Ensure image_name is a pure file name + image_name = os.path.basename(image_name) # Strip any directory path + print(f"Sanitized Image Name: {image_name}") + + # Correctly resolve the input image path + input_image_path = os.path.join(UPLOAD_FOLDERS['input'], image_name) + if not os.path.exists(input_image_path): + print(f"Input image does not exist: {input_image_path}") + return jsonify({'error': f'Input image not found: {input_image_path}'}), 404 + + # Copy the image to history/images + image_history_path = os.path.join(HISTORY_FOLDERS['images'], image_name) + os.makedirs(os.path.dirname(image_history_path), exist_ok=True) + shutil.copy(input_image_path, image_history_path) + print(f"Image saved to history: {image_history_path}") + + # Backup void mask + void_mask_path = os.path.join(UPLOAD_FOLDERS['mask_voids'], 'raw_mask.png') + if os.path.exists(void_mask_path): + void_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") + os.makedirs(os.path.dirname(void_mask_history_path), exist_ok=True) + shutil.copy(void_mask_path, void_mask_history_path) + print(f"Voids mask saved to history: {void_mask_history_path}") + else: + print(f"Voids mask not found: {void_mask_path}") + + # Backup chip mask + chip_mask_path = os.path.join(UPLOAD_FOLDERS['mask_chips'], 'raw_mask.png') + if os.path.exists(chip_mask_path): + chip_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") + os.makedirs(os.path.dirname(chip_mask_history_path), exist_ok=True) + shutil.copy(chip_mask_path, chip_mask_history_path) + print(f"Chips mask saved to history: {chip_mask_history_path}") + else: + print(f"Chips mask not found: {chip_mask_path}") + + return jsonify({'message': 'Image and masks saved successfully!'}), 200 + + except Exception as e: + print(f"Error saving files: {e}") + return jsonify({'error': 'Failed to save files.', 'details': str(e)}), 500 + +@app.route('/get_history', methods=['GET']) +def get_history(): + try: + saved_images = os.listdir(HISTORY_FOLDERS['images']) + return jsonify({'status': 'success', 'images': saved_images}), 200 + except Exception as e: + return jsonify({'status': 'error', 'message': f'Failed to fetch history: {e}'}), 500 + + +@app.route('/delete_history_item', methods=['POST']) +def delete_history_item(): + data = request.json + image_name = data.get('image_name') + + if not image_name: + return jsonify({'error': 'Image name not provided'}), 400 + + try: + image_path = os.path.join(HISTORY_FOLDERS['images'], image_name) + if os.path.exists(image_path): + os.remove(image_path) + + void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") + if os.path.exists(void_mask_path): + os.remove(void_mask_path) + + chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") + if os.path.exists(chip_mask_path): + os.remove(chip_mask_path) + + return jsonify({'message': f'{image_name} and associated masks deleted successfully.'}), 200 + except Exception as e: + return jsonify({'error': f'Failed to delete files: {e}'}), 500 + +# Lock for training status updates +status_lock = Lock() + +def update_training_status(key, value): + """Thread-safe update for training status.""" + with status_lock: + training_status[key] = value + +@app.route('/retrain_model', methods=['POST']) +def retrain_model(): + """Handle retrain model workflow.""" + global training_status + + if training_status.get('running', False): + return jsonify({'error': 'Training is already in progress'}), 400 + + try: + # Update training status + update_training_status('running', True) + update_training_status('cancelled', False) + logging.info("Training status updated. Starting training workflow.") + + # Backup masks and images + backup_masks_and_images() + logging.info("Backup completed successfully.") + + # Prepare YOLO labels + prepare_yolo_labels() + logging.info("YOLO labels prepared successfully.") + + # Start YOLO training in a separate thread + threading.Thread(target=run_yolo_training).start() + return jsonify({'message': 'Training started successfully!'}), 200 + + except Exception as e: + logging.error(f"Error during training preparation: {e}") + update_training_status('running', False) + return jsonify({'error': f"Failed to start training: {e}"}), 500 + +def prepare_yolo_labels(): + """Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" + images_folder = HISTORY_FOLDERS['images'] # Use history images as the source + train_labels_folder = DATASET_FOLDERS['train_labels'] + train_images_folder = DATASET_FOLDERS['train_images'] + val_labels_folder = DATASET_FOLDERS['val_labels'] + val_images_folder = DATASET_FOLDERS['val_images'] + + # Ensure destination directories exist + os.makedirs(train_labels_folder, exist_ok=True) + os.makedirs(train_images_folder, exist_ok=True) + os.makedirs(val_labels_folder, exist_ok=True) + os.makedirs(val_images_folder, exist_ok=True) + + try: + all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] + random.shuffle(all_images) # Shuffle the images for randomness + + # Determine split index + split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation + + # Split images into train and validation sets + train_images = all_images[:split_idx] + val_images = all_images[split_idx:] + + # Process training images + for image_name in train_images: + process_image_and_mask( + image_name, + source_images_folder=images_folder, + dest_images_folder=train_images_folder, + dest_labels_folder=train_labels_folder + ) + + # Process validation images + for image_name in val_images: + process_image_and_mask( + image_name, + source_images_folder=images_folder, + dest_images_folder=val_images_folder, + dest_labels_folder=val_labels_folder + ) + + logging.info("YOLO labels prepared, and images split into train and validation successfully.") + + except Exception as e: + logging.error(f"Error in preparing YOLO labels: {e}") + raise + +import random + +def prepare_yolo_labels(): + """Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" + images_folder = HISTORY_FOLDERS['images'] # Use history images as the source + train_labels_folder = DATASET_FOLDERS['train_labels'] + train_images_folder = DATASET_FOLDERS['train_images'] + val_labels_folder = DATASET_FOLDERS['val_labels'] + val_images_folder = DATASET_FOLDERS['val_images'] + + # Ensure destination directories exist + os.makedirs(train_labels_folder, exist_ok=True) + os.makedirs(train_images_folder, exist_ok=True) + os.makedirs(val_labels_folder, exist_ok=True) + os.makedirs(val_images_folder, exist_ok=True) + + try: + all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] + random.shuffle(all_images) # Shuffle the images for randomness + + # Determine split index + split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation + + # Split images into train and validation sets + train_images = all_images[:split_idx] + val_images = all_images[split_idx:] + + # Process training images + for image_name in train_images: + process_image_and_mask( + image_name, + source_images_folder=images_folder, + dest_images_folder=train_images_folder, + dest_labels_folder=train_labels_folder + ) + + # Process validation images + for image_name in val_images: + process_image_and_mask( + image_name, + source_images_folder=images_folder, + dest_images_folder=val_images_folder, + dest_labels_folder=val_labels_folder + ) + + logging.info("YOLO labels prepared, and images split into train and validation successfully.") + + except Exception as e: + logging.error(f"Error in preparing YOLO labels: {e}") + raise + + +def process_image_and_mask(image_name, source_images_folder, dest_images_folder, dest_labels_folder): + """ + Process a single image and its masks, saving them in the appropriate YOLO format. + """ + try: + image_path = os.path.join(source_images_folder, image_name) + label_file_path = os.path.join(dest_labels_folder, f"{os.path.splitext(image_name)[0]}.txt") + + # Copy image to the destination images folder + shutil.copy(image_path, os.path.join(dest_images_folder, image_name)) + + # Clear the label file if it exists + if os.path.exists(label_file_path): + os.remove(label_file_path) + + # Process void mask + void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") + if os.path.exists(void_mask_path): + convert_mask_to_yolo( + mask_path=void_mask_path, + image_path=image_path, + class_id=0, # Void class + output_path=label_file_path + ) + + # Process chip mask + chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") + if os.path.exists(chip_mask_path): + convert_mask_to_yolo( + mask_path=chip_mask_path, + image_path=image_path, + class_id=1, # Chip class + output_path=label_file_path, + append=True # Append chip annotations + ) + + logging.info(f"Processed {image_name} into YOLO format.") + except Exception as e: + logging.error(f"Error processing {image_name}: {e}") + raise + +def backup_masks_and_images(): + """Backup current masks and images from history folders.""" + temp_backup_paths = { + 'voids': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/voids'), + 'chips': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/chips'), + 'images': os.path.join(DATASET_FOLDERS['temp_backup'], 'images') + } + + # Prepare all backup directories + for path in temp_backup_paths.values(): + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + + try: + # Backup images from history + for file in os.listdir(HISTORY_FOLDERS['images']): + src_image_path = os.path.join(HISTORY_FOLDERS['images'], file) + dst_image_path = os.path.join(temp_backup_paths['images'], file) + shutil.copy(src_image_path, dst_image_path) + + # Backup void masks from history + for file in os.listdir(HISTORY_FOLDERS['masks_void']): + src_void_path = os.path.join(HISTORY_FOLDERS['masks_void'], file) + dst_void_path = os.path.join(temp_backup_paths['voids'], file) + shutil.copy(src_void_path, dst_void_path) + + # Backup chip masks from history + for file in os.listdir(HISTORY_FOLDERS['masks_chip']): + src_chip_path = os.path.join(HISTORY_FOLDERS['masks_chip'], file) + dst_chip_path = os.path.join(temp_backup_paths['chips'], file) + shutil.copy(src_chip_path, dst_chip_path) + + logging.info("Masks and images backed up successfully from history.") + except Exception as e: + logging.error(f"Error during backup: {e}") + raise RuntimeError("Backup process failed.") + +def run_yolo_training(num_epochs=10): + """Run YOLO training process.""" + global training_process + + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + data_cfg_path = os.path.join(BASE_DIR, "models/data.yaml") # Ensure correct YAML path + + logging.info(f"Starting YOLO training on {device} with {num_epochs} epochs.") + logging.info(f"Using dataset configuration: {data_cfg_path}") + + training_command = [ + "yolo", + "train", + f"data={data_cfg_path}", + f"model={os.path.join(DATASET_FOLDERS['models'], 'best.pt')}", + f"device={device}", + f"epochs={num_epochs}", + "project=runs", + "name=train" + ] + + training_process = subprocess.Popen( + training_command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + env=os.environ.copy(), + ) + + # Display and log output in real time + for line in iter(training_process.stdout.readline, ''): + print(line.strip()) + logging.info(line.strip()) + socketio.emit('training_update', {'message': line.strip()}) # Send updates to the frontend + + training_process.wait() + + if training_process.returncode == 0: + finalize_training() # Finalize successfully completed training + else: + raise RuntimeError("YOLO training process failed. Check logs for details.") + except Exception as e: + logging.error(f"Training error: {e}") + restore_backup() # Restore the dataset and masks + + # Emit training error event to the frontend + socketio.emit('training_status', {'status': 'error', 'message': f"Training failed: {str(e)}"}) + finally: + update_training_status('running', False) + training_process = None # Reset the process + + +@socketio.on('cancel_training') +def handle_cancel_training(): + """Cancel the YOLO training process.""" + global training_process, training_status + + if not training_status.get('running', False): + socketio.emit('button_update', {'action': 'retrain'}) # Update button to retrain + return + + try: + training_process.terminate() + training_process.wait() + training_status['running'] = False + training_status['cancelled'] = True + + restore_backup() + cleanup_train_val_directories() + + # Emit button state change + socketio.emit('button_update', {'action': 'retrain'}) + socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) + except Exception as e: + logging.error(f"Error cancelling training: {e}") + socketio.emit('training_status', {'status': 'error', 'message': str(e)}) + +def finalize_training(): + """Finalize training by promoting the new model and cleaning up.""" + try: + # Locate the most recent training directory + runs_dir = os.path.join(BASE_DIR, 'runs') + if not os.path.exists(runs_dir): + raise FileNotFoundError("Training runs directory does not exist.") + + # Get the latest training run folder + latest_run = max( + [os.path.join(runs_dir, d) for d in os.listdir(runs_dir)], + key=os.path.getmtime + ) + weights_dir = os.path.join(latest_run, 'weights') + best_model_path = os.path.join(weights_dir, 'best.pt') + + if not os.path.exists(best_model_path): + raise FileNotFoundError(f"'best.pt' not found in {weights_dir}.") + + # Backup the old model + old_model_folder = DATASET_FOLDERS['models_old'] + os.makedirs(old_model_folder, exist_ok=True) + existing_best_model = os.path.join(DATASET_FOLDERS['models'], 'best.pt') + + if os.path.exists(existing_best_model): + timestamp = time.strftime("%Y%m%d_%H%M%S") + shutil.move(existing_best_model, os.path.join(old_model_folder, f"old_{timestamp}.pt")) + logging.info(f"Old model backed up to {old_model_folder}.") + + # Move the new model to the models directory + new_model_dest = os.path.join(DATASET_FOLDERS['models'], 'best.pt') + shutil.move(best_model_path, new_model_dest) + logging.info(f"New model saved to {new_model_dest}.") + + # Notify frontend that training is completed + socketio.emit('training_status', { + 'status': 'completed', + 'message': 'Training completed successfully! Model saved as best.pt.' + }) + + # Clean up train/val directories + cleanup_train_val_directories() + logging.info("Train and validation directories cleaned up successfully.") + + except Exception as e: + logging.error(f"Error finalizing training: {e}") + # Emit error status to the frontend + socketio.emit('training_status', {'status': 'error', 'message': f"Error finalizing training: {str(e)}"}) + +def restore_backup(): + """Restore the dataset and masks from the backup.""" + try: + temp_backup = DATASET_FOLDERS['temp_backup'] + shutil.copytree(os.path.join(temp_backup, 'masks/voids'), UPLOAD_FOLDERS['mask_voids'], dirs_exist_ok=True) + shutil.copytree(os.path.join(temp_backup, 'masks/chips'), UPLOAD_FOLDERS['mask_chips'], dirs_exist_ok=True) + shutil.copytree(os.path.join(temp_backup, 'images'), UPLOAD_FOLDERS['input'], dirs_exist_ok=True) + logging.info("Backup restored successfully.") + except Exception as e: + logging.error(f"Error restoring backup: {e}") + +@app.route('/cancel_training', methods=['POST']) +def cancel_training(): + global training_process + + if training_process is None: + logging.error("No active training process to terminate.") + return jsonify({'error': 'No active training process to cancel.'}), 400 + + try: + training_process.terminate() + training_process.wait() + training_process = None # Reset the process after termination + + # Update training status + update_training_status('running', False) + update_training_status('cancelled', True) + + # Check if the model is already saved as best.pt + best_model_path = os.path.join(DATASET_FOLDERS['models'], 'best.pt') + if os.path.exists(best_model_path): + logging.info(f"Model already saved as best.pt at {best_model_path}.") + socketio.emit('button_update', {'action': 'revert'}) # Notify frontend to revert button state + else: + logging.info("Training canceled, but no new model was saved.") + + # Restore backup if needed + restore_backup() + cleanup_train_val_directories() + + # Emit status update to frontend + socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) + return jsonify({'message': 'Training canceled and data restored successfully.'}), 200 + + except Exception as e: + logging.error(f"Error cancelling training: {e}") + return jsonify({'error': f"Failed to cancel training: {e}"}), 500 + +@app.route('/clear_history', methods=['POST']) +def clear_history(): + try: + for folder in [HISTORY_FOLDERS['images'], HISTORY_FOLDERS['masks_chip'], HISTORY_FOLDERS['masks_void']]: + shutil.rmtree(folder, ignore_errors=True) + os.makedirs(folder, exist_ok=True) # Recreate the empty folder + return jsonify({'message': 'History cleared successfully!'}), 200 + except Exception as e: + return jsonify({'error': f'Failed to clear history: {e}'}), 500 + +@app.route('/training_status', methods=['GET']) +def get_training_status(): + """Return the current training status.""" + if training_status.get('running', False): + return jsonify({'status': 'running', 'message': 'Training in progress.'}), 200 + elif training_status.get('cancelled', False): + return jsonify({'status': 'cancelled', 'message': 'Training was cancelled.'}), 200 + return jsonify({'status': 'idle', 'message': 'No training is currently running.'}), 200 + +def cleanup_train_val_directories(): + """Clear the train and validation directories.""" + try: + for folder in [DATASET_FOLDERS['train_images'], DATASET_FOLDERS['train_labels'], + DATASET_FOLDERS['val_images'], DATASET_FOLDERS['val_labels']]: + shutil.rmtree(folder, ignore_errors=True) # Remove folder contents + os.makedirs(folder, exist_ok=True) # Recreate empty folders + logging.info("Train and validation directories cleaned up successfully.") + except Exception as e: + logging.error(f"Error cleaning up train/val directories: {e}") + + +if __name__ == '__main__': + multiprocessing.set_start_method('spawn') # Required for multiprocessing on Windows + app.run(debug=True, use_reloader=False) + + diff --git a/dataset/.DS_Store b/dataset/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..45eee4e220351d04640aaba67c8aca91b515295b Binary files /dev/null and b/dataset/.DS_Store differ diff --git a/dataset/images/.DS_Store b/dataset/images/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e84e195143810d17e6148f587bc19a630d3a9048 Binary files /dev/null and b/dataset/images/.DS_Store differ diff --git a/dataset/images/train/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5.jpg b/dataset/images/train/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d8c85c5cc11e9c183748a7d35e3c36ecb73d39ce Binary files /dev/null and b/dataset/images/train/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5.jpg differ diff --git a/dataset/images/train/03_JPG.rf.2ca107348e11cdefab68044dba66388d.jpg b/dataset/images/train/03_JPG.rf.2ca107348e11cdefab68044dba66388d.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f896e21dd791ed02fa488ef7c10dc2777118962 Binary files /dev/null and b/dataset/images/train/03_JPG.rf.2ca107348e11cdefab68044dba66388d.jpg differ diff --git a/dataset/images/train/04_JPG.rf.b0b546ecbc6b70149b8932018e69fef0.jpg b/dataset/images/train/04_JPG.rf.b0b546ecbc6b70149b8932018e69fef0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25373da9f2cd50afbc8ee8f7bc799aac96a40220 Binary files /dev/null and b/dataset/images/train/04_JPG.rf.b0b546ecbc6b70149b8932018e69fef0.jpg differ diff --git a/dataset/images/train/05_jpg.rf.46241369ebb0749c40882400f82eb224.jpg b/dataset/images/train/05_jpg.rf.46241369ebb0749c40882400f82eb224.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3579081555eed5deffe74a0c437f943bdc15c6d6 Binary files /dev/null and b/dataset/images/train/05_jpg.rf.46241369ebb0749c40882400f82eb224.jpg differ diff --git a/dataset/images/train/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98.jpg b/dataset/images/train/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8c3fee5d55e85fbac7e5e21cda3751efcd5ca171 Binary files /dev/null and b/dataset/images/train/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98.jpg differ diff --git a/dataset/images/train/09_JPG.rf.9119efd8c174f968457a893669209835.jpg b/dataset/images/train/09_JPG.rf.9119efd8c174f968457a893669209835.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9a889c4aa0a10f9b5bcd4286b13fa35e7c7535ea Binary files /dev/null and b/dataset/images/train/09_JPG.rf.9119efd8c174f968457a893669209835.jpg differ diff --git a/dataset/images/train/10_JPG.rf.6745a7b3ea828239398b85182acba199.jpg b/dataset/images/train/10_JPG.rf.6745a7b3ea828239398b85182acba199.jpg new file mode 100644 index 0000000000000000000000000000000000000000..38978a38b1cb3774a934411b3f30271b18f56830 Binary files /dev/null and b/dataset/images/train/10_JPG.rf.6745a7b3ea828239398b85182acba199.jpg differ diff --git a/dataset/images/train/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg b/dataset/images/train/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac452b8c15bf70a138cbc334e670fb75f2c4900e Binary files /dev/null and b/dataset/images/train/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg differ diff --git a/dataset/images/train/12_jpg.rf.357643b374df92f81f9dee7c701b2315.jpg b/dataset/images/train/12_jpg.rf.357643b374df92f81f9dee7c701b2315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..076496fd0fc67d90ad57e8d1489b4b74325d6a95 Binary files /dev/null and b/dataset/images/train/12_jpg.rf.357643b374df92f81f9dee7c701b2315.jpg differ diff --git a/dataset/images/train/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c.jpg b/dataset/images/train/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5c9d8aab04b28cbbede78a5acf683b613aa14dd9 Binary files /dev/null and b/dataset/images/train/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c.jpg differ diff --git a/dataset/images/train/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg b/dataset/images/train/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4cb1ec47ff7e172bcd876582462f9667c0118c19 Binary files /dev/null and b/dataset/images/train/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg differ diff --git a/dataset/images/train/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg b/dataset/images/train/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3612ea8bad3b881d64ba82a7bcb3734fd0d3c301 Binary files /dev/null and b/dataset/images/train/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg differ diff --git a/dataset/images/train/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg b/dataset/images/train/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f7769c2afbed542791dc7aa6c99b8f731b993c3d Binary files /dev/null and b/dataset/images/train/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg differ diff --git a/dataset/images/train/18_jpg.rf.4d241aab78af17171d83f3a50f1cf1aa.jpg b/dataset/images/train/18_jpg.rf.4d241aab78af17171d83f3a50f1cf1aa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b11c87c8db424bf376c804444f39271f52c3b71f Binary files /dev/null and b/dataset/images/train/18_jpg.rf.4d241aab78af17171d83f3a50f1cf1aa.jpg differ diff --git a/dataset/images/train/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg b/dataset/images/train/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a23e88255583db9dad0ae7d84c04aee4a17d7a1 Binary files /dev/null and b/dataset/images/train/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg differ diff --git a/dataset/images/train/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg b/dataset/images/train/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg new file mode 100644 index 0000000000000000000000000000000000000000..368daef342ad0515e373f0c051f72717bdeb2d47 Binary files /dev/null and b/dataset/images/train/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg differ diff --git a/dataset/images/train/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg b/dataset/images/train/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1e24609b75ed127858eb6fb66eba9d2274a09917 Binary files /dev/null and b/dataset/images/train/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg differ diff --git a/dataset/images/train/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg b/dataset/images/train/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg new file mode 100644 index 0000000000000000000000000000000000000000..860d0363ad46861aea837591d26295477a608eac Binary files /dev/null and b/dataset/images/train/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg differ diff --git a/dataset/images/train/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg b/dataset/images/train/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b357e7c557db3b68472b6f40accc04f469f33dd Binary files /dev/null and b/dataset/images/train/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg differ diff --git a/dataset/images/train/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg b/dataset/images/train/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9182db336422bdf0ac364af9fad89d9e0aa7ff4f Binary files /dev/null and b/dataset/images/train/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg differ diff --git a/dataset/images/train/31_jpg.rf.f31137f793efde0462ed560d426dcd24.jpg b/dataset/images/train/31_jpg.rf.f31137f793efde0462ed560d426dcd24.jpg new file mode 100644 index 0000000000000000000000000000000000000000..94357eeb8c07e5d9b99e842c33e7ed46db3f4643 Binary files /dev/null and b/dataset/images/train/31_jpg.rf.f31137f793efde0462ed560d426dcd24.jpg differ diff --git a/dataset/images/train/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058.jpg b/dataset/images/train/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9472eeee73b114d8aca4191a6d71caa94b3fb499 Binary files /dev/null and b/dataset/images/train/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058.jpg differ diff --git a/dataset/images/train/LU-F_mod_jpg.rf.fc594179772346639512f891960969bb.jpg b/dataset/images/train/LU-F_mod_jpg.rf.fc594179772346639512f891960969bb.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6db21088484d5b0e1a559b9570231c59ee20a2d0 Binary files /dev/null and b/dataset/images/train/LU-F_mod_jpg.rf.fc594179772346639512f891960969bb.jpg differ diff --git a/dataset/images/train/Solder_Voids_jpg.rf.d40f1b71d8a801f084067fde7f33fb08.jpg b/dataset/images/train/Solder_Voids_jpg.rf.d40f1b71d8a801f084067fde7f33fb08.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3bf2f808895d944306b7b31de2584eb3646c7704 Binary files /dev/null and b/dataset/images/train/Solder_Voids_jpg.rf.d40f1b71d8a801f084067fde7f33fb08.jpg differ diff --git a/dataset/images/train/gc10_lake_voids_260-31_jpg.rf.479f3d9dda8dd22097d3d93c78f7e11d.jpg b/dataset/images/train/gc10_lake_voids_260-31_jpg.rf.479f3d9dda8dd22097d3d93c78f7e11d.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d51dd64d43f84d62af073e86b317d1963ebfe231 Binary files /dev/null and b/dataset/images/train/gc10_lake_voids_260-31_jpg.rf.479f3d9dda8dd22097d3d93c78f7e11d.jpg differ diff --git a/dataset/images/train/images_jpg.rf.675b31c5e1ba2b77f0fa5ca92e2391b0.jpg b/dataset/images/train/images_jpg.rf.675b31c5e1ba2b77f0fa5ca92e2391b0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..045cca42b55430453acb0c6ab30cb76fe9380e43 Binary files /dev/null and b/dataset/images/train/images_jpg.rf.675b31c5e1ba2b77f0fa5ca92e2391b0.jpg differ diff --git a/dataset/images/train/qfn-voiding_0_jpg.rf.2945527db158e9ff4943febaf9cd3eab.jpg b/dataset/images/train/qfn-voiding_0_jpg.rf.2945527db158e9ff4943febaf9cd3eab.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05a2de971994b18d531c7c29c823462999205b12 Binary files /dev/null and b/dataset/images/train/qfn-voiding_0_jpg.rf.2945527db158e9ff4943febaf9cd3eab.jpg differ diff --git a/dataset/images/train/techtips_3_jpg.rf.ad88af637816f0999f4df0b18dfef293.jpg b/dataset/images/train/techtips_3_jpg.rf.ad88af637816f0999f4df0b18dfef293.jpg new file mode 100644 index 0000000000000000000000000000000000000000..35fef586ccd3c420a1f5ccd95f492e3ddd15c06e Binary files /dev/null and b/dataset/images/train/techtips_3_jpg.rf.ad88af637816f0999f4df0b18dfef293.jpg differ diff --git a/dataset/images/val/025_JPG.rf.b2cdc2d984adff593dc985f555b8d280.jpg b/dataset/images/val/025_JPG.rf.b2cdc2d984adff593dc985f555b8d280.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eebcff406fa21c88379bddc14e7bbd558348d63c Binary files /dev/null and b/dataset/images/val/025_JPG.rf.b2cdc2d984adff593dc985f555b8d280.jpg differ diff --git a/dataset/images/val/06_jpg.rf.a94e0a678df372f5ea1395a8d888a388.jpg b/dataset/images/val/06_jpg.rf.a94e0a678df372f5ea1395a8d888a388.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5a89d78f8b015959d982c19c8e09e0cb08b11f0a Binary files /dev/null and b/dataset/images/val/06_jpg.rf.a94e0a678df372f5ea1395a8d888a388.jpg differ diff --git a/dataset/images/val/07_JPG.rf.324d17a87726bd2a9614536c687c6e68.jpg b/dataset/images/val/07_JPG.rf.324d17a87726bd2a9614536c687c6e68.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a5e04ae573b74fe077143d2c7bb2cb5c20cf855 Binary files /dev/null and b/dataset/images/val/07_JPG.rf.324d17a87726bd2a9614536c687c6e68.jpg differ diff --git a/dataset/images/val/23_jpg.rf.8e9afa6b3b471e10c26637d47700f28b.jpg b/dataset/images/val/23_jpg.rf.8e9afa6b3b471e10c26637d47700f28b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9f3239f149d116ab41222cfd7101aa46cf7d186 Binary files /dev/null and b/dataset/images/val/23_jpg.rf.8e9afa6b3b471e10c26637d47700f28b.jpg differ diff --git a/dataset/images/val/24_jpg.rf.4caa996d97e35f6ce4f27a527ea43465.jpg b/dataset/images/val/24_jpg.rf.4caa996d97e35f6ce4f27a527ea43465.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b9d307699fb320af436e1e0662e0d561902eceb Binary files /dev/null and b/dataset/images/val/24_jpg.rf.4caa996d97e35f6ce4f27a527ea43465.jpg differ diff --git a/dataset/images/val/27_jpg.rf.3475fce31d283058f46d9f349c04cb1a.jpg b/dataset/images/val/27_jpg.rf.3475fce31d283058f46d9f349c04cb1a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aa62a4a294da74b9e245a78ecdbb98746ffa6971 Binary files /dev/null and b/dataset/images/val/27_jpg.rf.3475fce31d283058f46d9f349c04cb1a.jpg differ diff --git a/dataset/images/val/28_jpg.rf.50e348d807d35667583137c9a6c162ca.jpg b/dataset/images/val/28_jpg.rf.50e348d807d35667583137c9a6c162ca.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4bfbef4475236719405a9786a9477f5b85b71112 Binary files /dev/null and b/dataset/images/val/28_jpg.rf.50e348d807d35667583137c9a6c162ca.jpg differ diff --git a/dataset/images/val/30_jpg.rf.ed72622e97cf0d884997585686cfe40a.jpg b/dataset/images/val/30_jpg.rf.ed72622e97cf0d884997585686cfe40a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5bdd14697d8cfb2781ab9ad74f96609c0d5d0b0d Binary files /dev/null and b/dataset/images/val/30_jpg.rf.ed72622e97cf0d884997585686cfe40a.jpg differ diff --git a/dataset/test/.DS_Store b/dataset/test/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dac909f73e865879c40c05a7378d3b6e42a255b5 Binary files /dev/null and b/dataset/test/.DS_Store differ diff --git a/dataset/test/images/17_jpg.rf.ec31940ea72d0cf8b9f38dba68789fcf.jpg b/dataset/test/images/17_jpg.rf.ec31940ea72d0cf8b9f38dba68789fcf.jpg new file mode 100644 index 0000000000000000000000000000000000000000..484e06e4dd8b45c5402b8216c7b359e052757f7b Binary files /dev/null and b/dataset/test/images/17_jpg.rf.ec31940ea72d0cf8b9f38dba68789fcf.jpg differ diff --git a/dataset/test/images/19_jpg.rf.2c5ffd63bd0ce6b9b0c80fef69d101dc.jpg b/dataset/test/images/19_jpg.rf.2c5ffd63bd0ce6b9b0c80fef69d101dc.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5fd1580a96568c55c8a2e8c807b3a2bd64a760a Binary files /dev/null and b/dataset/test/images/19_jpg.rf.2c5ffd63bd0ce6b9b0c80fef69d101dc.jpg differ diff --git a/dataset/test/images/32_jpg.rf.f3e33dcf611a8754c0765224f7873d8b.jpg b/dataset/test/images/32_jpg.rf.f3e33dcf611a8754c0765224f7873d8b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99481b46e91d12dee9f513935645ff969e728adf Binary files /dev/null and b/dataset/test/images/32_jpg.rf.f3e33dcf611a8754c0765224f7873d8b.jpg differ diff --git a/dataset/test/images/normal-reflow_jpg.rf.2c4fbc1fda915b821b85689ae257e116.jpg b/dataset/test/images/normal-reflow_jpg.rf.2c4fbc1fda915b821b85689ae257e116.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1fa35cbc1155a92288144bae6ce6022cef079da1 Binary files /dev/null and b/dataset/test/images/normal-reflow_jpg.rf.2c4fbc1fda915b821b85689ae257e116.jpg differ diff --git a/dataset/test/images/techtips_31_jpg.rf.673cd3c7c8511e534766e6dbc3171b39.jpg b/dataset/test/images/techtips_31_jpg.rf.673cd3c7c8511e534766e6dbc3171b39.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aea9d6c4fc2c5d0fe8fe8d66aebbf1d565d92eaa Binary files /dev/null and b/dataset/test/images/techtips_31_jpg.rf.673cd3c7c8511e534766e6dbc3171b39.jpg differ diff --git a/dataset/test/labels/.DS_Store b/dataset/test/labels/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/dataset/test/labels/.DS_Store differ diff --git a/dataset/test/labels/17_jpg.rf.ec31940ea72d0cf8b9f38dba68789fcf.txt b/dataset/test/labels/17_jpg.rf.ec31940ea72d0cf8b9f38dba68789fcf.txt new file mode 100644 index 0000000000000000000000000000000000000000..40994b2332a91b26245faed345cfb81b58d2698a --- /dev/null +++ b/dataset/test/labels/17_jpg.rf.ec31940ea72d0cf8b9f38dba68789fcf.txt @@ -0,0 +1,6 @@ +0 0.6859375 0.8625 0.6578125 0.8671875 0.615625 0.8390625 0.55 0.85 0.5625 0.8578125 0.5125 0.884375 0.5578125 0.8984375 0.590625 0.871875 0.575 0.9109375 0.596875 0.9109375 0.6015625 0.865625 0.6140625 0.8984375 0.6328125 0.8515625 0.6390625 0.8734375 0.6765625 0.871875 0.68125 0.8984375 +0 0.7859375 0.6375 0.4046875 0.675 0.4109375 0.75625 0.2875 0.6875 0.2875 0.75 0.2125 0.7125 0.15 0.784375 0.3234375 0.9109375 0.3296875 0.8375 +0 0.55 0.5375 0.55625 0.540625 0.55 0.5453125 0.55 0.5859375 0.596875 0.584375 0.60625 0.5984375 0.615625 0.58125 0.634375 0.5859375 0.6375 0.5734375 0.6484375 0.5734375 0.6484375 0.55625 0.6359375 0.5515625 0.634375 0.5375 0.5859375 0.5375 0.5796875 0.54375 0.5703125 0.5421875 0.578125 0.5375 +0 0.0015625 0.5375 0.0015625 0.6734375 0.0546875 0.6875 0.0625 0.7234375 0.0859375 0.696875 0.078125 0.7109375 0.09375 0.7109375 0.1109375 0.65 0.0609375 0.65 0.046875 0.615625 0.075 0.5921875 0.1109375 0.5984375 0.096875 0.5375 +0 0.7234375 0.3015625 0.671875 0.2625 0.44375 0.3015625 0.1125 0.275 0.0625 0.334375 0.265625 0.4984375 0.521875 0.4609375 +0 0.525 0.1375 0.525 0.1734375 0.5125 0.175 0.5125 0.184375 0.525 0.1828125 0.528125 0.1984375 0.5484375 0.1984375 0.5578125 0.20625 0.5515625 0.2109375 0.5625 0.2046875 0.5984375 0.2109375 0.5953125 0.2 0.6109375 0.1953125 0.6109375 0.1796875 0.6046875 0.1828125 0.6109375 0.1390625 diff --git a/dataset/test/labels/19_jpg.rf.2c5ffd63bd0ce6b9b0c80fef69d101dc.txt b/dataset/test/labels/19_jpg.rf.2c5ffd63bd0ce6b9b0c80fef69d101dc.txt new file mode 100644 index 0000000000000000000000000000000000000000..ffdd5463c0935cbcff496fd541d3c1235fa52889 --- /dev/null +++ b/dataset/test/labels/19_jpg.rf.2c5ffd63bd0ce6b9b0c80fef69d101dc.txt @@ -0,0 +1,3 @@ +0 0.625 0.6625 0.625 0.6921875 0.634375 0.6984375 0.6578125 0.6984375 0.6578125 0.69375 0.6734375 0.6984375 0.6671875 0.6953125 0.6765625 0.684375 0.7125 0.6859375 0.70625 0.684375 0.7171875 0.68125 0.715625 0.6859375 0.7234375 0.6859375 0.71875 0.675 0.7234375 0.665625 0.6703125 0.6625 0.66875 0.6671875 0.6671875 0.6625 +0 0.5625 0.6375 0.5625 0.6484375 0.5671875 0.6484375 0.5625 0.6515625 0.5625 0.665625 0.571875 0.6640625 0.5625 0.6734375 0.575 0.6734375 0.578125 0.66875 0.584375 0.671875 0.58125 0.6734375 0.5890625 0.6734375 0.5875 0.66875 0.5984375 0.6734375 0.59375 0.665625 0.5984375 0.6671875 0.6 0.659375 0.6109375 0.659375 0.6046875 0.6578125 0.6109375 0.6515625 0.6046875 0.640625 0.6109375 0.6375 +0 0.525 0.55 0.5265625 0.5609375 0.53125 0.5609375 0.525 0.565625 0.525 0.5734375 0.5265625 0.5703125 0.5328125 0.5734375 0.53125 0.56875 0.5390625 0.5703125 0.5390625 0.5734375 0.55625 0.5734375 0.5609375 0.5703125 0.5609375 0.559375 0.5578125 0.5578125 0.559375 0.553125 0.5609375 0.55625 0.5609375 0.55 0.5421875 0.55 0.5390625 0.553125 0.5359375 0.55 0.5359375 0.5546875 0.5328125 0.55 0.5328125 0.5609375 diff --git a/dataset/test/labels/32_jpg.rf.f3e33dcf611a8754c0765224f7873d8b.txt b/dataset/test/labels/32_jpg.rf.f3e33dcf611a8754c0765224f7873d8b.txt new file mode 100644 index 0000000000000000000000000000000000000000..85a8948968627cc3c10f095de093a374c6e59b5a --- /dev/null +++ b/dataset/test/labels/32_jpg.rf.f3e33dcf611a8754c0765224f7873d8b.txt @@ -0,0 +1,4 @@ +0 0.15 0.7125 0.2515625 0.8234375 0.31875 0.821875 0.3546875 0.7546875 0.3671875 0.7984375 0.39375 0.7453125 0.5 0.7734375 0.5109375 0.8328125 0.5734375 0.759375 0.4609375 0.6375 0.2796875 0.625 +0 0.85 0.2140625 0.8546875 0.225 0.85 0.2359375 0.85625 0.2265625 0.8640625 0.2484375 0.903125 0.2484375 0.9046875 0.2421875 0.9109375 0.2484375 0.909375 0.2359375 0.9234375 0.234375 0.91875 0.23125 0.9234375 0.23125 0.9234375 0.2140625 0.9078125 0.209375 0.9109375 0.2 0.9046875 0.2 0.90625 0.209375 0.903125 0.2046875 0.8734375 0.203125 0.8671875 0.2125 0.8640625 0.209375 0.871875 0.2 0.8625 0.2 0.8640625 0.2125 +0 0.6984375 0.1359375 0.628125 0.1125 0.4640625 0.18125 0.3515625 0.1375 0.3 0.2859375 0.375 0.3484375 0.559375 0.3734375 0.56875 0.2578125 +0 0.775 0.0875 0.78125 0.090625 0.775 0.0953125 0.778125 0.1015625 0.7703125 0.1 0.778125 0.1046875 0.76875 0.1078125 0.7671875 0.1 0.7640625 0.1109375 0.7734375 0.10625 0.78125 0.1109375 0.7828125 0.10625 0.7921875 0.1109375 0.7890625 0.115625 0.796875 0.1140625 0.796875 0.121875 0.7890625 0.1234375 0.809375 0.1234375 0.809375 0.1125 0.8015625 0.1140625 0.7921875 0.1046875 0.7984375 0.1015625 0.79375 0.1 0.796875 0.0953125 0.809375 0.0953125 0.809375 0.0890625 diff --git a/dataset/test/labels/normal-reflow_jpg.rf.2c4fbc1fda915b821b85689ae257e116.txt b/dataset/test/labels/normal-reflow_jpg.rf.2c4fbc1fda915b821b85689ae257e116.txt new file mode 100644 index 0000000000000000000000000000000000000000..b736c451cd079b3ca97dead276c92cdca4f09db2 --- /dev/null +++ b/dataset/test/labels/normal-reflow_jpg.rf.2c4fbc1fda915b821b85689ae257e116.txt @@ -0,0 +1,5 @@ +0 0.3875 0.6875 0.3921875 0.7328125 0.3875 0.7359375 0.39375 0.73125 0.403125 0.740625 0.4015625 0.7484375 0.4140625 0.7421875 0.4234375 0.7484375 0.4328125 0.7359375 0.4265625 0.7484375 0.4359375 0.7484375 0.434375 0.7359375 0.4484375 0.734375 0.440625 0.7296875 0.4484375 0.7234375 0.446875 0.690625 +0 0.5125 0.525 0.51875 0.528125 0.515625 0.5359375 0.5 0.5390625 0.5 0.590625 0.509375 0.5890625 0.5 0.5984375 0.5140625 0.5984375 0.509375 0.590625 0.5171875 0.5984375 0.559375 0.5984375 0.5609375 0.5390625 0.5515625 0.5453125 0.553125 0.5375 0.5203125 0.5328125 0.5203125 0.525 0.5359375 0.528125 +0 0.8328125 0.5125 0.7625 0.5125 0.675 0.59375 0.6875 0.6734375 0.746875 0.6734375 0.8484375 0.5984375 0.8609375 0.5515625 +0 0.3125 0.5 0.2875 0.5375 0.2875 0.5953125 0.3140625 0.6234375 0.396875 0.6359375 0.4359375 0.609375 0.4484375 0.55625 0.421875 0.5125 +0 0.3515625 0.1125 0.3484375 0.128125 0.3125 0.15 0.3125 0.1984375 0.3265625 0.2 0.325 0.2234375 0.3359375 0.215625 0.3421875 0.2359375 0.353125 0.2265625 0.346875 0.2359375 0.396875 0.2359375 0.4234375 0.20625 0.4234375 0.1390625 0.4109375 0.1359375 0.40625 0.1125 diff --git a/dataset/test/labels/techtips_31_jpg.rf.673cd3c7c8511e534766e6dbc3171b39.txt b/dataset/test/labels/techtips_31_jpg.rf.673cd3c7c8511e534766e6dbc3171b39.txt new file mode 100644 index 0000000000000000000000000000000000000000..d0ea5e014fd76bef3dadb2b1b6d184b48da17e61 --- /dev/null +++ b/dataset/test/labels/techtips_31_jpg.rf.673cd3c7c8511e534766e6dbc3171b39.txt @@ -0,0 +1,7 @@ +0 0.4234375 0.8015625 0.41875 0.8046875 0.4171875 0.8 0.4125 0.8015625 0.4140625 0.8109375 0.409375 0.815625 0.40625 0.815625 0.40625 0.8125 0.4 0.815625 0.4 0.8203125 0.4046875 0.81875 0.4 0.825 0.403125 0.825 0.4 0.828125 0.403125 0.8296875 0.4 0.8328125 0.403125 0.8359375 0.403125 0.8265625 0.4078125 0.8296875 0.40625 0.8359375 0.4109375 0.834375 0.4078125 0.8296875 0.4140625 0.8171875 0.415625 0.8234375 0.41875 0.8234375 0.41875 0.81875 0.421875 0.8234375 0.4203125 0.8140625 0.4234375 0.80625 0.41875 0.80625 +0 0.2984375 0.7125 0.275 0.71875 0.28125 0.7125 0.2640625 0.7 0.2625 0.71875 0.2125 0.71875 0.2125 0.784375 0.2265625 0.7875 0.225 0.7984375 0.23125 0.7875 0.2328125 0.7984375 0.284375 0.7984375 0.2875 0.7796875 0.296875 0.784375 0.3109375 0.7546875 0.309375 0.728125 0.2921875 0.725 +0 0.225 0.475 0.2296875 0.4859375 0.2125 0.4890625 0.2203125 0.49375 0.2125 0.5234375 0.2265625 0.5359375 0.234375 0.53125 0.228125 0.5484375 0.2859375 0.5484375 0.28125 0.5359375 0.2984375 0.5359375 0.2921875 0.53125 0.2984375 0.4875 0.278125 0.484375 0.284375 0.475 +0 0.8125 0.3125 0.8125 0.334375 0.81875 0.3328125 0.8125 0.359375 0.8296875 0.3625 0.825 0.4234375 0.8359375 0.4234375 0.8359375 0.3734375 0.8296875 0.3734375 0.8359375 0.365625 0.8328125 0.371875 0.828125 0.365625 0.846875 0.3609375 0.8484375 0.35 0.834375 0.3578125 0.8359375 0.3375 0.83125 0.346875 0.83125 0.3375 0.8234375 0.3390625 0.8234375 0.315625 +0 0.15 0.3140625 0.15 0.3328125 0.15625 0.3296875 0.15625 0.325 0.1640625 0.3328125 0.15 0.3390625 0.15 0.35625 0.1546875 0.3609375 0.1578125 0.3578125 0.1703125 0.3609375 0.1734375 0.3484375 0.1671875 0.3609375 0.1625 0.35625 0.1734375 0.3421875 0.1703125 0.3390625 0.1734375 0.33125 0.1671875 0.321875 0.16875 0.31875 0.1734375 0.325 0.171875 0.3125 +0 0.7140625 0.125 0.7125 0.146875 0.71875 0.1484375 0.7125 0.1609375 0.728125 0.1625 0.7125 0.165625 0.721875 0.1703125 0.7125 0.175 0.7140625 0.1859375 0.7296875 0.1859375 0.7390625 0.1765625 0.7484375 0.1796875 0.7453125 0.175 0.7296875 0.178125 0.734375 0.1625 0.7484375 0.159375 0.7484375 0.125 0.7453125 0.1296875 0.7421875 0.125 +0 0.6328125 0.125 0.65 0.23125 0.2140625 0.25 0.4140625 0.6015625 0.3390625 0.7734375 0.7984375 0.746875 0.8109375 0.3640625 diff --git a/dataset/train/labels.cache b/dataset/train/labels.cache new file mode 100644 index 0000000000000000000000000000000000000000..cda1012ebd4f054570fc6fdbc2ee0954e0d569fc Binary files /dev/null and b/dataset/train/labels.cache differ diff --git a/dataset/val/labels.cache b/dataset/val/labels.cache new file mode 100644 index 0000000000000000000000000000000000000000..288f98d1817a617d2cf4f39fac418a2d26822e0d Binary files /dev/null and b/dataset/val/labels.cache differ diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ef0cae7a41b0ec90a90f0c6c56c15a10455c3e88 Binary files /dev/null and b/models/.DS_Store differ diff --git a/models/best.pt b/models/best.pt new file mode 100644 index 0000000000000000000000000000000000000000..a41ca2fa9d12eef772f44654d92e21b37a7594e7 --- /dev/null +++ b/models/best.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c1a69f8b7ebf89dff6c8512342a91782165b0536aafdc8ffa40ded5bd647843 +size 6012957 diff --git a/models/data.yaml b/models/data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..023c9eab43e613c9133ec384417f04e7d9c60718 --- /dev/null +++ b/models/data.yaml @@ -0,0 +1,7 @@ +train: C:/Users/marcv/Desktop/project/segmentation_tool/dataset/train/images +val: C:/Users/marcv/Desktop/project/segmentation_tool/dataset/val/images + +#test: dataset/test/images # Optional, for testing + +nc: 2 # Number of classes +names: ['void', 'chip'] # Class names diff --git a/models/old/old.pt b/models/old/old.pt new file mode 100644 index 0000000000000000000000000000000000000000..a41ca2fa9d12eef772f44654d92e21b37a7594e7 --- /dev/null +++ b/models/old/old.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c1a69f8b7ebf89dff6c8512342a91782165b0536aafdc8ffa40ded5bd647843 +size 6012957 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..df0033cf1f83809d424e00d8e1901955f5ace199 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +Flask +Pillow +numpy<2 +Flask-SocketIO +Pillow +torch +torchaudio +torchvision +ultralytics +hydra-core>=1.3.2 diff --git a/runs/.DS_Store b/runs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..42f24cd135e6f096fa9d63a97e722c7a39da6257 Binary files /dev/null and b/runs/.DS_Store differ diff --git a/runtime.txt b/runtime.txt new file mode 100644 index 0000000000000000000000000000000000000000..cd6f13073e4940080f1692d0e6b2e5fdb98c3a51 --- /dev/null +++ b/runtime.txt @@ -0,0 +1 @@ +python-3.11.1 \ No newline at end of file diff --git a/sam2/.DS_Store b/sam2/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b2506461939e867dc0fdfb8f3f88d203aa5a55ad Binary files /dev/null and b/sam2/.DS_Store differ diff --git a/sam2/__init__.py b/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec230f636cc8632cafa59f3652d52004362178aa --- /dev/null +++ b/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("sam2", version_base="1.2") diff --git a/sam2/__pycache__/__init__.cpython-311.pyc b/sam2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..576661585a2cc26cf83457a8b6c99bcceb9307dd Binary files /dev/null and b/sam2/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/__pycache__/build_sam.cpython-311.pyc b/sam2/__pycache__/build_sam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46a7be3857de70a415c388a19209dc33b8047aea Binary files /dev/null and b/sam2/__pycache__/build_sam.cpython-311.pyc differ diff --git a/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc b/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f95e7c46c6e10fc8dffe2cc33bfb50ac3b7e42c Binary files /dev/null and b/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc differ diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..68af8fe3fd01be1706c4ca57ddffa1b2a392b4ee --- /dev/null +++ b/sam2/automatic_mask_generator.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import ( + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + MaskData, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + **kwargs, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2AutomaticMaskGenerator): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + points = torch.as_tensor( + points, dtype=torch.float32, device=self.predictor.device + ) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/sam2/build_sam.py b/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6bdd1209f1587be562eea6327866181698c539 --- /dev/null +++ b/sam2/build_sam.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + +import sam2 + +# Check if the user is running Python from the parent directory of the sam2 repo +# (i.e. the directory where this repo is cloned into) -- this is not supported since +# it could shadow the sam2 package and cause issues. +if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): + # If the user has "sam2/sam2" in their path, they are likey importing the repo itself + # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). + # This typically happens because the user is running Python from the parent directory + # that contains the sam2 repo they cloned. + raise RuntimeError( + "You're likely running Python from the parent directory of the sam2 repository " + "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " + "This is not supported since the `sam2` Python package could be shadowed by the " + "repository name (the repository is also named `sam2` and contains the Python package " + "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " + "rather than its parent dir, or from your home directory) after installing SAM 2." + ) + + +HF_MODEL_ID_TO_FILENAMES = { + "facebook/sam2-hiera-tiny": ( + "configs/sam2/sam2_hiera_t.yaml", + "sam2_hiera_tiny.pt", + ), + "facebook/sam2-hiera-small": ( + "configs/sam2/sam2_hiera_s.yaml", + "sam2_hiera_small.pt", + ), + "facebook/sam2-hiera-base-plus": ( + "configs/sam2/sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ( + "configs/sam2/sam2_hiera_l.yaml", + "sam2_hiera_large.pt", + ), + "facebook/sam2.1-hiera-tiny": ( + "configs/sam2.1/sam2.1_hiera_t.yaml", + "sam2.1_hiera_tiny.pt", + ), + "facebook/sam2.1-hiera-small": ( + "configs/sam2.1/sam2.1_hiera_s.yaml", + "sam2.1_hiera_small.pt", + ), + "facebook/sam2.1-hiera-base-plus": ( + "configs/sam2.1/sam2.1_hiera_b+.yaml", + "sam2.1_hiera_base_plus.pt", + ), + "facebook/sam2.1-hiera-large": ( + "configs/sam2.1/sam2.1_hiera_l.yaml", + "sam2.1_hiera_large.pt", + ), +} + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def _hf_download(model_id): + from huggingface_hub import hf_hub_download + + config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return config_name, ckpt_path + + +def build_sam2_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + config_name, ckpt_path = _hf_download(model_id) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebed680f99eb32d4967511cf979b63953d080332 --- /dev/null +++ b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..451b78cd6399c7f4ec1c34f4e06dfb12cd027904 --- /dev/null +++ b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml @@ -0,0 +1,120 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1adaae6fda669ff04ecd06947810ae2f62e6c57a --- /dev/null +++ b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -0,0 +1,119 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b1e73d20f7c125618171229e3520fb05ebbdf666 --- /dev/null +++ b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -0,0 +1,121 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b31ed0204f92fba372367abc5d021b42a6fb1bcd --- /dev/null +++ b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml @@ -0,0 +1,339 @@ +# @package _global_ + +scratch: + resolution: 1024 + train_batch_size: 1 + num_train_workers: 10 + num_frames: 8 + max_num_objects: 3 + base_lr: 5.0e-6 + vision_lr: 3.0e-06 + phases_per_epoch: 1 + num_epochs: 40 + +dataset: + # PATHS to Dataset + img_folder: null # PATH to MOSE JPEGImages folder + gt_folder: null # PATH to MOSE Annotations folder + file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training + multiplier: 2 + +# Video transforms +vos: + train_transforms: + - _target_: training.dataset.transforms.ComposeAPI + transforms: + - _target_: training.dataset.transforms.RandomHorizontalFlip + consistent_transform: True + - _target_: training.dataset.transforms.RandomAffine + degrees: 25 + shear: 20 + image_interpolation: bilinear + consistent_transform: True + - _target_: training.dataset.transforms.RandomResizeAPI + sizes: ${scratch.resolution} + square: true + consistent_transform: True + - _target_: training.dataset.transforms.ColorJitter + consistent_transform: True + brightness: 0.1 + contrast: 0.03 + saturation: 0.03 + hue: null + - _target_: training.dataset.transforms.RandomGrayscale + p: 0.05 + consistent_transform: True + - _target_: training.dataset.transforms.ColorJitter + consistent_transform: False + brightness: 0.1 + contrast: 0.05 + saturation: 0.05 + hue: null + - _target_: training.dataset.transforms.ToTensorAPI + - _target_: training.dataset.transforms.NormalizeAPI + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +trainer: + _target_: training.trainer.Trainer + mode: train_only + max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} + accelerator: cuda + seed_value: 123 + + model: + _target_: training.model.sam2.SAM2Train + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + drop_path_rate: 0.1 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: ${scratch.resolution} + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # compile_image_encoder: False + + ####### Training specific params ####### + # box/point input and corrections + prob_to_use_pt_input_for_train: 0.5 + prob_to_use_pt_input_for_eval: 0.0 + prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points + prob_to_use_box_input_for_eval: 0.0 + prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors + num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame) + num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame + rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2 + add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame) + # maximum 2 initial conditioning frames + num_init_cond_frames_for_train: 2 + rand_init_cond_frames_for_train: True # random 1~2 + num_correction_pt_per_frame: 7 + use_act_ckpt_iterative_pt_sampling: false + + + + num_init_cond_frames_for_eval: 1 # only mask on the first frame + forward_backbone_per_frame_for_eval: True + + + data: + train: + _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset + phases_per_epoch: ${scratch.phases_per_epoch} + batch_sizes: + - ${scratch.train_batch_size} + + datasets: + - _target_: training.dataset.utils.RepeatFactorWrapper + dataset: + _target_: training.dataset.utils.ConcatDataset + datasets: + - _target_: training.dataset.vos_dataset.VOSDataset + transforms: ${vos.train_transforms} + training: true + video_dataset: + _target_: training.dataset.vos_raw_dataset.PNGRawDataset + img_folder: ${dataset.img_folder} + gt_folder: ${dataset.gt_folder} + file_list_txt: ${dataset.file_list_txt} + sampler: + _target_: training.dataset.vos_sampler.RandomUniformSampler + num_frames: ${scratch.num_frames} + max_num_objects: ${scratch.max_num_objects} + multiplier: ${dataset.multiplier} + shuffle: True + num_workers: ${scratch.num_train_workers} + pin_memory: True + drop_last: True + collate_fn: + _target_: training.utils.data_utils.collate_fn + _partial_: true + dict_key: all + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + optimizer: + _target_: torch.optim.AdamW + + gradient_clip: + _target_: training.optimizer.GradientClipper + max_norm: 0.1 + norm_type: 2 + + param_group_modifiers: + - _target_: training.optimizer.layer_decay_param_modifier + _partial_: True + layer_decay_value: 0.9 + apply_to: 'image_encoder.trunk' + overrides: + - pattern: '*pos_embed*' + value: 1.0 + + options: + lr: + - scheduler: + _target_: fvcore.common.param_scheduler.CosineParamScheduler + start_value: ${scratch.base_lr} + end_value: ${divide:${scratch.base_lr},10} + - scheduler: + _target_: fvcore.common.param_scheduler.CosineParamScheduler + start_value: ${scratch.vision_lr} + end_value: ${divide:${scratch.vision_lr},10} + param_names: + - 'image_encoder.*' + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.1 + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.0 + param_names: + - '*bias*' + module_cls_names: ['torch.nn.LayerNorm'] + + loss: + all: + _target_: training.loss_fns.MultiStepMultiMasksAndIous + weight_dict: + loss_mask: 20 + loss_dice: 1 + loss_iou: 1 + loss_class: 1 + supervise_all_iou: true + iou_use_l1_loss: true + pred_obj_scores: true + focal_gamma_obj_score: 0.0 + focal_alpha_obj_score: -1.0 + + distributed: + backend: nccl + find_unused_parameters: True + + logging: + tensorboard_writer: + _target_: training.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + log_dir: ${launcher.experiment_log_dir}/logs + log_freq: 10 + + # initialize from a SAM 2 checkpoint + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + model_weight_initializer: + _partial_: True + _target_: training.utils.checkpoint_utils.load_state_dict_into_model + strict: True + ignore_unexpected_keys: null + ignore_missing_keys: null + + state_dict: + _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels + checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint + ckpt_state_dict_keys: ['model'] + +launcher: + num_nodes: 1 + gpus_per_node: 8 + experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} + +# SLURM args if running on a cluster +submitit: + partition: null + account: null + qos: null + cpus_per_task: 10 + use_cluster: false + timeout_hour: 24 + name: null + port_range: [10000, 65000] + diff --git a/sam2/configs/sam2/sam2_hiera_b+.yaml b/sam2/configs/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9fdcfa4054b6d03a159c4fad01515fd3153d23d8 --- /dev/null +++ b/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2/sam2_hiera_l.yaml b/sam2/configs/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5ef6cc3a4c7656f2791d12575b70b3dbb665bb25 --- /dev/null +++ b/sam2/configs/sam2/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2/sam2_hiera_s.yaml b/sam2/configs/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6ebeeae747874ba1938ffdf69876202e7a98c0a --- /dev/null +++ b/sam2/configs/sam2/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2/sam2_hiera_t.yaml b/sam2/configs/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c0b3a36094f1cee9b6d320eacd1c5774e019fb2 --- /dev/null +++ b/sam2/configs/sam2/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/sam2/csrc/connected_components.cu b/sam2/csrc/connected_components.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e3fbee0eba762c7198ace220660ceadd13d7402 --- /dev/null +++ b/sam2/csrc/connected_components.cu @@ -0,0 +1,289 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. + +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// adapted from https://github.com/zsef123/Connected_components_PyTorch +// with license found in the LICENSE_cctorch file in the root directory. +#include +#include +#include +#include +#include +#include + +// 2d +#define BLOCK_ROWS 16 +#define BLOCK_COLS 16 + +namespace cc2d { + +template +__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { + return (bitmap >> pos) & 1; +} + +__device__ int32_t find(const int32_t* s_buf, int32_t n) { + while (s_buf[n] != n) + n = s_buf[n]; + return n; +} + +__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { + const int32_t id = n; + while (s_buf[n] != n) { + n = s_buf[n]; + s_buf[id] = n; + } + return n; +} + +__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { + bool done; + do { + a = find(s_buf, a); + b = find(s_buf, b); + + if (a < b) { + int32_t old = atomicMin(s_buf + b, a); + done = (old == b); + b = old; + } else if (b < a) { + int32_t old = atomicMin(s_buf + a, b); + done = (old == a); + a = old; + } else + done = true; + + } while (!done); +} + +__global__ void +init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + label[idx] = idx; +} + +__global__ void +merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + uint32_t P = 0; + + if (img[idx]) + P |= 0x777; + if (row + 1 < H && img[idx + W]) + P |= 0x777 << 4; + if (col + 1 < W && img[idx + 1]) + P |= 0x777 << 1; + + if (col == 0) + P &= 0xEEEE; + if (col + 1 >= W) + P &= 0x3333; + else if (col + 2 >= W) + P &= 0x7777; + + if (row == 0) + P &= 0xFFF0; + if (row + 1 >= H) + P &= 0xFF; + + if (P > 0) { + // If need check about top-left pixel(if flag the first bit) and hit the + // top-left pixel + if (hasBit(P, 0) && img[idx - W - 1]) { + union_(label, idx, idx - 2 * W - 2); // top left block + } + + if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) + union_(label, idx, idx - 2 * W); // top bottom block + + if (hasBit(P, 3) && img[idx + 2 - W]) + union_(label, idx, idx - 2 * W + 2); // top right block + + if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) + union_(label, idx, idx - 2); // just left block + } +} + +__global__ void compression(int32_t* label, const int32_t W, const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + find_n_compress(label, idx); +} + +__global__ void final_labeling( + const uint8_t* img, + int32_t* label, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx] + 1; + + if (img[idx]) + label[idx] = y; + else + label[idx] = 0; + + if (col + 1 < W) { + if (img[idx + 1]) + label[idx + 1] = y; + else + label[idx + 1] = 0; + + if (row + 1 < H) { + if (img[idx + W + 1]) + label[idx + W + 1] = y; + else + label[idx + W + 1] = 0; + } + } + + if (row + 1 < H) { + if (img[idx + W]) + label[idx + W] = y; + else + label[idx + W] = 0; + } +} + +__global__ void init_counting( + const int32_t* label, + int32_t* count_init, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + atomicAdd(count_init + count_idx, 1); + } +} + +__global__ void final_counting( + const int32_t* label, + const int32_t* count_init, + int32_t* count_final, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + count_final[idx] = count_init[count_idx]; + } else { + count_final[idx] = 0; + } +} + +} // namespace cc2d + +std::vector get_connected_componnets( + const torch::Tensor& inputs) { + AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); + AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM( + inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); + + const uint32_t N = inputs.size(0); + const uint32_t C = inputs.size(1); + const uint32_t H = inputs.size(2); + const uint32_t W = inputs.size(3); + + AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM((H % 2) == 0, "height must be an even number"); + AT_ASSERTM((W % 2) == 0, "width must be an even number"); + + // label must be uint32_t + auto label_options = + torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); + torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); + + dim3 grid = dim3( + ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, + ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); + dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); + dim3 grid_count = + dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); + dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + for (int n = 0; n < N; n++) { + uint32_t offset = n * H * W; + + cc2d::init_labeling<<>>( + labels.data_ptr() + offset, W, H); + cc2d::merge<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + cc2d::compression<<>>( + labels.data_ptr() + offset, W, H); + cc2d::final_labeling<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + + // get the counting of each pixel + cc2d::init_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + W, + H); + cc2d::final_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + counts_final.data_ptr() + offset, + W, + H); + } + + // returned values are [labels, counts] + std::vector outputs; + outputs.push_back(labels); + outputs.push_back(counts_final); + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "get_connected_componnets", + &get_connected_componnets, + "get_connected_componnets"); +} diff --git a/sam2/modeling/__init__.py b/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b085341957f3709b01eb2dfb0df17f03f394a3fe Binary files /dev/null and b/sam2/modeling/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc b/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2acb8c8fe8fd43e9690ac88a6891d89b2b28cf8 Binary files /dev/null and b/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc b/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b86f076f762a4881e13883f3e6dd1a249c1a0d2c Binary files /dev/null and b/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc b/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2331f2874f648a9fab0b282e4b70a610e8fedfeb Binary files /dev/null and b/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc b/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d73d8c51879e7ca944114d7c03de6c1691a3b707 Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc b/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dbd5bc12bc4141a251afcfa4c849935d0ea4fe1 Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__init__.py b/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81d2677b2f153197365c386135a4f87b32f8ea54 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9fe1e80d0731353f1583a9b2e6c45ba7d3c489c Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d319b0d534c0c9bea592f50c12096bfc5f4d38bf Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2fca4b634472f56ace6e0d4add3a8c4ad8cc2e9 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..cb57bf745c5ab973245d8f52020823043895c340 --- /dev/null +++ b/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr + +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + weights_path=None, + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/sam2/modeling/backbones/image_encoder.py b/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..25853dff2add5afc6abe704e9f00aba1d518cfd0 --- /dev/null +++ b/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + self.d_model = d_model + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b7807b275256f83a83e5d1baa6c045ad6c124807 --- /dev/null +++ b/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4adb5cf0a335f7013835dd31e3863b6b04e738 --- /dev/null +++ b/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from sam2.modeling.sam.transformer import RoPEAttention + +from sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..83f98a2544b225f5bdb6e9a046380b8df5887a30 --- /dev/null +++ b/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..75aac3cb4f113352fb385ef0a25aec97833cd34a --- /dev/null +++ b/sam2/modeling/position_encoding.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/sam2/modeling/sam/__init__.py b/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86fe6e81605fee0ac58953ca0afe29a87c2093f1 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1156893d0d7b754c0ccba204aa45e1530dcfac3e Binary files /dev/null and b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc differ diff --git a/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18fe45445517b8c34c4f67970a005a97c21a38df Binary files /dev/null and b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc differ diff --git a/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc b/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbf1212c7720b3409d538a437b97bffefd900e19 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc differ diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e2cb927dd691f1e07c1db34ed380ff305a0dd3c5 --- /dev/null +++ b/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..91d9952ca8078bedd04fdc2ea0d900529e432528 --- /dev/null +++ b/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.position_encoding import PositionEmbeddingRandom + +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..252773b023905c4a11a53fa479daf4e259b6f70c --- /dev/null +++ b/sam2/modeling/sam/transformer.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from sam2.modeling.sam2_utils import MLP +from sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f4305d04eb07a8ab51ac52472bdab6f2d08c5adf --- /dev/null +++ b/sam2/modeling/sam2_base.py @@ -0,0 +1,907 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam.mask_decoder import MaskDecoder +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer +from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = image_encoder.neck.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/sam2/modeling/sam2_utils.py b/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5634bc5fa1df56d7fd5b94e75ca3226316df5128 --- /dev/null +++ b/sam2/modeling/sam2_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.utils.misc import mask_to_box + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/sam2/sam2_hiera_b+.yaml b/sam2/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..998d9c98c9ff4e8ddd55deff72aa0d9067977418 --- /dev/null +++ b/sam2/sam2_hiera_b+.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_b+.yaml \ No newline at end of file diff --git a/sam2/sam2_hiera_l.yaml b/sam2/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0e7e58e1951d5c55a3a3ebe6b803dd814cf9d86 --- /dev/null +++ b/sam2/sam2_hiera_l.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_l.yaml \ No newline at end of file diff --git a/sam2/sam2_hiera_s.yaml b/sam2/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41896a26beb2aa831d18b0bf3c349ed43deeef68 --- /dev/null +++ b/sam2/sam2_hiera_s.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_s.yaml \ No newline at end of file diff --git a/sam2/sam2_hiera_t.yaml b/sam2/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71ff3abbb1e11f8b82100a0a1d63cb267eefe52a --- /dev/null +++ b/sam2/sam2_hiera_t.yaml @@ -0,0 +1 @@ +configs/sam2/sam2_hiera_t.yaml \ No newline at end of file diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..93d25d5e5597a71f9fe133e241c8c710233091d5 --- /dev/null +++ b/sam2/sam2_image_predictor.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base + +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + **kwargs, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tuple of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcc2f8b48e2521a3429b38af94d2bc5049f7d2f --- /dev/null +++ b/sam2/sam2_video_predictor.py @@ -0,0 +1,1172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ + "object_score_logits" + ] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + object_score_logits=consolidated_out["object_score_logits"], + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + "object_score_logits": current_out["object_score_logits"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def clear_all_prompts_in_frame( + self, inference_state, frame_idx, obj_id, need_output=True + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = self._get_obj_num(inference_state) + frame_has_input = False + for obj_idx2 in range(batch_size): + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object(self, inference_state, obj_id, strict=False, need_output=True): + """ + Remove an object id from the tracking state. If strict is True, we check whether + the object id actually exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + updated_frames = [] + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"], updated_frames + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.reset_state(inference_state) + return inference_state["obj_ids"], updated_frames + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_prompts_in_frame( + inference_state, frame_idx, obj_id, need_output=False + ) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] + out["maskmem_pos_enc"] = [ + x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] + ] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) + out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] + out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] + out["object_score_logits"] = out["object_score_logits"][ + remain_old_obj_inds + ] + # also update the per-object slices + self._add_output_per_object( + inference_state, frame_idx, out, storage_key + ) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/sam2/utils/__init__.py b/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/utils/__pycache__/__init__.cpython-311.pyc b/sam2/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b3b50c160247e6dff38ad5d21ff8ad618c162c Binary files /dev/null and b/sam2/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/utils/__pycache__/misc.cpython-311.pyc b/sam2/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f652e958c32dd43a39432408ce66464b63937221 Binary files /dev/null and b/sam2/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/sam2/utils/__pycache__/transforms.cpython-311.pyc b/sam2/utils/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69a2e1abaf103402e95d4cc74bef711622f8df2e Binary files /dev/null and b/sam2/utils/__pycache__/transforms.cpython-311.pyc differ diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..08a56abdec44fef5539f97c0f532fab1e3aefe68 --- /dev/null +++ b/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8951732899fa048f276dbf7f9f67208c9049bf01 --- /dev/null +++ b/sam2/utils/misc.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8ceaf0346ac7721336b65791eeeebcc92f18de --- /dev/null +++ b/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/static/.DS_Store b/static/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..374388b3f4d13894b475f96b15f944145e8a1939 Binary files /dev/null and b/static/.DS_Store differ diff --git a/static/css/styles.css b/static/css/styles.css new file mode 100644 index 0000000000000000000000000000000000000000..a1da6afee50b4f19cb9676362b219bd3a698f8d2 --- /dev/null +++ b/static/css/styles.css @@ -0,0 +1,292 @@ +/* Import Google Fonts */ +@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap'); + +/* General Body Styling */ +body { + font-family: 'Inter', sans-serif; + background-color: #f4f6f8; + margin: 0; + padding: 0; + min-height: 100vh; +} + +/* Container Styling */ +.container { + max-width: 1400px; /* Widened to better fit the boxes */ + background: white; + padding: 30px; + border-radius: 16px; + box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1); + margin: 20px auto; +} + +/* Section Headers */ +h1, h2 { + font-size: 2rem; + font-weight: 700; + color: #333; + text-align: center; + margin-bottom: 20px; +} + +/* Tool Sections */ +.tool-section { + background: #fff; + border: 1px solid #ddd; + border-radius: 16px; + padding: 20px; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + margin-bottom: 30px; + width: 100%; + max-width: 1200px; /* Increased to align with SAM boxes */ + margin: 20px auto; +} + +/* Buttons Styling */ +.btn-group { + display: flex; + justify-content: center; + margin-bottom: 20px; + gap: 10px; +} + +.btn-group button { + font-size: 16px; + padding: 10px 20px; + border-radius: 8px; + cursor: pointer; + transition: all 0.3s ease-in-out; + width: 50%; /* Equal button width */ +} + +.btn-group button.active { + background-color: #007bff; + color: white; + border: 1px solid #0056b3; +} + +.btn-group button:hover { + background-color: #0056b3; + color: white; +} + +/* Tool Layout */ +.tool-container { + display: flex; + flex-direction: column; + gap: 40px; + align-items: center; +} + +/* Canvas Containers */ +.canvas-container { + width: 100%; + max-width: 600px; /* Adjusted for better alignment */ + aspect-ratio: 1 / 1; + border: 2px solid #ddd; + border-radius: 16px; + background-color: #fff; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + display: flex; + justify-content: center; + align-items: center; + padding: 10px; + margin: 0 auto; +} + +/* File Upload Fields */ +input[type="file"] { + margin-top: 20px; + padding: 10px; + font-size: 16px; + border-radius: 8px; + border: 1px solid #ddd; + box-shadow: inset 0px 1px 3px rgba(0, 0, 0, 0.1); + width: 100%; + max-width: 600px; +} + +input[type="file"]:focus { + border: 1px solid #007bff; + outline: none; +} + +/* Processed Image Styling */ +#automaticProcessedImage { + border: 2px solid #ddd; + border-radius: 16px; + background-color: #fff; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + width: 100%; + height: auto; + margin-top: 20px; +} + +/* Clear Points Button */ +#clearPoints { + margin-top: 20px; + font-size: 14px; + padding: 10px 15px; + border-radius: 5px; + color: white; + background-color: #dc3545; + border: none; + cursor: pointer; +} + +#clearPoints:hover { + background-color: #b52b37; +} + +/* Table Styling */ +.table { + border: 1px solid #ddd; + border-radius: 10px; + overflow: hidden; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); + width: 100%; +} + +.table th { + background-color: #007bff; /* Blue column headers */ + color: white; + font-weight: bold; + text-align: center; + padding: 10px; +} + +.table td { + padding: 10px; + vertical-align: middle; + background-color: #f9f9f9; + text-align: center; +} + +.table-responsive { + border-radius: 10px; + overflow: hidden; +} + +/* Buttons Styling */ +#clearTableButton, #exportTableButton { + font-size: 16px; + padding: 10px 20px; + border-radius: 8px; + color: white; + background-color: #007bff; + border: none; + transition: all 0.3s ease-in-out; +} + +#clearTableButton:hover, #exportTableButton:hover { + background-color: #0056b3; + cursor: pointer; +} + +/* Responsive Design */ +@media screen and (max-width: 768px) { + .tool-section { + flex-direction: column; + align-items: center; + } + + .canvas-container { + margin-bottom: 20px; + } + + .btn-group button { + width: 100%; + } + + .table { + font-size: 14px; + } +} + +/* Additional Adjustments */ +.tool-section img, canvas { + max-width: 100%; + height: auto; + border-radius: 8px; +} + +#historyButton { + font-size: 16px; + padding: 10px 20px; + border-radius: 8px; + color: white; + background-color: #6c757d; /* Bootstrap secondary color */ + border: none; +} + +#historyButton:hover { + background-color: #5a6268; + cursor: pointer; +} + +/* Modal Body Styling */ +.modal-body { + padding: 15px; + width: 100%; + max-width: 600px; /* Adjust width to accommodate delete buttons */ + margin: auto; + display: flex; + flex-direction: column; + align-items: center; +} +/* History Modal List Styling */ +#historyList { + padding: 0; + margin: 0; + list-style: none; /* Remove default bullets */ + width: 100%; /* Expand list width */ +} + +/* List Group Item Styling */ +#historyList .list-group-item { + display: flex; /* Flexbox for alignment */ + justify-content: space-between; /* Space between filename and delete button */ + align-items: center; /* Vertically align items */ + padding: 10px 15px; /* Add consistent padding */ + border: 1px solid #ddd; /* Optional: border for clarity */ + border-radius: 8px; /* Rounded corners */ + background-color: #fff; /* White background for contrast */ + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); /* Subtle shadow for depth */ + margin-bottom: 5px; /* Space between list items */ + width: 100%; /* Ensure full width inside modal */ +} + +/* Filename Styling */ +#historyList .filename { + flex-grow: 1; /* Allow filename to take up available space */ + text-overflow: ellipsis; /* Truncate long names */ + white-space: nowrap; /* Prevent wrapping */ + overflow: hidden; /* Hide overflowing text */ + padding-right: 15px; /* Space between filename and delete button */ +} + +/* Delete Button Styling */ +#historyList .btn-danger { + flex-shrink: 0; /* Prevent button from shrinking */ + padding: 5px 15px; + font-size: 14px; + border-radius: 8px; /* Rounded corners */ + background-color: #dc3545; /* Standard Bootstrap danger color */ + border: none; + transition: all 0.3s ease-in-out; +} + +#historyList .btn-danger:hover { + background-color: #b52b37; /* Slightly darker red on hover */ + color: white; + cursor: pointer; +} + +/* Modal Adjustments */ +.modal-dialog { + max-width: 700px; /* Wider modal dialog to fit the expanded list and delete buttons */ +} + +.modal-content { + padding: 10px; +} diff --git a/static/history/images/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg b/static/history/images/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac452b8c15bf70a138cbc334e670fb75f2c4900e Binary files /dev/null and b/static/history/images/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg differ diff --git a/static/history/images/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg b/static/history/images/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4cb1ec47ff7e172bcd876582462f9667c0118c19 Binary files /dev/null and b/static/history/images/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg differ diff --git a/static/history/images/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg b/static/history/images/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3612ea8bad3b881d64ba82a7bcb3734fd0d3c301 Binary files /dev/null and b/static/history/images/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg differ diff --git a/static/history/images/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg b/static/history/images/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f7769c2afbed542791dc7aa6c99b8f731b993c3d Binary files /dev/null and b/static/history/images/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg differ diff --git a/static/history/images/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg b/static/history/images/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a23e88255583db9dad0ae7d84c04aee4a17d7a1 Binary files /dev/null and b/static/history/images/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg differ diff --git a/static/history/images/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg b/static/history/images/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg new file mode 100644 index 0000000000000000000000000000000000000000..368daef342ad0515e373f0c051f72717bdeb2d47 Binary files /dev/null and b/static/history/images/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg differ diff --git a/static/history/images/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg b/static/history/images/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1e24609b75ed127858eb6fb66eba9d2274a09917 Binary files /dev/null and b/static/history/images/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg differ diff --git a/static/history/images/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg b/static/history/images/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg new file mode 100644 index 0000000000000000000000000000000000000000..860d0363ad46861aea837591d26295477a608eac Binary files /dev/null and b/static/history/images/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg differ diff --git a/static/history/images/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg b/static/history/images/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b357e7c557db3b68472b6f40accc04f469f33dd Binary files /dev/null and b/static/history/images/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg differ diff --git a/static/history/images/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg b/static/history/images/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9182db336422bdf0ac364af9fad89d9e0aa7ff4f Binary files /dev/null and b/static/history/images/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg differ diff --git a/static/history/masks/chip/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png b/static/history/masks/chip/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png new file mode 100644 index 0000000000000000000000000000000000000000..cfe59a8e858cb63a75e133f959eabb1ebdc0c20b Binary files /dev/null and b/static/history/masks/chip/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png differ diff --git a/static/history/masks/chip/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png b/static/history/masks/chip/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png new file mode 100644 index 0000000000000000000000000000000000000000..9d8126856e3b0a8cbdd32ab02489700065839c4d Binary files /dev/null and b/static/history/masks/chip/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png differ diff --git a/static/history/masks/chip/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png b/static/history/masks/chip/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png new file mode 100644 index 0000000000000000000000000000000000000000..b54dd093ebc3753379df1d30325ee7232f1a653d Binary files /dev/null and b/static/history/masks/chip/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png differ diff --git a/static/history/masks/chip/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png b/static/history/masks/chip/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png new file mode 100644 index 0000000000000000000000000000000000000000..e4dbc826edcda87f39dd1998176e2162c379058c Binary files /dev/null and b/static/history/masks/chip/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png differ diff --git a/static/history/masks/chip/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png b/static/history/masks/chip/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png new file mode 100644 index 0000000000000000000000000000000000000000..b1159aa308dd01a7b4ceb1c15d3fcc27f622f193 Binary files /dev/null and b/static/history/masks/chip/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png differ diff --git a/static/history/masks/chip/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png b/static/history/masks/chip/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png new file mode 100644 index 0000000000000000000000000000000000000000..685eb010a7a3ed2820a583c6f79669735e8d2a33 Binary files /dev/null and b/static/history/masks/chip/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png differ diff --git a/static/history/masks/chip/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png b/static/history/masks/chip/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png new file mode 100644 index 0000000000000000000000000000000000000000..8e1c3e1a7e9af6d64b521e132eb24145e5b89bdb Binary files /dev/null and b/static/history/masks/chip/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png differ diff --git a/static/history/masks/chip/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png b/static/history/masks/chip/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png new file mode 100644 index 0000000000000000000000000000000000000000..49af8a089a365164655cb69ee999698285b16b31 Binary files /dev/null and b/static/history/masks/chip/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png differ diff --git a/static/history/masks/chip/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png b/static/history/masks/chip/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png new file mode 100644 index 0000000000000000000000000000000000000000..a9ba5fccb3f7f3eb5b2ce8a556812566f7d1770e Binary files /dev/null and b/static/history/masks/chip/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png differ diff --git a/static/history/masks/chip/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png b/static/history/masks/chip/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png new file mode 100644 index 0000000000000000000000000000000000000000..085353f59f871492391434d027b1f664604dcfc7 Binary files /dev/null and b/static/history/masks/chip/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png differ diff --git a/static/history/masks/void/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png b/static/history/masks/void/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png new file mode 100644 index 0000000000000000000000000000000000000000..e6fc92f551a7d60c65f2df142a541bec91870f5d Binary files /dev/null and b/static/history/masks/void/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png differ diff --git a/static/history/masks/void/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png b/static/history/masks/void/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png new file mode 100644 index 0000000000000000000000000000000000000000..a2354e965cc811518a03465c31af0c198869fd95 Binary files /dev/null and b/static/history/masks/void/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png differ diff --git a/static/history/masks/void/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png b/static/history/masks/void/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png new file mode 100644 index 0000000000000000000000000000000000000000..94ec31dbc470a83ad06750cbfc1dadacdae9e923 Binary files /dev/null and b/static/history/masks/void/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png differ diff --git a/static/history/masks/void/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png b/static/history/masks/void/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png new file mode 100644 index 0000000000000000000000000000000000000000..b373067ac59737c435e4eada9171f4498416a89f Binary files /dev/null and b/static/history/masks/void/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png differ diff --git a/static/history/masks/void/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png b/static/history/masks/void/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png new file mode 100644 index 0000000000000000000000000000000000000000..3ccdfc5a7e3612a37167a117828aa10b31bb2322 Binary files /dev/null and b/static/history/masks/void/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png differ diff --git a/static/history/masks/void/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png b/static/history/masks/void/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png new file mode 100644 index 0000000000000000000000000000000000000000..0151180865d7b664d7ffa51db606086a16f6c889 Binary files /dev/null and b/static/history/masks/void/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png differ diff --git a/static/history/masks/void/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png b/static/history/masks/void/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png new file mode 100644 index 0000000000000000000000000000000000000000..a1c195f11edd403f3dabac42760e9049808c51e7 Binary files /dev/null and b/static/history/masks/void/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png differ diff --git a/static/history/masks/void/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png b/static/history/masks/void/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png new file mode 100644 index 0000000000000000000000000000000000000000..ac8e0a95d726259518793040157782d8ce667656 Binary files /dev/null and b/static/history/masks/void/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png differ diff --git a/static/history/masks/void/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png b/static/history/masks/void/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png new file mode 100644 index 0000000000000000000000000000000000000000..cd9732adb958f28969d2d88c1fe8af3bec378f12 Binary files /dev/null and b/static/history/masks/void/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png differ diff --git a/static/history/masks/void/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png b/static/history/masks/void/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png new file mode 100644 index 0000000000000000000000000000000000000000..13afab24243cbfaf4c21271b50bc0d41c2b38b51 Binary files /dev/null and b/static/history/masks/void/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png differ diff --git a/static/js/app.js b/static/js/app.js new file mode 100644 index 0000000000000000000000000000000000000000..7197eb3451c1b7aecdaa5bc2bdc7e006963e2034 --- /dev/null +++ b/static/js/app.js @@ -0,0 +1,689 @@ +// Interactive Segmentation DOM elements +const inputCanvas = document.getElementById('inputCanvas'); +const segmentedCanvas = document.getElementById('segmentedCanvas'); +const imageUpload = document.getElementById('imageUpload'); +const clearPointsButton = document.getElementById('clearPoints'); +const voidsButton = document.getElementById('voidsButton'); +const chipsButton = document.getElementById('chipsButton'); +const retrainModelButton = document.getElementById('retrainModelButton'); +const etaDisplay = document.getElementById('etaDisplay'); + +// Automatic Segmentation DOM elements +const automaticImageUpload = document.getElementById('automaticImageUpload'); +const automaticProcessedImage = document.getElementById('automaticProcessedImage'); +const resultsTableBody = document.getElementById('resultsTableBody'); +const clearTableButton = document.getElementById('clearTableButton'); +const exportTableButton = document.getElementById('exportTableButton'); + +// Constants for consistent canvas and SAM model dimensions +const CANVAS_SIZE = 512; +inputCanvas.width = CANVAS_SIZE; +inputCanvas.height = CANVAS_SIZE; +segmentedCanvas.width = CANVAS_SIZE; +segmentedCanvas.height = CANVAS_SIZE; + +// Interactive segmentation variables +let points = { Voids: [], Chips: [] }; +let labels = { Voids: [], Chips: [] }; +let currentClass = 'Voids'; +let imageUrl = ''; +let originalImageWidth = 0; +let originalImageHeight = 0; +let trainingInProgress = false; + +// Disable right-click menu on canvas +inputCanvas.addEventListener('contextmenu', (event) => event.preventDefault()); + +// Switch between classes +voidsButton.addEventListener('click', () => { + currentClass = 'Voids'; + voidsButton.classList.add('active'); + chipsButton.classList.remove('active'); + clearAndRestorePoints(); +}); + +chipsButton.addEventListener('click', () => { + currentClass = 'Chips'; + chipsButton.classList.add('active'); + voidsButton.classList.remove('active'); + clearAndRestorePoints(); +}); + +// Handle image upload for interactive tool +imageUpload.addEventListener('change', async (event) => { + const file = event.target.files[0]; + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await fetch('/upload', { method: 'POST', body: formData }); + const data = await response.json(); + if (data.error) { + console.error('Error uploading image:', data.error); + return; + } + + imageUrl = data.image_url; + console.log('Uploaded image URL:', imageUrl); + + const img = new Image(); + img.src = imageUrl; + img.onload = () => { + console.log('Image loaded:', img.width, img.height); + originalImageWidth = img.width; + originalImageHeight = img.height; + resizeAndDrawImage(inputCanvas, img); + resizeAndDrawImage(segmentedCanvas, img); + }; + img.onerror = () => { + console.error('Failed to load image from URL:', imageUrl); + }; + } catch (error) { + console.error('Failed to upload image:', error); + } +}); + + +// Handle input canvas clicks +inputCanvas.addEventListener('mousedown', async (event) => { + const rect = inputCanvas.getBoundingClientRect(); + const x = (event.clientX - rect.left) * (originalImageWidth / CANVAS_SIZE); + const y = (event.clientY - rect.top) * (originalImageHeight / CANVAS_SIZE); + + if (event.button === 2) { + points[currentClass].push([x, y]); + labels[currentClass].push(0); // Exclude point (red) + } else if (event.button === 0) { + points[currentClass].push([x, y]); + labels[currentClass].push(1); // Include point (green) + } + + drawPoints(); + await updateSegmentation(); +}); + +// Clear points for current class +clearPointsButton.addEventListener('click', () => { + points[currentClass] = []; + labels[currentClass] = []; + drawPoints(); + resetSegmentation(); +}); + +function resizeAndDrawImage(canvas, img) { + const ctx = canvas.getContext('2d'); + ctx.clearRect(0, 0, canvas.width, canvas.height); // Clear the canvas + + // Scale the image to fit within the canvas + const scale = Math.min(canvas.width / img.width, canvas.height / img.height); + const x = (canvas.width - img.width * scale) / 2; + const y = (canvas.height - img.height * scale) / 2; + + ctx.drawImage(img, x, y, img.width * scale, img.height * scale); +} + + +// Draw points on canvases +function drawPoints() { + [inputCanvas, segmentedCanvas].forEach((canvas) => { + const ctx = canvas.getContext('2d'); + ctx.clearRect(0, 0, CANVAS_SIZE, CANVAS_SIZE); + + const img = new Image(); + img.src = imageUrl; + img.onload = () => { + resizeAndDrawImage(canvas, img); + + points[currentClass].forEach(([x, y], i) => { + const scaledX = x * (CANVAS_SIZE / originalImageWidth); + const scaledY = y * (CANVAS_SIZE / originalImageHeight); + ctx.beginPath(); + ctx.arc(scaledX, scaledY, 5, 0, 2 * Math.PI); + ctx.fillStyle = labels[currentClass][i] === 1 ? 'green' : 'red'; + ctx.fill(); + }); + }; + img.onerror = () => { + console.error('Error loading image for canvas:', img.src); + }; + }); +} + +async function updateSegmentation() { + try { + const response = await fetch('/segment', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ points: points[currentClass], labels: labels[currentClass], class: currentClass.toLowerCase() }) + }); + + const data = await response.json(); + + if (data.error) { + console.error('Error during segmentation:', data.error); + alert(`Segmentation error: ${data.error}`); + return; + } + + console.log('Segmentation result:', data); + + const img = new Image(); + img.src = `${data.segmented_url}?t=${new Date().getTime()}`; // Add timestamp to prevent caching + img.onload = () => { + console.log('Segmented image loaded successfully:', img.src); + resizeAndDrawImage(segmentedCanvas, img); // Render the segmented image + }; + img.onerror = () => { + console.error('Failed to load segmented image:', img.src); + alert('Failed to load the segmented image.'); + }; + } catch (error) { + console.error('Error updating segmentation:', error); + alert('Failed to process segmentation.'); + } +} + +// Reset segmented canvas +function resetSegmentation() { + const ctx = segmentedCanvas.getContext('2d'); + ctx.clearRect(0, 0, CANVAS_SIZE, CANVAS_SIZE); + const img = new Image(); + img.src = imageUrl; + img.onload = () => resizeAndDrawImage(segmentedCanvas, img); +} + +// Handle automatic segmentation +automaticImageUpload.addEventListener('change', async (event) => { + const file = event.target.files[0]; + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await fetch('/automatic_segment', { method: 'POST', body: formData }); + const data = await response.json(); + if (data.error) return console.error('Error during automatic segmentation:', data.error); + + // Display the processed image + const processedImage = document.getElementById('automaticProcessedImage'); + processedImage.src = `${data.segmented_url}?t=${new Date().getTime()}`; + processedImage.style.display = 'block'; + + // Optionally append the table data + appendRowToTable(data.table_data); + } catch (error) { + console.error('Failed to process image automatically:', error); + } +}); + +function appendRowToTable(tableData) { + // Remove duplicates based on the image name and chip number + const existingRows = Array.from(resultsTableBody.querySelectorAll('tr')); + const existingIdentifiers = existingRows.map(row => { + const cells = row.querySelectorAll('td'); + return `${cells[0]?.textContent}_${cells[1]?.textContent}`; // Combine Image Name and Chip # + }); + + tableData.chips.forEach((chip, index) => { + const uniqueId = `${tableData.image_name}_${chip.chip_number}`; + if (existingIdentifiers.includes(uniqueId)) return; // Skip if already present + + const row = document.createElement('tr'); + + // Image Name (unchanged for each chip) + const imageNameCell = document.createElement('td'); + imageNameCell.textContent = tableData.image_name; + row.appendChild(imageNameCell); + + // Chip # (1, 2, etc.) + const chipNumberCell = document.createElement('td'); + chipNumberCell.textContent = chip.chip_number; + row.appendChild(chipNumberCell); + + // Chip Area + const chipAreaCell = document.createElement('td'); + chipAreaCell.textContent = chip.chip_area.toFixed(2); + row.appendChild(chipAreaCell); + + // Void % (Total void area / Chip area * 100) + const voidPercentageCell = document.createElement('td'); + voidPercentageCell.textContent = chip.void_percentage.toFixed(2); + row.appendChild(voidPercentageCell); + + // Max Void % (Largest void area / Chip area * 100) + const maxVoidPercentageCell = document.createElement('td'); + maxVoidPercentageCell.textContent = chip.max_void_percentage.toFixed(2); + row.appendChild(maxVoidPercentageCell); + + resultsTableBody.appendChild(row); + }); +} + +// Handle automatic segmentation +automaticImageUpload.addEventListener('change', async (event) => { + const file = event.target.files[0]; + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await fetch('/automatic_segment', { method: 'POST', body: formData }); + const data = await response.json(); + if (data.error) return console.error('Error during automatic segmentation:', data.error); + + automaticProcessedImage.src = `${data.segmented_url}?t=${new Date().getTime()}`; + automaticProcessedImage.style.display = 'block'; + appendRowToTable(data.table_data); // Append new data to the table + } catch (error) { + console.error('Failed to process image automatically:', error); + } +}); + +// Clear table +clearTableButton.addEventListener('click', () => { + resultsTableBody.innerHTML = ''; +}); + +// Export table to CSV +exportTableButton.addEventListener('click', () => { + const rows = Array.from(resultsTableBody.querySelectorAll('tr')); + const csvContent = [ + ['Image Name', 'Chip #', 'Chip Area', 'Void %', 'Max Void %'], + ...rows.map(row => + Array.from(row.children).map(cell => cell.textContent) + ), + ] + .map(row => row.join(',')) + .join('\n'); + + const blob = new Blob([csvContent], { type: 'text/csv' }); + const url = URL.createObjectURL(blob); + const link = document.createElement('a'); + link.href = url; + link.download = 'segmentation_results.csv'; + link.click(); + URL.revokeObjectURL(url); +}); +saveBothButton.addEventListener('click', async () => { + const imageName = imageUrl.split('/').pop(); // Extract the image name from the URL + if (!imageName) { + alert("No image to save."); + return; + } + + const confirmSave = confirm("Are you sure you want to save both voids and chips segmentations?"); + if (!confirmSave) return; + + try { + const response = await fetch('/save_both', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ image_name: imageName }) + }); + const result = await response.json(); + if (response.ok) { + alert(result.message); + } else { + alert("Failed to save segmentations."); + } + } catch (error) { + console.error("Error saving segmentations:", error); + alert("Failed to save segmentations."); + } +}); +// Update the "historyButton" click listener to populate the list correctly +document.getElementById('historyButton').addEventListener('click', async () => { + try { + const response = await fetch('/get_history'); // Fetch the saved history + const result = await response.json(); + + if (response.ok) { + const historyList = document.getElementById('historyList'); + historyList.innerHTML = ''; // Clear the list + + if (result.images.length === 0) { + historyList.innerHTML = '
  • No images found in history.
  • '; + return; + } + + result.images.forEach(image => { + const listItem = document.createElement('li'); + listItem.className = 'list-group-item'; + + const imageName = document.createElement('span'); + imageName.textContent = image; + + const deleteButton = document.createElement('button'); + deleteButton.className = 'btn btn-danger btn-sm'; + deleteButton.textContent = 'Delete'; + deleteButton.addEventListener('click', async () => { + if (confirm(`Are you sure you want to delete ${image}?`)) { + await deleteHistoryItem(image, listItem); + } + }); + + listItem.appendChild(imageName); + listItem.appendChild(deleteButton); + historyList.appendChild(listItem); + }); + + new bootstrap.Modal(document.getElementById('historyModal')).show(); + } else { + alert("Failed to fetch history."); + } + } catch (error) { + console.error("Error fetching history:", error); + alert("Failed to fetch history."); + } +}); + +// Function to delete history item +async function deleteHistoryItem(imageName, listItem) { + try { + const response = await fetch('/delete_history_item', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ image_name: imageName }) + }); + const result = await response.json(); + + if (response.ok) { + alert(result.message); + listItem.remove(); // Remove the item from the list + } else { + alert("Failed to delete image."); + } + } catch (error) { + console.error("Error deleting image:", error); + alert("Failed to delete image."); + } +} + +historyButton.addEventListener('click', async () => { + try { + const response = await fetch('/get_history'); + const result = await response.json(); + + if (response.ok) { + const historyList = document.getElementById('historyList'); + historyList.innerHTML = ''; // Clear the list + + if (result.images.length === 0) { + historyList.innerHTML = '
  • No images found in history.
  • '; + return; + } + + result.images.forEach(image => { + const listItem = document.createElement('li'); + listItem.className = 'list-group-item d-flex justify-content-between align-items-center'; + listItem.textContent = image; + + const deleteButton = document.createElement('button'); + deleteButton.className = 'btn btn-danger btn-sm'; + deleteButton.textContent = 'Delete'; + deleteButton.addEventListener('click', async () => { + if (confirm(`Are you sure you want to delete ${image}?`)) { + await deleteHistoryItem(image, listItem); + } + }); + + listItem.appendChild(deleteButton); + historyList.appendChild(listItem); + }); + + new bootstrap.Modal(document.getElementById('historyModal')).show(); + } else { + alert("Failed to fetch history."); + } + } catch (error) { + console.error("Error fetching history:", error); + alert("Failed to fetch history."); + } +}); + +// Function to delete history item +async function deleteHistoryItem(imageName, listItem) { + try { + const response = await fetch('/delete_history_item', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ image_name: imageName }) + }); + const result = await response.json(); + + if (response.ok) { + alert(result.message); + listItem.remove(); // Remove the item from the list + } else { + alert("Failed to delete image."); + } + } catch (error) { + console.error("Error deleting image:", error); + alert("Failed to delete image."); + } +} + +// Handle Retrain Model button click +retrainModelButton.addEventListener('click', async () => { + if (!trainingInProgress) { + const confirmRetrain = confirm("Are you sure you want to retrain the model?"); + if (!confirmRetrain) return; + + try { + const response = await fetch('/retrain_model', { method: 'POST' }); + const result = await response.json(); + + if (response.ok) { + // Update button to "Cancel Training" + trainingInProgress = true; + retrainModelButton.textContent = "Cancel Training"; + retrainModelButton.classList.replace("btn-primary", "btn-danger"); + startTrainingMonitor(); // Start monitoring the training status + } else { + alert(result.error || "Failed to start retraining."); + } + } catch (error) { + console.error("Error starting training:", error); + alert("An error occurred while starting the training process."); + } + } else { + // Handle cancel training + const confirmCancel = confirm("Are you sure you want to cancel the training?"); + if (!confirmCancel) return; + + try { + const response = await fetch('/cancel_training', { method: 'POST' }); + const result = await response.json(); + + if (response.ok) { + // Reset button to "Retrain Model" + trainingInProgress = false; + retrainModelButton.textContent = "Retrain Model"; + retrainModelButton.classList.replace("btn-danger", "btn-primary"); + alert(result.message || "Training canceled successfully."); + } else { + alert(result.error || "Failed to cancel training."); + } + } catch (error) { + console.error("Error canceling training:", error); + alert("An error occurred while canceling the training process."); + } + } +}); + + +function startTrainingMonitor() { + const monitorInterval = setInterval(async () => { + try { + const response = await fetch('/training_status'); + const result = await response.json(); + + const retrainButton = document.getElementById('retrainModelButton'); + const cancelButton = document.getElementById('cancelTrainingButton'); + const etaDisplay = document.getElementById('etaDisplay'); + + if (result.status === 'running') { + // Show training progress + retrainButton.style.display = 'none'; + cancelButton.style.display = 'inline-block'; + etaDisplay.textContent = `Estimated Time Left: ${result.eta || "Calculating..."}`; + } else if (result.status === 'idle' || result.status === 'cancelled') { + // Revert button to "Retrain Model" (blue) + cancelButton.style.display = 'none'; + retrainButton.style.display = 'inline-block'; + retrainButton.textContent = 'Retrain Model'; + retrainButton.classList.replace('btn-danger', 'btn-primary'); + etaDisplay.textContent = ''; + + // Stop monitoring if training is idle + if (result.status === 'idle') { + clearInterval(monitorInterval); + } + } + } catch (error) { + console.error("Error fetching training status:", error); + } + }, 5000); // Poll every 5 seconds +} + +function resetTrainingUI() { + trainingInProgress = false; + retrainModelButton.textContent = "Retrain Model"; + retrainModelButton.classList.replace("btn-danger", "btn-primary"); + etaDisplay.textContent = ""; +} + +clearHistoryButton.addEventListener('click', async () => { + const confirmClear = confirm("Are you sure you want to clear the history? This will delete all images and masks."); + if (!confirmClear) return; + + try { + const response = await fetch('/clear_history', { method: 'POST' }); + const result = await response.json(); + if (response.ok) { + alert(result.message); + // Optionally update UI to reflect the cleared history + const historyList = document.getElementById('historyList'); + if (historyList) historyList.innerHTML = '
  • No images found in history.
  • '; + } else { + alert("Failed to clear history."); + } + } catch (error) { + console.error("Error clearing history:", error); + alert("Failed to clear history."); + } +}); + +// Toggle training progress display +function showTrainingProgress(message = "Initializing...", timeLeft = "Calculating...") { + document.getElementById("trainingProgress").style.display = "block"; + document.getElementById("progressMessage").textContent = message; + document.getElementById("estimatedTimeLeft").textContent = `Estimated Time Left: ${timeLeft}`; +} + +function hideTrainingProgress() { + document.getElementById("trainingProgress").style.display = "none"; +} + +// Toggle Cancel Training Button +function showCancelTrainingButton() { + document.getElementById("cancelTrainingButton").style.display = "inline-block"; + document.getElementById("retrainModelButton").style.display = "none"; +} + +function hideCancelTrainingButton() { + document.getElementById("cancelTrainingButton").style.display = "none"; + document.getElementById("retrainModelButton").style.display = "inline-block"; +} + +// Add event listener to Cancel Training button +document.getElementById("cancelTrainingButton").addEventListener("click", async () => { + const confirmCancel = confirm("Are you sure you want to cancel training?"); + if (!confirmCancel) return; + + try { + const response = await fetch("/cancel_training", { method: "POST" }); + const result = await response.json(); + + if (result.message) { + alert(result.message); + hideTrainingProgress(); + hideCancelTrainingButton(); + } + } catch (error) { + console.error("Error canceling training:", error); + alert("Failed to cancel training."); + } +}); +// Handle training status updates +socket.on('training_status', (data) => { + const trainingButton = document.getElementById('retrainModelButton'); + const cancelButton = document.getElementById('cancelTrainingButton'); + + if (data.status === 'completed') { + // Update UI: change "Cancel Training" to "Retrain Model" + trainingButton.style.display = 'inline-block'; + cancelButton.style.display = 'none'; + + // Show a popup or notification for training completion + alert(data.message || "Training completed successfully!"); + } else if (data.status === 'failed') { + // Update UI: change "Cancel Training" to "Retrain Model" + trainingButton.style.display = 'inline-block'; + cancelButton.style.display = 'none'; + + // Show a popup or notification for training failure + alert(data.message || "Training failed. Please try again."); + } +}); + +socket.on('button_update', (data) => { + const retrainButton = document.getElementById('retrainModelButton'); + const cancelButton = document.getElementById('cancelTrainingButton'); + + if (data.action === 'retrain') { + // Update to "Retrain Model" button + retrainButton.style.display = 'inline-block'; + retrainButton.textContent = 'Retrain Model'; + retrainButton.classList.replace('btn-danger', 'btn-primary'); + cancelButton.style.display = 'none'; + } +}); + +function updateButtonToRetrainModel() { + const button = document.getElementById('retrainModelButton'); + button.innerText = "Retrain Model"; + button.classList.replace("btn-danger", "btn-primary"); + button.disabled = false; +} + + +socket.on('training_status', (data) => { + const retrainButton = document.getElementById('retrainModelButton'); + const cancelButton = document.getElementById('cancelTrainingButton'); + + if (data.status === 'completed') { + retrainButton.style.display = 'inline-block'; // Show retrain button + retrainButton.textContent = "Retrain Model"; + retrainButton.classList.replace("btn-danger", "btn-primary"); + cancelButton.style.display = 'none'; // Hide cancel button + + // Notify user + alert(data.message); + } else if (data.status === 'cancelled') { + retrainButton.style.display = 'inline-block'; + retrainButton.textContent = "Retrain Model"; + retrainButton.classList.replace("btn-danger", "btn-primary"); + cancelButton.style.display = 'none'; + + // Notify user + alert(data.message); + } +}); + +// Ensure the modal backdrop is properly removed when the modal is closed +document.getElementById('historyModal').addEventListener('hidden.bs.modal', function () { + document.body.classList.remove('modal-open'); + const backdrop = document.querySelector('.modal-backdrop'); + if (backdrop) { + backdrop.remove(); + } +}); diff --git a/static/uploads/input/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5.jpg b/static/uploads/input/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d8c85c5cc11e9c183748a7d35e3c36ecb73d39ce Binary files /dev/null and b/static/uploads/input/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5.jpg differ diff --git a/static/uploads/input/05_jpg.rf.46241369ebb0749c40882400f82eb224.jpg b/static/uploads/input/05_jpg.rf.46241369ebb0749c40882400f82eb224.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3579081555eed5deffe74a0c437f943bdc15c6d6 Binary files /dev/null and b/static/uploads/input/05_jpg.rf.46241369ebb0749c40882400f82eb224.jpg differ diff --git a/static/uploads/input/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98.jpg b/static/uploads/input/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8c3fee5d55e85fbac7e5e21cda3751efcd5ca171 Binary files /dev/null and b/static/uploads/input/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98.jpg differ diff --git a/static/uploads/input/10_JPG.rf.6745a7b3ea828239398b85182acba199.jpg b/static/uploads/input/10_JPG.rf.6745a7b3ea828239398b85182acba199.jpg new file mode 100644 index 0000000000000000000000000000000000000000..38978a38b1cb3774a934411b3f30271b18f56830 Binary files /dev/null and b/static/uploads/input/10_JPG.rf.6745a7b3ea828239398b85182acba199.jpg differ diff --git a/static/uploads/input/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg b/static/uploads/input/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac452b8c15bf70a138cbc334e670fb75f2c4900e Binary files /dev/null and b/static/uploads/input/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg differ diff --git a/static/uploads/input/12_jpg.rf.357643b374df92f81f9dee7c701b2315.jpg b/static/uploads/input/12_jpg.rf.357643b374df92f81f9dee7c701b2315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..076496fd0fc67d90ad57e8d1489b4b74325d6a95 Binary files /dev/null and b/static/uploads/input/12_jpg.rf.357643b374df92f81f9dee7c701b2315.jpg differ diff --git a/static/uploads/input/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c.jpg b/static/uploads/input/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5c9d8aab04b28cbbede78a5acf683b613aa14dd9 Binary files /dev/null and b/static/uploads/input/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c.jpg differ diff --git a/static/uploads/input/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg b/static/uploads/input/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4cb1ec47ff7e172bcd876582462f9667c0118c19 Binary files /dev/null and b/static/uploads/input/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg differ diff --git a/static/uploads/input/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg b/static/uploads/input/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3612ea8bad3b881d64ba82a7bcb3734fd0d3c301 Binary files /dev/null and b/static/uploads/input/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg differ diff --git a/static/uploads/input/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg b/static/uploads/input/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f7769c2afbed542791dc7aa6c99b8f731b993c3d Binary files /dev/null and b/static/uploads/input/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg differ diff --git a/static/uploads/input/18_jpg.rf.4d241aab78af17171d83f3a50f1cf1aa.jpg b/static/uploads/input/18_jpg.rf.4d241aab78af17171d83f3a50f1cf1aa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b11c87c8db424bf376c804444f39271f52c3b71f Binary files /dev/null and b/static/uploads/input/18_jpg.rf.4d241aab78af17171d83f3a50f1cf1aa.jpg differ diff --git a/static/uploads/input/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg b/static/uploads/input/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a23e88255583db9dad0ae7d84c04aee4a17d7a1 Binary files /dev/null and b/static/uploads/input/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg differ diff --git a/static/uploads/input/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg b/static/uploads/input/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg new file mode 100644 index 0000000000000000000000000000000000000000..368daef342ad0515e373f0c051f72717bdeb2d47 Binary files /dev/null and b/static/uploads/input/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg differ diff --git a/static/uploads/input/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg b/static/uploads/input/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1e24609b75ed127858eb6fb66eba9d2274a09917 Binary files /dev/null and b/static/uploads/input/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg differ diff --git a/static/uploads/input/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg b/static/uploads/input/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg new file mode 100644 index 0000000000000000000000000000000000000000..860d0363ad46861aea837591d26295477a608eac Binary files /dev/null and b/static/uploads/input/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg differ diff --git a/static/uploads/input/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg b/static/uploads/input/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b357e7c557db3b68472b6f40accc04f469f33dd Binary files /dev/null and b/static/uploads/input/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg differ diff --git a/static/uploads/input/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg b/static/uploads/input/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9182db336422bdf0ac364af9fad89d9e0aa7ff4f Binary files /dev/null and b/static/uploads/input/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg differ diff --git a/static/uploads/input/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058.jpg b/static/uploads/input/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9472eeee73b114d8aca4191a6d71caa94b3fb499 Binary files /dev/null and b/static/uploads/input/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058.jpg differ diff --git a/static/uploads/mask/chips/05_jpg.rf.46241369ebb0749c40882400f82eb224.png b/static/uploads/mask/chips/05_jpg.rf.46241369ebb0749c40882400f82eb224.png new file mode 100644 index 0000000000000000000000000000000000000000..a942f9262cdec73c5cdfdf8056be9bf00045c8e3 Binary files /dev/null and b/static/uploads/mask/chips/05_jpg.rf.46241369ebb0749c40882400f82eb224.png differ diff --git a/static/uploads/mask/chips/10_JPG.rf.6745a7b3ea828239398b85182acba199.png b/static/uploads/mask/chips/10_JPG.rf.6745a7b3ea828239398b85182acba199.png new file mode 100644 index 0000000000000000000000000000000000000000..3d210ff182291a1a3cc88cfb3340a774a685b6fb Binary files /dev/null and b/static/uploads/mask/chips/10_JPG.rf.6745a7b3ea828239398b85182acba199.png differ diff --git a/static/uploads/mask/chips/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png b/static/uploads/mask/chips/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png new file mode 100644 index 0000000000000000000000000000000000000000..cfe59a8e858cb63a75e133f959eabb1ebdc0c20b Binary files /dev/null and b/static/uploads/mask/chips/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png differ diff --git a/static/uploads/mask/chips/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png b/static/uploads/mask/chips/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png new file mode 100644 index 0000000000000000000000000000000000000000..9d8126856e3b0a8cbdd32ab02489700065839c4d Binary files /dev/null and b/static/uploads/mask/chips/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png differ diff --git a/static/uploads/mask/chips/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png b/static/uploads/mask/chips/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png new file mode 100644 index 0000000000000000000000000000000000000000..b54dd093ebc3753379df1d30325ee7232f1a653d Binary files /dev/null and b/static/uploads/mask/chips/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png differ diff --git a/static/uploads/mask/chips/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png b/static/uploads/mask/chips/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png new file mode 100644 index 0000000000000000000000000000000000000000..e4dbc826edcda87f39dd1998176e2162c379058c Binary files /dev/null and b/static/uploads/mask/chips/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png differ diff --git a/static/uploads/mask/chips/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png b/static/uploads/mask/chips/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png new file mode 100644 index 0000000000000000000000000000000000000000..b1159aa308dd01a7b4ceb1c15d3fcc27f622f193 Binary files /dev/null and b/static/uploads/mask/chips/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png differ diff --git a/static/uploads/mask/chips/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png b/static/uploads/mask/chips/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png new file mode 100644 index 0000000000000000000000000000000000000000..685eb010a7a3ed2820a583c6f79669735e8d2a33 Binary files /dev/null and b/static/uploads/mask/chips/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png differ diff --git a/static/uploads/mask/chips/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png b/static/uploads/mask/chips/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png new file mode 100644 index 0000000000000000000000000000000000000000..8e1c3e1a7e9af6d64b521e132eb24145e5b89bdb Binary files /dev/null and b/static/uploads/mask/chips/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png differ diff --git a/static/uploads/mask/chips/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png b/static/uploads/mask/chips/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png new file mode 100644 index 0000000000000000000000000000000000000000..49af8a089a365164655cb69ee999698285b16b31 Binary files /dev/null and b/static/uploads/mask/chips/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png differ diff --git a/static/uploads/mask/chips/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png b/static/uploads/mask/chips/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png new file mode 100644 index 0000000000000000000000000000000000000000..a9ba5fccb3f7f3eb5b2ce8a556812566f7d1770e Binary files /dev/null and b/static/uploads/mask/chips/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png differ diff --git a/static/uploads/mask/chips/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png b/static/uploads/mask/chips/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png new file mode 100644 index 0000000000000000000000000000000000000000..085353f59f871492391434d027b1f664604dcfc7 Binary files /dev/null and b/static/uploads/mask/chips/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png differ diff --git a/static/uploads/mask/chips/raw_mask.png b/static/uploads/mask/chips/raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..87663ac428d4cf9c4c0b4b682e93b2817c9fd8c9 Binary files /dev/null and b/static/uploads/mask/chips/raw_mask.png differ diff --git a/static/uploads/mask/voids/05_jpg.rf.46241369ebb0749c40882400f82eb224.png b/static/uploads/mask/voids/05_jpg.rf.46241369ebb0749c40882400f82eb224.png new file mode 100644 index 0000000000000000000000000000000000000000..967594c43334b0f6d6d8cf184108f29e940d6a51 Binary files /dev/null and b/static/uploads/mask/voids/05_jpg.rf.46241369ebb0749c40882400f82eb224.png differ diff --git a/static/uploads/mask/voids/10_JPG.rf.6745a7b3ea828239398b85182acba199.png b/static/uploads/mask/voids/10_JPG.rf.6745a7b3ea828239398b85182acba199.png new file mode 100644 index 0000000000000000000000000000000000000000..535574a5d920250a287286dc4cdbd2198b503bcb Binary files /dev/null and b/static/uploads/mask/voids/10_JPG.rf.6745a7b3ea828239398b85182acba199.png differ diff --git a/static/uploads/mask/voids/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png b/static/uploads/mask/voids/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png new file mode 100644 index 0000000000000000000000000000000000000000..e6fc92f551a7d60c65f2df142a541bec91870f5d Binary files /dev/null and b/static/uploads/mask/voids/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png differ diff --git a/static/uploads/mask/voids/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png b/static/uploads/mask/voids/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png new file mode 100644 index 0000000000000000000000000000000000000000..a2354e965cc811518a03465c31af0c198869fd95 Binary files /dev/null and b/static/uploads/mask/voids/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png differ diff --git a/static/uploads/mask/voids/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png b/static/uploads/mask/voids/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png new file mode 100644 index 0000000000000000000000000000000000000000..94ec31dbc470a83ad06750cbfc1dadacdae9e923 Binary files /dev/null and b/static/uploads/mask/voids/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png differ diff --git a/static/uploads/mask/voids/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png b/static/uploads/mask/voids/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png new file mode 100644 index 0000000000000000000000000000000000000000..b373067ac59737c435e4eada9171f4498416a89f Binary files /dev/null and b/static/uploads/mask/voids/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png differ diff --git a/static/uploads/mask/voids/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png b/static/uploads/mask/voids/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png new file mode 100644 index 0000000000000000000000000000000000000000..3ccdfc5a7e3612a37167a117828aa10b31bb2322 Binary files /dev/null and b/static/uploads/mask/voids/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png differ diff --git a/static/uploads/mask/voids/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png b/static/uploads/mask/voids/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png new file mode 100644 index 0000000000000000000000000000000000000000..0151180865d7b664d7ffa51db606086a16f6c889 Binary files /dev/null and b/static/uploads/mask/voids/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png differ diff --git a/static/uploads/mask/voids/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png b/static/uploads/mask/voids/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png new file mode 100644 index 0000000000000000000000000000000000000000..a1c195f11edd403f3dabac42760e9049808c51e7 Binary files /dev/null and b/static/uploads/mask/voids/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png differ diff --git a/static/uploads/mask/voids/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png b/static/uploads/mask/voids/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png new file mode 100644 index 0000000000000000000000000000000000000000..ac8e0a95d726259518793040157782d8ce667656 Binary files /dev/null and b/static/uploads/mask/voids/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png differ diff --git a/static/uploads/mask/voids/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png b/static/uploads/mask/voids/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png new file mode 100644 index 0000000000000000000000000000000000000000..cd9732adb958f28969d2d88c1fe8af3bec378f12 Binary files /dev/null and b/static/uploads/mask/voids/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png differ diff --git a/static/uploads/mask/voids/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png b/static/uploads/mask/voids/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png new file mode 100644 index 0000000000000000000000000000000000000000..13afab24243cbfaf4c21271b50bc0d41c2b38b51 Binary files /dev/null and b/static/uploads/mask/voids/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png differ diff --git a/static/uploads/mask/voids/raw_mask.png b/static/uploads/mask/voids/raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..5b197214b627eb2d8ab6de48a1c110b4ae4ad0af Binary files /dev/null and b/static/uploads/mask/voids/raw_mask.png differ diff --git a/static/uploads/segmented/automatic/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5_pred.jpg b/static/uploads/segmented/automatic/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c6780664033acc985c09c1ce3c595467e9e3e0a5 Binary files /dev/null and b/static/uploads/segmented/automatic/02_JPG.rf.d6063f8ca200e543da7becc1bf260ed5_pred.jpg differ diff --git a/static/uploads/segmented/automatic/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98_pred.jpg b/static/uploads/segmented/automatic/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..53fa958e7a2dc661029ce443e99b6dec9f47f065 Binary files /dev/null and b/static/uploads/segmented/automatic/08_JPG.rf.1f81e954a3bbfc49dcd30e3ba0eb5e98_pred.jpg differ diff --git a/static/uploads/segmented/automatic/09_JPG.rf.9119efd8c174f968457a893669209835_pred.jpg b/static/uploads/segmented/automatic/09_JPG.rf.9119efd8c174f968457a893669209835_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a235f3b90e22425392701edf2512e7bb6cc7d683 Binary files /dev/null and b/static/uploads/segmented/automatic/09_JPG.rf.9119efd8c174f968457a893669209835_pred.jpg differ diff --git a/static/uploads/segmented/automatic/10_JPG.rf.6745a7b3ea828239398b85182acba199_pred.jpg b/static/uploads/segmented/automatic/10_JPG.rf.6745a7b3ea828239398b85182acba199_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..54c0b310d17de83cf0df43b1fb36b363c14ad4d9 Binary files /dev/null and b/static/uploads/segmented/automatic/10_JPG.rf.6745a7b3ea828239398b85182acba199_pred.jpg differ diff --git a/static/uploads/segmented/automatic/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb_pred.jpg b/static/uploads/segmented/automatic/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f8e7cd3c134a1f41a87448bf764a2dc18468315b Binary files /dev/null and b/static/uploads/segmented/automatic/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb_pred.jpg differ diff --git a/static/uploads/segmented/automatic/12_jpg.rf.357643b374df92f81f9dee7c701b2315_pred.jpg b/static/uploads/segmented/automatic/12_jpg.rf.357643b374df92f81f9dee7c701b2315_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9d8b448d79abde4a5eddea506ac44eb0ec7721ec Binary files /dev/null and b/static/uploads/segmented/automatic/12_jpg.rf.357643b374df92f81f9dee7c701b2315_pred.jpg differ diff --git a/static/uploads/segmented/automatic/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c_pred.jpg b/static/uploads/segmented/automatic/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..96f4bc949ef48c41321d654a339ad008cbde48ee Binary files /dev/null and b/static/uploads/segmented/automatic/14_jpg.rf.d91472c724e7c34da4d96ac5e504044c_pred.jpg differ diff --git a/static/uploads/segmented/automatic/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0_pred.jpg b/static/uploads/segmented/automatic/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..65ca5e28eab24cd07c90db185b164c62f9071525 Binary files /dev/null and b/static/uploads/segmented/automatic/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0_pred.jpg differ diff --git a/static/uploads/segmented/automatic/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536_pred.jpg b/static/uploads/segmented/automatic/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c80f32e508a2a2e25a60c8d0cd3710cd05e44158 Binary files /dev/null and b/static/uploads/segmented/automatic/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536_pred.jpg differ diff --git a/static/uploads/segmented/automatic/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba_pred.jpg b/static/uploads/segmented/automatic/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a4ecae049c42c8f160146bedb05d1cb67279e98 Binary files /dev/null and b/static/uploads/segmented/automatic/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba_pred.jpg differ diff --git a/static/uploads/segmented/automatic/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147_pred.jpg b/static/uploads/segmented/automatic/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cea7467ea27a4b62cb024a0d7b67c976c83f5850 Binary files /dev/null and b/static/uploads/segmented/automatic/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147_pred.jpg differ diff --git a/static/uploads/segmented/automatic/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab_pred.jpg b/static/uploads/segmented/automatic/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd37ca630143244d2b5cbfb7a3d7b24dd0abd9f3 Binary files /dev/null and b/static/uploads/segmented/automatic/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab_pred.jpg differ diff --git a/static/uploads/segmented/automatic/29_jpg.rf.931769b58ae20d18d1f09d042bc44176_pred.jpg b/static/uploads/segmented/automatic/29_jpg.rf.931769b58ae20d18d1f09d042bc44176_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7c8091d61aed18b2ba2177175359a893911d752e Binary files /dev/null and b/static/uploads/segmented/automatic/29_jpg.rf.931769b58ae20d18d1f09d042bc44176_pred.jpg differ diff --git a/static/uploads/segmented/automatic/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058_pred.jpg b/static/uploads/segmented/automatic/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac40c44b30a36c7dc1ab15f4b36e8b918ffaf030 Binary files /dev/null and b/static/uploads/segmented/automatic/7-Figure14-1_jpg.rf.1c6cb204ed1f37c8fed44178a02e9058_pred.jpg differ diff --git a/static/uploads/segmented/chips/blended_chips.png b/static/uploads/segmented/chips/blended_chips.png new file mode 100644 index 0000000000000000000000000000000000000000..160bc0e96af90d18a99e7bdc59840fc74a0aa686 Binary files /dev/null and b/static/uploads/segmented/chips/blended_chips.png differ diff --git a/static/uploads/segmented/chips/blended_image.png b/static/uploads/segmented/chips/blended_image.png new file mode 100644 index 0000000000000000000000000000000000000000..1d8fbf1814f362d4cdf5b637de25362befe3ed4d Binary files /dev/null and b/static/uploads/segmented/chips/blended_image.png differ diff --git a/static/uploads/segmented/voids/blended_image.png b/static/uploads/segmented/voids/blended_image.png new file mode 100644 index 0000000000000000000000000000000000000000..45b729389e40cc9aa196311076af6305de5ea98d Binary files /dev/null and b/static/uploads/segmented/voids/blended_image.png differ diff --git a/static/uploads/segmented/voids/blended_voids.png b/static/uploads/segmented/voids/blended_voids.png new file mode 100644 index 0000000000000000000000000000000000000000..9aadb2a2d3bf8b8ca8cbecc64be92a27168a2052 Binary files /dev/null and b/static/uploads/segmented/voids/blended_voids.png differ diff --git a/temp_backup/.DS_Store b/temp_backup/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ad89615ad548a34663c2e49a0e40671136a26d5a Binary files /dev/null and b/temp_backup/.DS_Store differ diff --git a/temp_backup/images/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg b/temp_backup/images/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac452b8c15bf70a138cbc334e670fb75f2c4900e Binary files /dev/null and b/temp_backup/images/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.jpg differ diff --git a/temp_backup/images/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg b/temp_backup/images/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4cb1ec47ff7e172bcd876582462f9667c0118c19 Binary files /dev/null and b/temp_backup/images/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.jpg differ diff --git a/temp_backup/images/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg b/temp_backup/images/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3612ea8bad3b881d64ba82a7bcb3734fd0d3c301 Binary files /dev/null and b/temp_backup/images/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.jpg differ diff --git a/temp_backup/images/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg b/temp_backup/images/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f7769c2afbed542791dc7aa6c99b8f731b993c3d Binary files /dev/null and b/temp_backup/images/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.jpg differ diff --git a/temp_backup/images/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg b/temp_backup/images/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a23e88255583db9dad0ae7d84c04aee4a17d7a1 Binary files /dev/null and b/temp_backup/images/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.jpg differ diff --git a/temp_backup/images/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg b/temp_backup/images/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg new file mode 100644 index 0000000000000000000000000000000000000000..368daef342ad0515e373f0c051f72717bdeb2d47 Binary files /dev/null and b/temp_backup/images/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.jpg differ diff --git a/temp_backup/images/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg b/temp_backup/images/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1e24609b75ed127858eb6fb66eba9d2274a09917 Binary files /dev/null and b/temp_backup/images/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.jpg differ diff --git a/temp_backup/images/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg b/temp_backup/images/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg new file mode 100644 index 0000000000000000000000000000000000000000..860d0363ad46861aea837591d26295477a608eac Binary files /dev/null and b/temp_backup/images/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.jpg differ diff --git a/temp_backup/images/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg b/temp_backup/images/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b357e7c557db3b68472b6f40accc04f469f33dd Binary files /dev/null and b/temp_backup/images/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.jpg differ diff --git a/temp_backup/images/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg b/temp_backup/images/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9182db336422bdf0ac364af9fad89d9e0aa7ff4f Binary files /dev/null and b/temp_backup/images/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.jpg differ diff --git a/temp_backup/masks/chips/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png b/temp_backup/masks/chips/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png new file mode 100644 index 0000000000000000000000000000000000000000..cfe59a8e858cb63a75e133f959eabb1ebdc0c20b Binary files /dev/null and b/temp_backup/masks/chips/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png differ diff --git a/temp_backup/masks/chips/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png b/temp_backup/masks/chips/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png new file mode 100644 index 0000000000000000000000000000000000000000..9d8126856e3b0a8cbdd32ab02489700065839c4d Binary files /dev/null and b/temp_backup/masks/chips/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png differ diff --git a/temp_backup/masks/chips/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png b/temp_backup/masks/chips/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png new file mode 100644 index 0000000000000000000000000000000000000000..b54dd093ebc3753379df1d30325ee7232f1a653d Binary files /dev/null and b/temp_backup/masks/chips/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png differ diff --git a/temp_backup/masks/chips/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png b/temp_backup/masks/chips/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png new file mode 100644 index 0000000000000000000000000000000000000000..e4dbc826edcda87f39dd1998176e2162c379058c Binary files /dev/null and b/temp_backup/masks/chips/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png differ diff --git a/temp_backup/masks/chips/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png b/temp_backup/masks/chips/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png new file mode 100644 index 0000000000000000000000000000000000000000..b1159aa308dd01a7b4ceb1c15d3fcc27f622f193 Binary files /dev/null and b/temp_backup/masks/chips/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png differ diff --git a/temp_backup/masks/chips/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png b/temp_backup/masks/chips/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png new file mode 100644 index 0000000000000000000000000000000000000000..685eb010a7a3ed2820a583c6f79669735e8d2a33 Binary files /dev/null and b/temp_backup/masks/chips/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png differ diff --git a/temp_backup/masks/chips/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png b/temp_backup/masks/chips/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png new file mode 100644 index 0000000000000000000000000000000000000000..8e1c3e1a7e9af6d64b521e132eb24145e5b89bdb Binary files /dev/null and b/temp_backup/masks/chips/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png differ diff --git a/temp_backup/masks/chips/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png b/temp_backup/masks/chips/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png new file mode 100644 index 0000000000000000000000000000000000000000..49af8a089a365164655cb69ee999698285b16b31 Binary files /dev/null and b/temp_backup/masks/chips/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png differ diff --git a/temp_backup/masks/chips/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png b/temp_backup/masks/chips/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png new file mode 100644 index 0000000000000000000000000000000000000000..a9ba5fccb3f7f3eb5b2ce8a556812566f7d1770e Binary files /dev/null and b/temp_backup/masks/chips/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png differ diff --git a/temp_backup/masks/chips/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png b/temp_backup/masks/chips/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png new file mode 100644 index 0000000000000000000000000000000000000000..085353f59f871492391434d027b1f664604dcfc7 Binary files /dev/null and b/temp_backup/masks/chips/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png differ diff --git a/temp_backup/masks/voids/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png b/temp_backup/masks/voids/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png new file mode 100644 index 0000000000000000000000000000000000000000..e6fc92f551a7d60c65f2df142a541bec91870f5d Binary files /dev/null and b/temp_backup/masks/voids/11_JPG.rf.3aa3109a1838549cf273cdbe8b2cafeb.png differ diff --git a/temp_backup/masks/voids/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png b/temp_backup/masks/voids/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png new file mode 100644 index 0000000000000000000000000000000000000000..a2354e965cc811518a03465c31af0c198869fd95 Binary files /dev/null and b/temp_backup/masks/voids/15_jpg.rf.284413e4432b16253b4cd19f0c4f01e2.png differ diff --git a/temp_backup/masks/voids/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png b/temp_backup/masks/voids/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png new file mode 100644 index 0000000000000000000000000000000000000000..94ec31dbc470a83ad06750cbfc1dadacdae9e923 Binary files /dev/null and b/temp_backup/masks/voids/15r_jpg.rf.2da1990173346311d3a3555e23a9164a.png differ diff --git a/temp_backup/masks/voids/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png b/temp_backup/masks/voids/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png new file mode 100644 index 0000000000000000000000000000000000000000..b373067ac59737c435e4eada9171f4498416a89f Binary files /dev/null and b/temp_backup/masks/voids/16_jpg.rf.9fdb4f56ec8596ddcc31db5bbffc26a0.png differ diff --git a/temp_backup/masks/voids/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png b/temp_backup/masks/voids/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png new file mode 100644 index 0000000000000000000000000000000000000000..3ccdfc5a7e3612a37167a117828aa10b31bb2322 Binary files /dev/null and b/temp_backup/masks/voids/20_jpg.rf.4a45f799ba16b5ff81ab1929f12a12b1.png differ diff --git a/temp_backup/masks/voids/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png b/temp_backup/masks/voids/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png new file mode 100644 index 0000000000000000000000000000000000000000..0151180865d7b664d7ffa51db606086a16f6c889 Binary files /dev/null and b/temp_backup/masks/voids/21_jpg.rf.d1d6dd254d2e5f396589ccc68a3c8536.png differ diff --git a/temp_backup/masks/voids/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png b/temp_backup/masks/voids/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png new file mode 100644 index 0000000000000000000000000000000000000000..a1c195f11edd403f3dabac42760e9049808c51e7 Binary files /dev/null and b/temp_backup/masks/voids/22_jpg.rf.a72964a78ea36c7bebe3a09cf27ef6ba.png differ diff --git a/temp_backup/masks/voids/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png b/temp_backup/masks/voids/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png new file mode 100644 index 0000000000000000000000000000000000000000..ac8e0a95d726259518793040157782d8ce667656 Binary files /dev/null and b/temp_backup/masks/voids/25_jpg.rf.893f4286e0c8a3cef2efb7612f248147.png differ diff --git a/temp_backup/masks/voids/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png b/temp_backup/masks/voids/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png new file mode 100644 index 0000000000000000000000000000000000000000..cd9732adb958f28969d2d88c1fe8af3bec378f12 Binary files /dev/null and b/temp_backup/masks/voids/26_jpg.rf.a03c550707ff22cd50ff7f54bebda7ab.png differ diff --git a/temp_backup/masks/voids/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png b/temp_backup/masks/voids/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png new file mode 100644 index 0000000000000000000000000000000000000000..13afab24243cbfaf4c21271b50bc0d41c2b38b51 Binary files /dev/null and b/temp_backup/masks/voids/29_jpg.rf.931769b58ae20d18d1f09d042bc44176.png differ diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..cbf5c21141d33d37be214a5f598faad99e93d23d --- /dev/null +++ b/templates/index.html @@ -0,0 +1,137 @@ + + + + + + Segmentation Tools + + + + +
    + +

    Segmentation Tools

    + + +
    +

    Interactive Segmentation Tool

    +
    + + +
    +
    +
    +
    Input Image
    +
    + +
    + +
    +
    +
    Segmented Image
    +
    + +
    +
    + + + + + +
    +
    +
    + +
    + + + + + +
    +

    Automatic Segmentation Tool

    +
    +
    +
    Input Image
    + +
    +
    +
    Processed Image
    +
    + +
    +
    +
    +
    + + +
    +

    Segmentation Results

    + + + + + + + + + + + + + +
    Image NameChip #Chip AreaVoid %Max Void %
    +
    + + +
    +
    +
    + + + + + + + + + + diff --git a/utils/.DS_Store b/utils/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..33928769301887fef432324e9d57c5a2462a9a02 Binary files /dev/null and b/utils/.DS_Store differ diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63242630dbed3629c64041a1c4616a47bcf1fc72 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-311.pyc b/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7471c80806948ab99a064501a26cfe7dc8ccd3ae Binary files /dev/null and b/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/utils/__pycache__/helpers.cpython-310.pyc b/utils/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cbba024f8526a2794045bf1bee36f4baf13956c Binary files /dev/null and b/utils/__pycache__/helpers.cpython-310.pyc differ diff --git a/utils/__pycache__/helpers.cpython-311.pyc b/utils/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..184f8d72303fed8f7c7f52427fe8ea12b9b20f24 Binary files /dev/null and b/utils/__pycache__/helpers.cpython-311.pyc differ diff --git a/utils/__pycache__/predictor.cpython-310.pyc b/utils/__pycache__/predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7b4ec1c91d01f671123579e0aac80380a7fd660 Binary files /dev/null and b/utils/__pycache__/predictor.cpython-310.pyc differ diff --git a/utils/__pycache__/predictor.cpython-311.pyc b/utils/__pycache__/predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9b4bcb5833e2f0a6243900d0e56562f56922a16 Binary files /dev/null and b/utils/__pycache__/predictor.cpython-311.pyc differ diff --git a/utils/helpers.py b/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..129147f109dc13541220e126be1ab7e9b43f8f7d --- /dev/null +++ b/utils/helpers.py @@ -0,0 +1,74 @@ +import numpy as np +from PIL import Image +import cv2 +import os +import shutil + +def blend_mask_with_image(image, mask, color): + """Blend the mask with the original image using a transparent color overlay.""" + mask_rgb = np.stack([mask * color[i] for i in range(3)], axis=-1) + blended = (0.7 * image + 0.3 * mask_rgb).astype(np.uint8) + return blended + +def save_mask_as_png(mask, path): + """Save the binary mask as a PNG.""" + mask_image = Image.fromarray((mask * 255).astype(np.uint8)) + mask_image.save(path) + +def convert_mask_to_yolo(mask_path, image_path, class_id, output_path, append=False): + """ + Convert a binary mask to YOLO-compatible segmentation labels. + + Args: + mask_path (str): Path to the binary mask image. + image_path (str): Path to the corresponding image. + class_id (int): Class ID (e.g., 0 for void, 1 for chip). + output_path (str): Path to save the YOLO label (.txt) file. + append (bool): Whether to append labels to the file. + + Returns: + None + """ + try: + # Load the binary mask + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + if mask is None: + raise ValueError(f"Mask not found or invalid: {mask_path}") + + # Load the corresponding image to get dimensions + image = cv2.imread(image_path) + if image is None: + raise ValueError(f"Image not found or invalid: {image_path}") + + h, w = image.shape[:2] # Image height and width + + # Find contours in the mask + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Determine file mode: "w" for overwrite or "a" for append + file_mode = "a" if append else "w" + + # Open the output .txt file + with open(output_path, file_mode) as label_file: + for contour in contours: + # Simplify the contour points to reduce the number of vertices + epsilon = 0.01 * cv2.arcLength(contour, True) # Tolerance for approximation + contour = cv2.approxPolyDP(contour, epsilon, True) + + # Normalize contour points (polygon vertices) + normalized_vertices = [] + for point in contour: + x, y = point[0] # Extract x, y from the point + x_normalized = x / w + y_normalized = y / h + normalized_vertices.extend([x_normalized, y_normalized]) + + # Write the polygon annotation to the label file + if len(normalized_vertices) >= 6: # At least 3 points required for a polygon + label_file.write(f"{class_id} " + " ".join(f"{v:.6f}" for v in normalized_vertices) + "\n") + + print(f"YOLO segmentation label saved: {output_path}") + + except Exception as e: + print(f"Error converting mask to YOLO format: {e}") + raise RuntimeError(f"Failed to convert {mask_path} for class {class_id}: {e}") diff --git a/utils/predictor.py b/utils/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0e358117a71ea711b1d221034341a23599f0f0 --- /dev/null +++ b/utils/predictor.py @@ -0,0 +1,26 @@ +import numpy as np +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +class Predictor: + def __init__(self, model_cfg, checkpoint, device): + self.device = device + self.model = build_sam2(model_cfg, checkpoint, device=device) + self.predictor = SAM2ImagePredictor(self.model) + self.image_set = False + + def set_image(self, image): + """Set the image for SAM prediction.""" + self.image = image + self.predictor.set_image(image) + self.image_set = True + + def predict(self, point_coords, point_labels, multimask_output=False): + """Run SAM prediction.""" + if not self.image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + return self.predictor.predict( + point_coords=point_coords, + point_labels=point_labels, + multimask_output=multimask_output + ) diff --git a/yolo11n.pt b/yolo11n.pt new file mode 100644 index 0000000000000000000000000000000000000000..c7723db027d009343e9682f261370833fd6f0d84 --- /dev/null +++ b/yolo11n.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ebbc80d4a7680d14987a577cd21342b65ecfd94632bd9a8da63ae6417644ee1 +size 5613764