Spaces:
Runtime error
Runtime error
import streamlit as st | |
from diffusers import StableDiffusionInpaintPipeline | |
import os | |
from tqdm import tqdm | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
import warnings | |
from huggingface_hub import hf_hub_download | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from data.base_dataset import Normalize_image | |
from utils.saving_utils import load_checkpoint_mgpu | |
from networks import U2NET | |
import argparse | |
from enum import Enum | |
from rembg import remove | |
from dataclasses import dataclass | |
class StableFashionCLIArgs: | |
image = None | |
part = None | |
resolution = None | |
promt = None | |
num_steps = None | |
guidance_scale = None | |
rembg = None | |
class Parts: | |
UPPER = 1 | |
LOWER = 2 | |
def load_u2net(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth") | |
net = U2NET(in_ch=3, out_ch=4) | |
net = load_checkpoint_mgpu(net, checkpoint_path) | |
net = net.to(device) | |
net = net.eval() | |
return net | |
def change_bg_color(rgba_image, color): | |
new_image = Image.new("RGBA", rgba_image.size, color) | |
new_image.paste(rgba_image, (0, 0), rgba_image) | |
return new_image.convert("RGB") | |
def load_inpainting_pipeline(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_auth_token=os.environ["hf_auth_token"] | |
).to(device) | |
return inpainting_pipeline | |
def process_image(args, inpainting_pipeline, net): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
image_path = args.image | |
transforms_list = [] | |
transforms_list += [transforms.ToTensor()] | |
transforms_list += [Normalize_image(0.5, 0.5)] | |
transform_rgb = transforms.Compose(transforms_list) | |
img = Image.open(image_path) | |
img = img.convert("RGB") | |
img = img.resize((args.resolution, args.resolution)) | |
if args.rembg: | |
img_with_green_bg = remove(img) | |
img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN") | |
img_with_green_bg = img_with_green_bg.convert("RGB") | |
else: | |
img_with_green_bg = img | |
image_tensor = transform_rgb(img_with_green_bg) | |
image_tensor = image_tensor.unsqueeze(0) | |
with torch.autocast(device_type=device): | |
output_tensor = net(image_tensor.to(device)) | |
output_tensor = F.log_softmax(output_tensor[0], dim=1) | |
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] | |
output_tensor = torch.squeeze(output_tensor, dim=0) | |
output_tensor = torch.squeeze(output_tensor, dim=0) | |
output_arr = output_tensor.cpu().numpy() | |
mask_code = eval(f"Parts.{args.part.upper()}") | |
mask = (output_arr == mask_code) | |
output_arr[mask] = 1 | |
output_arr[~mask] = 0 | |
output_arr *= 255 | |
mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L") | |
clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt, | |
image=img_with_green_bg, | |
mask_image=mask_PIL, | |
width=args.resolution, | |
height=args.resolution, | |
guidance_scale=args.guidance_scale, | |
num_inference_steps=args.num_steps).images[0] | |
clothed_image_from_pipeline = remove(clothed_image_from_pipeline) | |
clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE") | |
return clothed_image_from_pipeline.convert("RGB"). mask_PIL | |
net = load_u2net() | |
inpainting_pipeline = load_inpainting_pipeline() | |
st.markdown( | |
""" | |
<p style='text-align: center'> | |
<a href='https://github.com/ovshake' target='_blank'>ovshake Github</a> | <a href='https://github.com/ovshake/stable-fashion' target='_blank'>Stable Fashion Github</a> | <a href='https://huggingface.co/spaces/maiti/stable-fashion' target='_blank'>Stable Fashion Demo</a> | |
<br /> | |
Follow me for more! <a href='https://twitter.com/o_v_shake' target='_blank'> <img src="https://img.icons8.com/color/48/000000/twitter--v1.png" height="30"></a><a href='https://github.com/ovshake' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/github.png" height="27"></a><a href='https://www.linkedin.com/in/ovshake/' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.title("Stable Fashion Huggingface Spaces") | |
file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background") | |
stable_fashion_args = StableFashionCLIArgs() | |
stable_fashion_args.image = file_name | |
body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower')) | |
stable_fashion_args.part = body_part | |
resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512), index=2) | |
stable_fashion_args.resolution = resolution | |
rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"), index=0) | |
stable_fashion_args.rembg = (rembg_status == "Yes") | |
guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5) | |
stable_fashion_args.guidance_scale = guidance_scale | |
prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt') | |
stable_fashion_args.prompt = prompt | |
num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25) | |
stable_fashion_args.num_steps = num_steps | |
if file_name is not None: | |
result_image, mask_PIL = process_image(stable_fashion_args, inpainting_pipeline, net) | |
st.image(result_image, caption='Result') | |
st.image(mask_PIL, caption='Mask') | |
else: | |
stock_image = Image.open('assets/abhishek_yellow.jpg') | |
st.image(stock_image, caption='Result') | |