Jhony23German commited on
Commit
d869b2a
1 Parent(s): f9581ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -151
app.py CHANGED
@@ -1,151 +1,170 @@
1
- import hashlib
2
- import os
3
- from io import BytesIO
4
-
5
- import gradio as gr
6
- import grpc
7
- from PIL import Image
8
- from cachetools import LRUCache
9
-
10
- from inference_pb2 import HairSwapRequest, HairSwapResponse
11
- from inference_pb2_grpc import HairSwapServiceStub
12
- from utils.shape_predictor import align_face
13
-
14
-
15
- def get_bytes(img):
16
- if img is None:
17
- return img
18
-
19
- buffered = BytesIO()
20
- img.save(buffered, format="JPEG")
21
- return buffered.getvalue()
22
-
23
-
24
- def bytes_to_image(image: bytes) -> Image.Image:
25
- image = Image.open(BytesIO(image))
26
- return image
27
-
28
-
29
- def center_crop(img):
30
- width, height = img.size
31
- side = min(width, height)
32
-
33
- left = (width - side) / 2
34
- top = (height - side) / 2
35
- right = (width + side) / 2
36
- bottom = (height + side) / 2
37
-
38
- img = img.crop((left, top, right, bottom))
39
- return img
40
-
41
-
42
- def resize(name):
43
- def resize_inner(img, align):
44
- global align_cache
45
-
46
- if name in align:
47
- img_hash = hashlib.md5(get_bytes(img)).hexdigest()
48
-
49
- if img_hash not in align_cache:
50
- img = align_face(img, return_tensors=False)[0]
51
- align_cache[img_hash] = img
52
- else:
53
- img = align_cache[img_hash]
54
-
55
- elif img.size != (1024, 1024):
56
- img = center_crop(img)
57
- img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
58
-
59
- return img
60
-
61
- return resize_inner
62
-
63
-
64
- def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
65
- if not face and not shape and not color:
66
- return gr.update(visible=False), gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True)
67
- elif not face:
68
- return gr.update(visible=False), gr.update(value="Need to upload a face ❗", visible=True)
69
- elif not shape and not color:
70
- return gr.update(visible=False), gr.update(value="Need to upload at least a shape or color ❗", visible=True)
71
-
72
- face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color))
73
-
74
- if shape_bytes is None:
75
- shape_bytes = b'face'
76
- if color_bytes is None:
77
- color_bytes = b'shape'
78
-
79
- with grpc.insecure_channel(os.environ['SERVER']) as channel:
80
- stub = HairSwapServiceStub(channel)
81
-
82
- output: HairSwapResponse = stub.swap(
83
- HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
84
- poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
85
- )
86
-
87
- output = bytes_to_image(output.image)
88
- return gr.update(value=output, visible=True), gr.update(visible=False)
89
-
90
-
91
- def get_demo():
92
- with gr.Blocks() as demo:
93
- gr.Markdown("## HairFastGan")
94
- gr.Markdown(
95
- '<div style="display: flex; align-items: center; gap: 10px;">'
96
- '<span>Official HairFastGAN Gradio demo:</span>'
97
- '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
98
- '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
99
- '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
100
- '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
101
- '</div>'
102
- )
103
- with gr.Row():
104
- with gr.Column():
105
- source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
106
- with gr.Row():
107
- shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
108
- color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
109
- with gr.Accordion("Advanced Options", open=False):
110
- blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
111
- label="Color Encoder version", info="Selects a model for hair color transfer.")
112
- poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
113
- info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
114
- poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
115
- info="Smooths out the blending area.")
116
- align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
117
- label="Image cropping [recommended]",
118
- info="Selects which images to crop by face")
119
- btn = gr.Button("Get the haircut")
120
- with gr.Column():
121
- output = gr.Image(label="Your result")
122
- error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")
123
-
124
- gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
125
- ["input/10.jpg", None, "input/11.jpg"]],
126
- inputs=[source, shape, color], outputs=output)
127
-
128
- source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
129
- shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
130
- color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
131
-
132
- btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
133
- outputs=[output, error_message])
134
-
135
- gr.Markdown('''To cite the paper by the authors
136
- ```
137
- @article{nikolaev2024hairfastgan,
138
- title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
139
- author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
140
- journal={arXiv preprint arXiv:2404.01094},
141
- year={2024}
142
- }
143
- ```
144
- ''')
145
- return demo
146
-
147
-
148
- if __name__ == '__main__':
149
- align_cache = LRUCache(maxsize=10)
150
- demo = get_demo()
151
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ from io import BytesIO
4
+
5
+ import gradio as gr
6
+ import grpc
7
+ from PIL import Image
8
+ from cachetools import LRUCache
9
+
10
+ from inference_pb2 import HairSwapRequest, HairSwapResponse
11
+ from inference_pb2_grpc import HairSwapServiceStub
12
+ from utils.shape_predictor import align_face
13
+
14
+
15
+ def get_bytes(img):
16
+ if img is None:
17
+ return img
18
+
19
+ buffered = BytesIO()
20
+ img.save(buffered, format="JPEG")
21
+ return buffered.getvalue()
22
+
23
+
24
+ def bytes_to_image(image: bytes) -> Image.Image:
25
+ image = Image.open(BytesIO(image))
26
+ return image
27
+
28
+
29
+ def center_crop(img):
30
+ width, height = img.size
31
+ side = min(width, height)
32
+
33
+ left = (width - side) / 2
34
+ top = (height - side) / 2
35
+ right = (width + side) / 2
36
+ bottom = (height + side) / 2
37
+
38
+ img = img.crop((left, top, right, bottom))
39
+ return img
40
+
41
+
42
+ def resize(name):
43
+ def resize_inner(img, align):
44
+ global align_cache
45
+
46
+ if name in align:
47
+ img_hash = hashlib.md5(get_bytes(img)).hexdigest()
48
+
49
+ if img_hash not in align_cache:
50
+ img = align_face(img, return_tensors=False)[0]
51
+ align_cache[img_hash] = img
52
+ else:
53
+ img = align_cache[img_hash]
54
+
55
+ elif img.size != (1024, 1024):
56
+ img = center_crop(img)
57
+ img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
58
+
59
+ return img
60
+
61
+ return resize_inner
62
+
63
+
64
+ def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
65
+ # Verificar que al menos uno de shape o color esté presente junto con face
66
+ if not face and not (shape or color):
67
+ return (gr.update(visible=False),
68
+ gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True))
69
+ elif not face:
70
+ return (gr.update(visible=False),
71
+ gr.update(value="Need to upload a face ❗", visible=True))
72
+ elif not (shape or color):
73
+ return (gr.update(visible=False),
74
+ gr.update(value="Need to upload at least a shape or color ❗", visible=True))
75
+
76
+ # Obtener los bytes de los blobs
77
+ face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item) if item else None, (face, shape, color))
78
+
79
+ # Asignar valores por defecto si no están presentes
80
+ if shape_bytes is None:
81
+ shape_bytes = b'face'
82
+ if color_bytes is None:
83
+ color_bytes = b'shape'
84
+
85
+ try:
86
+ with grpc.insecure_channel(os.environ['SERVER']) as channel:
87
+ stub = HairSwapServiceStub(channel)
88
+
89
+ output: HairSwapResponse = stub.swap(
90
+ HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
91
+ poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
92
+ )
93
+
94
+ output_image = bytes_to_image(output.image)
95
+ return (gr.update(value=output_image, visible=True),
96
+ gr.update(visible=False))
97
+ except grpc.RpcError as e:
98
+ # Manejo de errores de gRPC
99
+ error_message = f"gRPC error: {e.code()}: {e.details()}"
100
+ return (gr.update(visible=False),
101
+ gr.update(value=error_message, visible=True))
102
+ except Exception as e:
103
+ # Manejo de cualquier otro error
104
+ error_message = f"Unexpected error: {str(e)}"
105
+ return (gr.update(visible=False),
106
+ gr.update(value=error_message, visible=True))
107
+
108
+
109
+
110
+ def get_demo():
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown("## HairFastGan")
113
+ gr.Markdown(
114
+ '<div style="display: flex; align-items: center; gap: 10px;">'
115
+ '<span>Official HairFastGAN Gradio demo:</span>'
116
+ '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
117
+ '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
118
+ '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
119
+ '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
120
+ '</div>'
121
+ )
122
+ with gr.Row():
123
+ with gr.Column():
124
+ source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
125
+ with gr.Row():
126
+ shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
127
+ color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
128
+ with gr.Accordion("Advanced Options", open=False):
129
+ blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
130
+ label="Color Encoder version", info="Selects a model for hair color transfer.")
131
+ poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
132
+ info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
133
+ poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
134
+ info="Smooths out the blending area.")
135
+ align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
136
+ label="Image cropping [recommended]",
137
+ info="Selects which images to crop by face")
138
+ btn = gr.Button("Get the haircut")
139
+ with gr.Column():
140
+ output = gr.Image(label="Your result")
141
+ error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")
142
+
143
+ gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
144
+ ["input/10.jpg", None, "input/11.jpg"]],
145
+ inputs=[source, shape, color], outputs=output)
146
+
147
+ source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
148
+ shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
149
+ color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
150
+
151
+ btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
152
+ outputs=[output, error_message])
153
+
154
+ gr.Markdown('''To cite the paper by the authors
155
+ ```
156
+ @article{nikolaev2024hairfastgan,
157
+ title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
158
+ author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
159
+ journal={arXiv preprint arXiv:2404.01094},
160
+ year={2024}
161
+ }
162
+ ```
163
+ ''')
164
+ return demo
165
+
166
+
167
+ if __name__ == '__main__':
168
+ align_cache = LRUCache(maxsize=10)
169
+ demo = get_demo()
170
+ demo.launch(server_name="0.0.0.0", server_port=7860)