Real vectorized trimap generation

#2
by PartyParrot - opened

In your paper in listing 1, you give the implementation of

an efficient vectorized version of the confidence trimap generation algorithm

I think your implementation could still be improved a bit. It is not yet fully vectorized, since it steps through various thresholds in the while loop in line 33.

I do not know the range of pred in your paper, but here is a fully vectorized version with a benchmark on some random image for which it is much faster. Maybe it is useful to you.

# https://arxiv.org/pdf/2501.06230
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import time
import urllib.request

def generate_trimap_vectorized(pred, min_unknown_pixels=60_000):
    # Quantize predictions to 256 levels and count them
    values = pred.sigmoid() * 256
    histc = values.view(-1).histc(bins=256, min=0, max=255)
    counts = pred.numel() - (histc + histc.flip(0)).cumsum(0)
    # Get first index where count of unknown pixels exceeds threshold
    index = (counts > min_unknown_pixels).to(int).argmin()
    # If you want to, you could clamp the index to some reasonable range
    trimap = (values > 255 - index) * 127.0 + (values >= index) * 128.0
    return trimap

def generate_trimap(pred):
    min_pixels = 60000
    t_high = 0.90
    t_low = 0.10
    t_min = 0.03 # Lower bound
    t_max = 0.97 # Upper bound
    step = 0.001 # Adjustment size
    mask = pred.sigmoid () # Apply sigmoid to prediction

     # Generate initial trimap
    trimap = torch.where (
        mask >= t_high ,
        mask.new_tensor (255.0) ,
        torch.where (
        mask <= t_low ,
            mask.new_tensor (0.0) ,
            mask.new_tensor (128.0)
        )
    )

     # Count gray pixels and adjust thresholds
    n_gray = ( trimap == 128).sum ().item ()
    while n_gray < min_pixels :
        t_low = max ( t_low - step , t_min )
        t_high = min ( t_high + step , t_max )

        if ( t_low <= t_min and t_high >= t_max ) :
            break # Exit if bounds reached

        trimap = torch.where (
            mask >= t_high ,
            trimap.new_tensor (255.0) ,
            torch.where (
                mask <= t_low ,
                trimap.new_tensor (0.0) ,
                trimap.new_tensor (128.0)
            )
        )
        n_gray = ( trimap == 128).sum ().item ()

    return trimap

def main():
    # Download test file
    url = "https://raw.githubusercontent.com/frcs/alternative-matting-laplacian/master/result-alpha-GT04.png"
    if not os.path.isfile("alpha.png"):
        urllib.request.urlretrieve(url, "alpha.png")

    alpha = np.array(Image.open("alpha.png").convert("L"))
    alpha = alpha.astype(np.float32) / 255.0
    alpha = torch.from_numpy(alpha)

    pred = alpha * 10 - 5

    for _ in range(10):
        torch.cuda.synchronize()
        t = time.perf_counter()
        trimap = generate_trimap(pred)
        torch.cuda.synchronize()
        dt1 = time.perf_counter() - t

        torch.cuda.synchronize()
        t = time.perf_counter()
        trimap_new = generate_trimap_vectorized(pred)
        torch.cuda.synchronize()
        dt2 = time.perf_counter() - t

        print(f"{dt1 * 1000:7.3f} ms for generate_trimap")
        print(f"{dt2 * 1000:7.3f} ms for generate_trimap_vectorized")
        print()

    for i, img in enumerate([alpha, trimap, trimap_new]):
        plt.subplot(1, 3, 1 + i)
        plt.imshow(img.detach().cpu().numpy(), cmap="gray")
    plt.show()

if __name__ == "__main__":
    main()
Prama LLC org

Hello, thank you for your comment. We are still preparing the paper for peer review, so outside perspectives are greatly appreciated. The range of the pred is 0 and 1 as it is the result of sigmoid(logits). This algorithm was used for BEN, and our BEN2 models employ a fixed range. This prevents the steps through the thresholds entirely:
#BEN2 trimap generation
min_low_threshold = 0.01 # Set minimum limit for low_threshold
max_high_threshold = 0.99 # Set maximum limit for high_threshold

    mask = predicted_output.sigmoid()  # Use the sigmoid output directly

    # Start with initial trimap
    trimap = torch.where(
        mask >= max_high_threshold,
        mask.new_tensor(255.0),
        torch.where(
            mask <= min_low_threshold,
            mask.new_tensor(0.0),
            mask.new_tensor(128.0)
        )
    )

We found better model generalization results this way. The purpose of the PyTorch implementation was more for research and demonstration.

MaxwellMeyer changed discussion status to closed

Sign up or log in to comment