Boundary-aware Supervoxel-level Iteratively Refined Interactive 3D Image Segmentation with Multi-agent Reinforcement Learning
Paper • 2303.10692 • Published
YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
A PPO-trained RL agent that automates corrective click placement for SAM2-based interactive medical image segmentation. After an initial user click, the agent automatically places additional refinement clicks to improve the segmentation mask.
from stable_baselines3 import PPO
from sam2_click_env import SAM2ClickEnv, compute_dice
# Load agent
model = PPO.load("click_agent_ppo")
# Create environment with your SAM2 predictor
env = SAM2ClickEnv(
dataset=your_dataset,
sam_predictor=your_sam_predictor,
obs_size=128,
grid_size=32,
max_clicks=5,
use_sam=True,
)
# Run inference
obs, info = env.reset()
for step in range(5):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
print(f"Step {step+1}: Dice={info['dice']:.4f}")
if done:
break