|
import os |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import gradio as gr |
|
import argparse |
|
from pathlib import Path |
|
from glob import glob |
|
from typing import Optional, Tuple, List |
|
from PIL import Image |
|
from transformers import AutoModelForImageSegmentation |
|
from torchvision import transforms |
|
import time |
|
import os |
|
import platform |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Run the image segmentation app") |
|
parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface") |
|
return parser.parse_args() |
|
|
|
torch.set_float32_matmul_precision('high') |
|
torch.jit.script = lambda f: f |
|
|
|
os.environ['HOME'] = os.path.expanduser('~') |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def open_folder(): |
|
open_folder_path = os.path.abspath("results") |
|
if platform.system() == "Windows": |
|
os.startfile(open_folder_path) |
|
elif platform.system() == "Linux": |
|
os.system(f'xdg-open "{open_folder_path}"') |
|
|
|
class ImagePreprocessor(): |
|
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None: |
|
self.transform_image = transforms.Compose([ |
|
transforms.ToTensor(), |
|
]) |
|
self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
|
def proc(self, image: Image.Image) -> torch.Tensor: |
|
image = image.convert('RGB') |
|
image = self.transform_image(image) |
|
return self.normalize(image) |
|
|
|
usage_to_weights_file = { |
|
'General': 'BiRefNet', |
|
'General-Lite': 'BiRefNet_T', |
|
'Portrait': 'BiRefNet-portrait', |
|
'DIS': 'BiRefNet-DIS5K', |
|
'HRSOD': 'BiRefNet-HRSOD', |
|
'COD': 'BiRefNet-COD', |
|
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs' |
|
} |
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True) |
|
birefnet.to(device) |
|
birefnet.eval() |
|
|
|
def process_single_image(image_path: str, resolution: str, output_folder: str) -> Tuple[str, str, float]: |
|
start_time = time.time() |
|
|
|
image = Image.open(image_path).convert('RGBA') |
|
|
|
if resolution == '': |
|
resolution = f"{image.width}x{image.height}" |
|
resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')] |
|
|
|
image_shape = image.size[::-1] |
|
image_pil = image.resize(tuple(resolution)) |
|
|
|
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution)) |
|
image_proc = image_preprocessor.proc(image_pil) |
|
image_proc = image_proc.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid() |
|
|
|
if device == 'cuda': |
|
scaled_pred_tensor = scaled_pred_tensor.cpu() |
|
|
|
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy() |
|
|
|
pred_rgba = np.zeros((*pred.shape, 4), dtype=np.uint8) |
|
pred_rgba[..., :3] = (pred[..., np.newaxis] * 255).astype(np.uint8) |
|
pred_rgba[..., 3] = (pred * 255).astype(np.uint8) |
|
|
|
image_array = np.array(image) |
|
image_pred = image_array * (pred_rgba / 255.0) |
|
|
|
output_image = Image.fromarray(image_pred.astype(np.uint8), 'RGBA') |
|
|
|
base_filename = os.path.splitext(os.path.basename(image_path))[0] |
|
output_path = os.path.join(output_folder, f"{base_filename}.png") |
|
|
|
counter = 1 |
|
while os.path.exists(output_path): |
|
output_path = os.path.join(output_folder, f"{base_filename}_{counter:04d}.png") |
|
counter += 1 |
|
|
|
output_image.save(output_path) |
|
|
|
processing_time = time.time() - start_time |
|
print(f"Processed {image_path} in {processing_time:.4f} seconds") |
|
return image_path, output_path, processing_time |
|
|
|
def predict( |
|
image: str, |
|
resolution: str, |
|
weights_file: Optional[str], |
|
batch_folder: Optional[str] = None, |
|
output_folder: Optional[str] = None, |
|
is_batch: bool = False |
|
) -> Tuple[str, List[Tuple[str, str]]]: |
|
global birefnet |
|
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General'])) |
|
print('Using weights:', _weights_file) |
|
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True) |
|
birefnet.to(device) |
|
birefnet.eval() |
|
|
|
if not output_folder: |
|
output_folder = 'results' |
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
results = [] |
|
|
|
if is_batch and batch_folder: |
|
image_files = glob(os.path.join(batch_folder, '*')) |
|
total_images = len(image_files) |
|
processed_images = 0 |
|
start_time = time.time() |
|
|
|
for img_path in image_files: |
|
try: |
|
input_path, output_path, proc_time = process_single_image(img_path, resolution, output_folder) |
|
results.append((output_path, f"{proc_time:.4f} seconds")) |
|
processed_images += 1 |
|
elapsed_time = time.time() - start_time |
|
avg_time_per_image = elapsed_time / processed_images |
|
estimated_time_left = avg_time_per_image * (total_images - processed_images) |
|
|
|
status = f"Processed {processed_images}/{total_images} images. Estimated time left: {estimated_time_left:.2f} seconds" |
|
print(status) |
|
except Exception as e: |
|
print(f"Error processing {img_path}: {str(e)}") |
|
continue |
|
|
|
return f"Batch processing complete. Processed {processed_images}/{total_images} images.", results |
|
else: |
|
input_path, output_path, proc_time = process_single_image(image, resolution, output_folder) |
|
results.append((output_path, f"{proc_time:.4f} seconds")) |
|
return "Single image processing complete.", results |
|
|
|
def create_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("## SECourses Improved BiRefNet V2 'Bilateral Reference for High-Resolution Dichotomous Image Segmentation' APP - SOTA Background Remover") |
|
gr.Markdown("## Most Advanced Latest Version On : https://www.patreon.com/posts/109913645") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(type="filepath", label="Input Image",height=512) |
|
output_image = gr.Gallery(label="Output Image", elem_id="gallery",height=512) |
|
|
|
|
|
with gr.Row(): |
|
resolution = gr.Textbox(label="Resolution", placeholder="1024x1024 - Optional - Don't enter to use original image resolution - Higher res uses more VRAM but still works perfect with shared VRAM so fast") |
|
weights_file = gr.Dropdown(choices=list(usage_to_weights_file.keys()), value="General", label="Weights File") |
|
btn_open_outputs = gr.Button("Open Results Folder") |
|
btn_open_outputs.click(fn=open_folder) |
|
|
|
with gr.Row(): |
|
batch_folder = gr.Textbox(label="Batch Folder Path") |
|
output_folder = gr.Textbox(label="Output Folder Path", value="results") |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Single Image Process") |
|
batch_button = gr.Button("Batch Process Images in Given Folder") |
|
|
|
output_text = gr.Textbox(label="Processing Status") |
|
|
|
submit_button.click( |
|
predict, |
|
inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=False, visible=False)], |
|
outputs=[output_text, output_image] |
|
) |
|
|
|
batch_button.click( |
|
predict, |
|
inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=True, visible=False)], |
|
outputs=[output_text, output_image] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
demo = create_interface() |
|
demo.launch(inbrowser=True, share=args.share) |