File size: 7,353 Bytes
d869b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import hashlib
import os
from io import BytesIO

import gradio as gr
import grpc
from PIL import Image
from cachetools import LRUCache

from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face


def get_bytes(img):
    if img is None:
        return img

    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    return buffered.getvalue()


def bytes_to_image(image: bytes) -> Image.Image:
    image = Image.open(BytesIO(image))
    return image


def center_crop(img):
    width, height = img.size
    side = min(width, height)

    left = (width - side) / 2
    top = (height - side) / 2
    right = (width + side) / 2
    bottom = (height + side) / 2

    img = img.crop((left, top, right, bottom))
    return img


def resize(name):
    def resize_inner(img, align):
        global align_cache

        if name in align:
            img_hash = hashlib.md5(get_bytes(img)).hexdigest()

            if img_hash not in align_cache:
                img = align_face(img, return_tensors=False)[0]
                align_cache[img_hash] = img
            else:
                img = align_cache[img_hash]

        elif img.size != (1024, 1024):
            img = center_crop(img)
            img = img.resize((1024, 1024), Image.Resampling.LANCZOS)

        return img

    return resize_inner


def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
    # Verificar que al menos uno de shape o color esté presente junto con face
    if not face and not (shape or color):
        return (gr.update(visible=False),
                gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True))
    elif not face:
        return (gr.update(visible=False),
                gr.update(value="Need to upload a face ❗", visible=True))
    elif not (shape or color):
        return (gr.update(visible=False),
                gr.update(value="Need to upload at least a shape or color ❗", visible=True))

    # Obtener los bytes de los blobs
    face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item) if item else None, (face, shape, color))

    # Asignar valores por defecto si no están presentes
    if shape_bytes is None:
        shape_bytes = b'face'
    if color_bytes is None:
        color_bytes = b'shape'

    try:
        with grpc.insecure_channel(os.environ['SERVER']) as channel:
            stub = HairSwapServiceStub(channel)

            output: HairSwapResponse = stub.swap(
                HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
                                poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
            )

        output_image = bytes_to_image(output.image)
        return (gr.update(value=output_image, visible=True),
                gr.update(visible=False))
    except grpc.RpcError as e:
        # Manejo de errores de gRPC
        error_message = f"gRPC error: {e.code()}: {e.details()}"
        return (gr.update(visible=False),
                gr.update(value=error_message, visible=True))
    except Exception as e:
        # Manejo de cualquier otro error
        error_message = f"Unexpected error: {str(e)}"
        return (gr.update(visible=False),
                gr.update(value=error_message, visible=True))



def get_demo():
    with gr.Blocks() as demo:
        gr.Markdown("## HairFastGan")
        gr.Markdown(
            '<div style="display: flex; align-items: center; gap: 10px;">'
            '<span>Official HairFastGAN Gradio demo:</span>'
            '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
            '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
            '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
            '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
            '</div>'
        )
        with gr.Row():
            with gr.Column():
                source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
                with gr.Row():
                    shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
                    color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
                with gr.Accordion("Advanced Options", open=False):
                    blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
                                        label="Color Encoder version", info="Selects a model for hair color transfer.")
                    poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
                                              info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
                    poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
                                                info="Smooths out the blending area.")
                    align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
                                             label="Image cropping [recommended]",
                                             info="Selects which images to crop by face")
                btn = gr.Button("Get the haircut")
            with gr.Column():
                output = gr.Image(label="Your result")
                error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")

        gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
                              ["input/10.jpg", None, "input/11.jpg"]],
                    inputs=[source, shape, color], outputs=output)

        source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
        shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
        color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)

        btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
                  outputs=[output, error_message])

        gr.Markdown('''To cite the paper by the authors
    ```
        @article{nikolaev2024hairfastgan,
          title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
          author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
          journal={arXiv preprint arXiv:2404.01094},
          year={2024}
        }
    ```
        ''')
    return demo


if __name__ == '__main__':
    align_cache = LRUCache(maxsize=10)
    demo = get_demo()
    demo.launch(server_name="0.0.0.0", server_port=7860)