Spaces:
Running
Running
import hashlib | |
import os | |
from io import BytesIO | |
import gradio as gr | |
import grpc | |
from PIL import Image | |
from cachetools import LRUCache | |
from inference_pb2 import HairSwapRequest, HairSwapResponse | |
from inference_pb2_grpc import HairSwapServiceStub | |
from utils.shape_predictor import align_face | |
def get_bytes(img): | |
if img is None: | |
return img | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
return buffered.getvalue() | |
def bytes_to_image(image: bytes) -> Image.Image: | |
image = Image.open(BytesIO(image)) | |
return image | |
def center_crop(img): | |
width, height = img.size | |
side = min(width, height) | |
left = (width - side) / 2 | |
top = (height - side) / 2 | |
right = (width + side) / 2 | |
bottom = (height + side) / 2 | |
img = img.crop((left, top, right, bottom)) | |
return img | |
def resize(name): | |
def resize_inner(img, align): | |
global align_cache | |
if name in align: | |
img_hash = hashlib.md5(get_bytes(img)).hexdigest() | |
if img_hash not in align_cache: | |
img = align_face(img, return_tensors=False)[0] | |
align_cache[img_hash] = img | |
else: | |
img = align_cache[img_hash] | |
elif img.size != (1024, 1024): | |
img = center_crop(img) | |
img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
return img | |
return resize_inner | |
def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion): | |
# Verificar que al menos uno de shape o color esté presente junto con face | |
if not face and not (shape or color): | |
return (gr.update(visible=False), | |
gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True)) | |
elif not face: | |
return (gr.update(visible=False), | |
gr.update(value="Need to upload a face ❗", visible=True)) | |
elif not (shape or color): | |
return (gr.update(visible=False), | |
gr.update(value="Need to upload at least a shape or color ❗", visible=True)) | |
# Obtener los bytes de los blobs | |
face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item) if item else None, (face, shape, color)) | |
# Asignar valores por defecto si no están presentes | |
if shape_bytes is None: | |
shape_bytes = b'face' | |
if color_bytes is None: | |
color_bytes = b'shape' | |
try: | |
with grpc.insecure_channel(os.environ['SERVER']) as channel: | |
stub = HairSwapServiceStub(channel) | |
output: HairSwapResponse = stub.swap( | |
HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending, | |
poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True) | |
) | |
output_image = bytes_to_image(output.image) | |
return (gr.update(value=output_image, visible=True), | |
gr.update(visible=False)) | |
except grpc.RpcError as e: | |
# Manejo de errores de gRPC | |
error_message = f"gRPC error: {e.code()}: {e.details()}" | |
return (gr.update(visible=False), | |
gr.update(value=error_message, visible=True)) | |
except Exception as e: | |
# Manejo de cualquier otro error | |
error_message = f"Unexpected error: {str(e)}" | |
return (gr.update(visible=False), | |
gr.update(value=error_message, visible=True)) | |
def get_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## HairFastGan") | |
gr.Markdown( | |
'<div style="display: flex; align-items: center; gap: 10px;">' | |
'<span>Official HairFastGAN Gradio demo:</span>' | |
'<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>' | |
'<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>' | |
'<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>' | |
'<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>' | |
'</div>' | |
) | |
with gr.Row(): | |
with gr.Column(): | |
source = gr.Image(label="Source photo to try on the hairstyle", type="pil") | |
with gr.Row(): | |
shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil") | |
color = gr.Image(label="Color photo with desired hair color (optional)", type="pil") | |
with gr.Accordion("Advanced Options", open=False): | |
blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article', | |
label="Color Encoder version", info="Selects a model for hair color transfer.") | |
poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters", | |
info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.") | |
poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion", | |
info="Smooths out the blending area.") | |
align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"], | |
label="Image cropping [recommended]", | |
info="Selects which images to crop by face") | |
btn = gr.Button("Get the haircut") | |
with gr.Column(): | |
output = gr.Image(label="Your result") | |
error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message") | |
gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None], | |
["input/10.jpg", None, "input/11.jpg"]], | |
inputs=[source, shape, color], outputs=output) | |
source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) | |
shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) | |
color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) | |
btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], | |
outputs=[output, error_message]) | |
gr.Markdown('''To cite the paper by the authors | |
``` | |
@article{nikolaev2024hairfastgan, | |
title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach}, | |
author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek}, | |
journal={arXiv preprint arXiv:2404.01094}, | |
year={2024} | |
} | |
``` | |
''') | |
return demo | |
if __name__ == '__main__': | |
align_cache = LRUCache(maxsize=10) | |
demo = get_demo() | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |