Spaces:
Running
Running
File size: 7,353 Bytes
d869b2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import hashlib
import os
from io import BytesIO
import gradio as gr
import grpc
from PIL import Image
from cachetools import LRUCache
from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face
def get_bytes(img):
if img is None:
return img
buffered = BytesIO()
img.save(buffered, format="JPEG")
return buffered.getvalue()
def bytes_to_image(image: bytes) -> Image.Image:
image = Image.open(BytesIO(image))
return image
def center_crop(img):
width, height = img.size
side = min(width, height)
left = (width - side) / 2
top = (height - side) / 2
right = (width + side) / 2
bottom = (height + side) / 2
img = img.crop((left, top, right, bottom))
return img
def resize(name):
def resize_inner(img, align):
global align_cache
if name in align:
img_hash = hashlib.md5(get_bytes(img)).hexdigest()
if img_hash not in align_cache:
img = align_face(img, return_tensors=False)[0]
align_cache[img_hash] = img
else:
img = align_cache[img_hash]
elif img.size != (1024, 1024):
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
return resize_inner
def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
# Verificar que al menos uno de shape o color esté presente junto con face
if not face and not (shape or color):
return (gr.update(visible=False),
gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True))
elif not face:
return (gr.update(visible=False),
gr.update(value="Need to upload a face ❗", visible=True))
elif not (shape or color):
return (gr.update(visible=False),
gr.update(value="Need to upload at least a shape or color ❗", visible=True))
# Obtener los bytes de los blobs
face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item) if item else None, (face, shape, color))
# Asignar valores por defecto si no están presentes
if shape_bytes is None:
shape_bytes = b'face'
if color_bytes is None:
color_bytes = b'shape'
try:
with grpc.insecure_channel(os.environ['SERVER']) as channel:
stub = HairSwapServiceStub(channel)
output: HairSwapResponse = stub.swap(
HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
)
output_image = bytes_to_image(output.image)
return (gr.update(value=output_image, visible=True),
gr.update(visible=False))
except grpc.RpcError as e:
# Manejo de errores de gRPC
error_message = f"gRPC error: {e.code()}: {e.details()}"
return (gr.update(visible=False),
gr.update(value=error_message, visible=True))
except Exception as e:
# Manejo de cualquier otro error
error_message = f"Unexpected error: {str(e)}"
return (gr.update(visible=False),
gr.update(value=error_message, visible=True))
def get_demo():
with gr.Blocks() as demo:
gr.Markdown("## HairFastGan")
gr.Markdown(
'<div style="display: flex; align-items: center; gap: 10px;">'
'<span>Official HairFastGAN Gradio demo:</span>'
'<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
'<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
'<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
'<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
'</div>'
)
with gr.Row():
with gr.Column():
source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
with gr.Row():
shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
with gr.Accordion("Advanced Options", open=False):
blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
label="Color Encoder version", info="Selects a model for hair color transfer.")
poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
info="Smooths out the blending area.")
align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
label="Image cropping [recommended]",
info="Selects which images to crop by face")
btn = gr.Button("Get the haircut")
with gr.Column():
output = gr.Image(label="Your result")
error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")
gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
["input/10.jpg", None, "input/11.jpg"]],
inputs=[source, shape, color], outputs=output)
source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
outputs=[output, error_message])
gr.Markdown('''To cite the paper by the authors
```
@article{nikolaev2024hairfastgan,
title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
journal={arXiv preprint arXiv:2404.01094},
year={2024}
}
```
''')
return demo
if __name__ == '__main__':
align_cache = LRUCache(maxsize=10)
demo = get_demo()
demo.launch(server_name="0.0.0.0", server_port=7860)
|