Spaces:
Runtime error
Runtime error
File size: 6,768 Bytes
6724ca0 46e4e58 6724ca0 ad35187 6724ca0 ad35187 6724ca0 64135f7 6724ca0 5664062 6724ca0 7f50dc5 6724ca0 55f0240 6724ca0 6252c13 6724ca0 47512ef 5664062 da6de3d 1b1a677 da6de3d 1646c1f da6de3d 8b4c4c8 5664062 6724ca0 152d97b ae6d3bd 5664062 152d97b 975c056 45007eb 975c056 6724ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import streamlit as st
from diffusers import StableDiffusionInpaintPipeline
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import warnings
from huggingface_hub import hf_hub_download
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
import argparse
from enum import Enum
from rembg import remove
from dataclasses import dataclass
@dataclass
class StableFashionCLIArgs:
image = None
part = None
resolution = None
promt = None
num_steps = None
guidance_scale = None
rembg = None
class Parts:
UPPER = 1
LOWER = 2
@st.cache(allow_output_mutation=True)
def load_u2net():
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth")
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()
return net
def change_bg_color(rgba_image, color):
new_image = Image.new("RGBA", rgba_image.size, color)
new_image.paste(rgba_image, (0, 0), rgba_image)
return new_image.convert("RGB")
@st.cache(allow_output_mutation=True)
def load_inpainting_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=os.environ["hf_auth_token"]
).to(device)
return inpainting_pipeline
def process_image(args, inpainting_pipeline, net):
device = "cuda" if torch.cuda.is_available() else "cpu"
image_path = args.image
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
img = Image.open(image_path)
img = img.convert("RGB")
img = img.resize((args.resolution, args.resolution))
if args.rembg:
img_with_green_bg = remove(img)
img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
img_with_green_bg = img_with_green_bg.convert("RGB")
else:
img_with_green_bg = img
image_tensor = transform_rgb(img_with_green_bg)
image_tensor = image_tensor.unsqueeze(0)
with torch.autocast(device_type=device):
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
mask_code = eval(f"Parts.{args.part.upper()}")
mask = (output_arr == mask_code)
print(f"Numbers in output_rr")
print(np.unique(output_arr))
print(f"mask code {mask_code}")
output_arr[mask] = 1
output_arr[~mask] = 0
output_arr *= 255
mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
image=img_with_green_bg,
mask_image=mask_PIL,
width=args.resolution,
height=args.resolution,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_steps).images[0]
clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
return clothed_image_from_pipeline.convert("RGB"), mask_PIL
net = load_u2net()
inpainting_pipeline = load_inpainting_pipeline()
st.markdown(
"""
<p style='text-align: center'>
<a href='https://github.com/ovshake' target='_blank'>ovshake Github</a> | <a href='https://github.com/ovshake/stable-fashion' target='_blank'>Stable Fashion Github</a> | <a href='https://huggingface.co/spaces/maiti/stable-fashion' target='_blank'>Stable Fashion Demo</a>
<br />
Follow me for more! <a href='https://twitter.com/o_v_shake' target='_blank'> <img src="https://img.icons8.com/color/48/000000/twitter--v1.png" height="30"></a><a href='https://github.com/ovshake' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/github.png" height="27"></a><a href='https://www.linkedin.com/in/ovshake/' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a>
</p>
""",
unsafe_allow_html=True,
)
st.title("Stable Fashion Huggingface Spaces")
file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background")
stable_fashion_args = StableFashionCLIArgs()
stable_fashion_args.image = file_name
body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower'))
stable_fashion_args.part = body_part
resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512), index=2)
stable_fashion_args.resolution = resolution
rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"), index=0)
stable_fashion_args.rembg = (rembg_status == "Yes")
guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5)
stable_fashion_args.guidance_scale = guidance_scale
prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt')
stable_fashion_args.prompt = prompt
num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25)
stable_fashion_args.num_steps = num_steps
if file_name is not None:
result_image, mask_PIL = process_image(stable_fashion_args, inpainting_pipeline, net)
print(np.unique(np.asarray(mask_PIL)))
st.image(result_image, caption='Result')
st.image(mask_PIL, caption='Mask')
else:
stock_image = Image.open('assets/abhishek_yellow.jpg')
st.image(stock_image, caption='Result')
|