YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

SAM2 Click Agent — RL-Based Interactive Segmentation Refinement

Overview

A PPO-trained RL agent that automates corrective click placement for SAM2-based interactive medical image segmentation. After an initial user click, the agent automatically places additional refinement clicks to improve the segmentation mask.

Architecture

  • Environment: SAM2.1-hiera-base-plus (frozen) as segmentation backbone
  • Agent: CNN-based PPO policy (Stable-Baselines3)
  • Observation: 6-channel image (RGB + mask + fg/bg click heatmaps), 128x128
  • Action: Discrete(2048) = 32x32 grid positions × 2 (fg/bg)
  • Reward: Delta Dice + boundary-aware bonus (BS-IRIS inspired)

Training Details

  • Dataset: Kvasir-SEG Augmented (4800 train, polyp segmentation)
  • Total timesteps: 500,000
  • Training time: 185.7 minutes
  • Parameters: 1,886,625
  • PPO Config: lr=0.00025, clip=0.1, ent=0.02, batch=128

Results (on test set, 100 samples)

Oracle Baseline (deterministic heuristic — center of largest error region)

  • step_0: Dice = 0.0000 ± 0.0000
  • step_1: Dice = 0.7482 ± 0.3160
  • step_2: Dice = 0.8480 ± 0.2142
  • step_3: Dice = 0.8545 ± 0.2173
  • step_4: Dice = 0.8942 ± 0.1692
  • step_5: Dice = 0.9170 ± 0.1281

RL Click Agent (trained PPO policy)

  • mean_episode_reward: -0.0348
  • step_0: Dice = 0.7482 ± 0.3160
  • step_1: Dice = 0.6528 ± 0.3375
  • step_2: Dice = 0.6242 ± 0.3509
  • step_3: Dice = 0.6667 ± 0.3246
  • step_4: Dice = 0.6445 ± 0.3087
  • step_5: Dice = 0.6141 ± 0.3233

Based On

  • BS-IRIS — Boundary-aware reward design (IEEE TMI 2023)
  • IteR-MRL — Multi-agent RL for interactive segmentation (CVPR 2020)
  • RITM — Oracle click simulation strategy

Usage

from stable_baselines3 import PPO
from sam2_click_env import SAM2ClickEnv, compute_dice

# Load agent
model = PPO.load("click_agent_ppo")

# Create environment with your SAM2 predictor
env = SAM2ClickEnv(
    dataset=your_dataset,
    sam_predictor=your_sam_predictor,
    obs_size=128,
    grid_size=32,
    max_clicks=5,
    use_sam=True,
)

# Run inference
obs, info = env.reset()
for step in range(5):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)
    print(f"Step {step+1}: Dice={info['dice']:.4f}")
    if done:
        break
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Papers for dat-rohit/sam2-click-agent