File size: 3,833 Bytes
542c815 3f8e328 542c815 37ded0e 542c815 a888400 d6e753e 8a357d1 37ded0e 542c815 37ded0e fbe23c5 37ded0e c0a3a3c 542c815 4f91b95 542c815 0c9e50b 542c815 0c9e50b 1605763 70974c3 542c815 70974c3 542c815 70974c3 542c815 70974c3 542c815 70974c3 542c815 70974c3 542c815 d909bca 542c815 d909bca 542c815 d909bca 8cb0f2e d909bca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
# from foo import hello
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
# import git # pip install gitpython
# hello()
# git.Git(".").clone("https://huggingface.co/briaai/RMBG-1.4")
# git.Git(".").clone("git@hf.co:briaai/RMBG-1.4")
net=BriaRMBG()
model_path = "./model.pth"
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
def image_size_by_min_resolution(
image: Image.Image,
resolution: Tuple,
resample=None,
):
w, h = image.size
image_min = min(w, h)
resolution_min = min(resolution)
scale_factor = image_min / resolution_min
resize_to: Tuple[int, int] = (
int(w // scale_factor),
int(h // scale_factor),
)
return resize_to
def resize_image(image):
image = image.convert('RGB')
new_image_size = image_size_by_min_resolution(image=image,resolution=(1024, 1024))
image = image.resize(new_image_size, Image.BILINEAR)
return image
def process(image):
# prepare input
print(type(image))
print(image.shape)
orig_image = Image.fromarray(image)
# return [orig_image,orig_image]
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
print("process debug1")
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
print("process debug2")
#inference
result=net(im_tensor)
print("process debug3")
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
print("process debug4")
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,0,0))
new_im.paste(orig_image, mask=pil_im)
return [orig_image, new_im]
# block = gr.Blocks().queue()
# with block:
# gr.Markdown("## BRIA RMBG 1.4")
# gr.HTML('''
# <p style="margin-bottom: 10px; font-size: 94%">
# This is a demo for BRIA RMBG 1.4 that using
# <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
# </p>
# ''')
# with gr.Row():
# with gr.Column():
# input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
# # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
# run_button = gr.Button(value="Run")
# with gr.Column():
# result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
# ips = [input_image]
# run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
# block.launch(debug = True)
title = "background_removal"
description = "remove image background"
examples = [['./input.jpg'],]
output = ImageSlider(position=0.5,label='Image without background slider-view', type="pil")
demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
if __name__ == "__main__":
demo.launch(share=False) |