File size: 1,906 Bytes
37aeb5b 5a3e910 37aeb5b f38a22d 37aeb5b 2fc8dce 37aeb5b 8bfc447 37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import sys
import torch
import gradio as gr
from PIL import Image
import numpy as np
from rembg import remove
from gradio_app.utils import change_rgba_bg, rgba_to_rgb
from gradio_app.custom_models.utils import load_pipeline
from scripts.all_typing import *
from scripts.utils import session, simple_preprocess
training_config = "gradio_app/custom_models/image2mvimage.yaml"
checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
trainer, pipeline = load_pipeline(training_config, checkpoint_path)
def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
pipeline.enable_model_cpu_offload()
if isinstance(img_list, Image.Image):
img_list = [img_list]
img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
ret = []
for img in img_list:
images = trainer.pipeline_forward(
pipeline=pipeline,
image=img,
guidance_scale=guidance_scale,
**kwargs
).images
ret.extend(images)
return ret
def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
# still do remove using rembg, since simple_preprocess requires RGBA image
print("RGB image not RGBA! still remove bg!")
remove_bg = True
if remove_bg:
input_image = remove(input_image, session=session)
# make front_pil RGBA with white bg
input_image = change_rgba_bg(input_image, "white")
single_image = simple_preprocess(input_image)
generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
rgb_pils = predict(
single_image,
generator=generator,
guidance_scale=guidance_scale,
width=256,
height=256,
num_inference_steps=30,
)
return rgb_pils, single_image
|