import os from io import BytesIO import gradio as gr import grpc from PIL import Image from cachetools import LRUCache import hashlib from inference_pb2 import HairSwapRequest, HairSwapResponse from inference_pb2_grpc import HairSwapServiceStub from utils.shape_predictor import align_face def get_bytes(img): if img is None: return img buffered = BytesIO() img.save(buffered, format="JPEG") return buffered.getvalue() def bytes_to_image(image: bytes) -> Image.Image: image = Image.open(BytesIO(image)) return image def center_crop(img): width, height = img.size side = min(width, height) left = (width - side) / 2 top = (height - side) / 2 right = (width + side) / 2 bottom = (height + side) / 2 img = img.crop((left, top, right, bottom)) return img def resize(name): def resize_inner(img, align): global align_cache if name in align: img_hash = hashlib.md5(get_bytes(img)).hexdigest() if img_hash not in align_cache: img = align_face(img, return_tensors=False)[0] align_cache[img_hash] = img else: img = align_cache[img_hash] elif img.size != (1024, 1024): img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img return resize_inner def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion, progress=gr.Progress(track_tqdm=True)): if not face or not shape and not color: raise ValueError("Need to upload a face and at least a shape or color") face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color)) if shape_bytes is None: shape_bytes = b'face' if color_bytes is None: color_bytes = b'shape' with grpc.insecure_channel(os.environ['SERVER']) as channel: stub = HairSwapServiceStub(channel) output: HairSwapResponse = stub.swap( HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending, poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True) ) output = bytes_to_image(output.image) return output def get_demo(): with gr.Blocks() as demo: gr.Markdown("## HairFastGan") gr.Markdown( '
' 'Official HairFastGAN Gradio demo:' '' '' '' '' '
' ) with gr.Row(): with gr.Column(): source = gr.Image(label="Photo that you want to replace the hair", type="pil") with gr.Row(): shape = gr.Image(label="Reference hair you want to get (optional)", type="pil") color = gr.Image(label="Reference color hair you want to get (optional)", type="pil") with gr.Accordion("Advanced Options", open=False): blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article', label="Blending version", info="Selects a model for hair color transfer.") poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters", info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.") poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion", info="Smooths out the blending area.") align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"], label="Image cropping [recommended]", info="Selects which images to crop by face") btn = gr.Button("Get the haircut") with gr.Column(): output = gr.Image(label="Your result") gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None], ["input/10.jpg", None, "input/11.jpg"]], inputs=[source, shape, color], outputs=output) source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], outputs=output) gr.Markdown('''To cite the paper by the authors ``` @article{nikolaev2024hairfastgan, title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach}, author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek}, journal={arXiv preprint arXiv:2404.01094}, year={2024} } ``` ''') return demo if __name__ == '__main__': align_cache = LRUCache(maxsize=10) demo = get_demo() demo.launch(server_name="0.0.0.0", server_port=7860)