TDN-M commited on
Commit
058abf3
1 Parent(s): 53efe07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -27
app.py CHANGED
@@ -7,52 +7,46 @@ import sys
7
  import spaces
8
  from PIL import Image
9
 
 
10
  subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
11
  os.chdir("HairFastGAN")
12
-
13
  subprocess.run(["git", "clone", "https://huggingface.co/AIRI-Institute/HairFastGAN"], check=True)
14
-
15
  os.chdir("HairFastGAN")
16
  subprocess.run(["git", "lfs", "pull"], check=True)
17
  os.chdir("..")
18
-
19
  shutil.move("HairFastGAN/pretrained_models", "pretrained_models")
20
  shutil.move("HairFastGAN/input", "input")
21
-
22
  shutil.rmtree("HairFastGAN")
23
-
24
  items = os.listdir()
25
-
26
  for item in items:
27
  print(item)
28
  shutil.move(item, os.path.join('..', item))
29
-
30
  os.chdir("..")
31
-
32
  shutil.rmtree("HairFastGAN")
33
 
 
34
  from hair_swap import HairFast, get_parser
35
-
36
  hair_fast = HairFast(get_parser().parse_args([]))
37
 
 
38
  def resize(image_path):
39
  img = Image.open(image_path)
40
  square_size = 1024
41
-
42
  left = (img.width - square_size) / 2
43
  top = (img.height - square_size) / 2
44
  right = (img.width + square_size) / 2
45
- bottom = (img.height + square_size) / 2
46
-
47
  img_cropped = img.crop((left, top, right, bottom))
48
  return img_cropped
49
 
 
50
  @spaces.GPU
51
  def swap_hair(source, target_1, target_2, progress=gr.Progress(track_tqdm=True)):
52
  target_2 = target_2 if target_2 else target_1
53
  final_image = hair_fast.swap(source, target_1, target_2)
54
  return T.functional.to_pil_image(final_image)
55
-
 
56
  with gr.Blocks() as demo:
57
  gr.Markdown("## TDNM Demo")
58
  gr.Markdown("Thanks to [AIRI Institute]")
@@ -62,23 +56,15 @@ with gr.Blocks() as demo:
62
  source = gr.Image(label="Ảnh nhân vật", type="filepath")
63
  target_1 = gr.Image(label="Mẫu tóc", type="filepath")
64
  with gr.Accordion("Màu tóc (optional)", type="filepath"):
 
65
  btn = gr.Button("Thực hiện")
66
  with gr.Column():
67
  output = gr.Image(label="Kết quả")
68
- gr.Examples(examples=[["michael_cera-min.png", "leo_square-min.png", "pink_hair_celeb-min.png"]], inputs=[source, target_1, target_2], outputs=output)
69
  source.upload(fn=resize, inputs=source, outputs=source)
70
  target_1.upload(fn=resize, inputs=target_1, outputs=target_1)
71
- target_2.upload(fn=resize, inputs=target_2, outputs=target_2)
72
- btn.click(fn=swap_hair, inputs=[source, target_1, target_2], outputs=[output])
73
- gr.Markdown('''To cite the paper by the authors
74
- ```
75
- @article{nikolaev2024hairfastgan,
76
- title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
77
- author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
78
- journal={arXiv preprint arXiv:2404.01094},
79
- year={2024}
80
- }
81
- ```
82
- ''')
83
-
84
  demo.launch()
 
7
  import spaces
8
  from PIL import Image
9
 
10
+ # Clone và di chuyển các thư mục cần thiết
11
  subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
12
  os.chdir("HairFastGAN")
 
13
  subprocess.run(["git", "clone", "https://huggingface.co/AIRI-Institute/HairFastGAN"], check=True)
 
14
  os.chdir("HairFastGAN")
15
  subprocess.run(["git", "lfs", "pull"], check=True)
16
  os.chdir("..")
 
17
  shutil.move("HairFastGAN/pretrained_models", "pretrained_models")
18
  shutil.move("HairFastGAN/input", "input")
 
19
  shutil.rmtree("HairFastGAN")
 
20
  items = os.listdir()
 
21
  for item in items:
22
  print(item)
23
  shutil.move(item, os.path.join('..', item))
 
24
  os.chdir("..")
 
25
  shutil.rmtree("HairFastGAN")
26
 
27
+ # Import và khởi tạo HairFast
28
  from hair_swap import HairFast, get_parser
 
29
  hair_fast = HairFast(get_parser().parse_args([]))
30
 
31
+ # Hàm resize ảnh
32
  def resize(image_path):
33
  img = Image.open(image_path)
34
  square_size = 1024
 
35
  left = (img.width - square_size) / 2
36
  top = (img.height - square_size) / 2
37
  right = (img.width + square_size) / 2
38
+ bottom = (img.height - square_size) / 2
 
39
  img_cropped = img.crop((left, top, right, bottom))
40
  return img_cropped
41
 
42
+ # Hàm swap hair sử dụng GPU
43
  @spaces.GPU
44
  def swap_hair(source, target_1, target_2, progress=gr.Progress(track_tqdm=True)):
45
  target_2 = target_2 if target_2 else target_1
46
  final_image = hair_fast.swap(source, target_1, target_2)
47
  return T.functional.to_pil_image(final_image)
48
+
49
+ # Xây dựng giao diện Gradio
50
  with gr.Blocks() as demo:
51
  gr.Markdown("## TDNM Demo")
52
  gr.Markdown("Thanks to [AIRI Institute]")
 
56
  source = gr.Image(label="Ảnh nhân vật", type="filepath")
57
  target_1 = gr.Image(label="Mẫu tóc", type="filepath")
58
  with gr.Accordion("Màu tóc (optional)", type="filepath"):
59
+ pass # Đặt các thành phần tùy chọn ở đây nếu cần
60
  btn = gr.Button("Thực hiện")
61
  with gr.Column():
62
  output = gr.Image(label="Kết quả")
63
+ gr.Examples(examples=[["michael_cera-min.png", "leo_square-min.png", "pink_hair_celeb-min.png"]], inputs=[source, target_1], outputs=output)
64
  source.upload(fn=resize, inputs=source, outputs=source)
65
  target_1.upload(fn=resize, inputs=target_1, outputs=target_1)
66
+ btn.click(fn=swap_hair, inputs=[source, target_1], outputs=[output])
67
+ gr.Markdown('''
68
+ To cite the paper by the authors
69
+ ''')
 
 
 
 
 
 
 
 
 
70
  demo.launch()