xiaozaa commited on
Commit
abf3d6e
·
1 Parent(s): 5458c75

add gradio demo

Browse files
.gitignore CHANGED
@@ -49,4 +49,7 @@ coverage.xml
49
  .Spotlight-V100
50
  .Trashes
51
  ehthumbs.db
52
- Thumbs.db
 
 
 
 
49
  .Spotlight-V100
50
  .Trashes
51
  ehthumbs.db
52
+ Thumbs.db
53
+
54
+ # Gradio cache
55
+ .gradio/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -25,6 +25,7 @@ pip install -r requirements.txt
25
 
26
  ## Usage
27
 
 
28
  ```bash
29
  python tryon_inference.py \
30
  --image ./example/person/00008_00.jpg \
@@ -33,9 +34,15 @@ python tryon_inference.py \
33
  --seed 42
34
  ```
35
 
 
 
 
 
 
 
36
  ## TODO:
37
  - [ ] Release the FID score
38
- - [ ] Add gradio demo
39
  - [ ] Release updated weights with better performance
40
 
41
  ## Citation
 
25
 
26
  ## Usage
27
 
28
+ Run the following command to try on an image:
29
  ```bash
30
  python tryon_inference.py \
31
  --image ./example/person/00008_00.jpg \
 
34
  --seed 42
35
  ```
36
 
37
+ Run the following command to start a gradio demo:
38
+ ```bash
39
+ python app.py
40
+ ```
41
+
42
+
43
  ## TODO:
44
  - [ ] Release the FID score
45
+ - [x] Add gradio demo
46
  - [ ] Release updated weights with better performance
47
 
48
  ## Citation
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from tryon_inference import run_inference
3
+ import os
4
+ import numpy as np
5
+ from PIL import Image
6
+ import tempfile
7
+
8
+ def gradio_inference(
9
+ image_data,
10
+ garment,
11
+ num_steps=50,
12
+ guidance_scale=30.0,
13
+ seed=-1,
14
+ size=(576,768)
15
+ ):
16
+ """Wrapper function for Gradio interface"""
17
+ # Use temporary directory
18
+ with tempfile.TemporaryDirectory() as tmp_dir:
19
+ # Save inputs to temp directory
20
+ temp_image = os.path.join(tmp_dir, "image.png")
21
+ temp_mask = os.path.join(tmp_dir, "mask.png")
22
+ temp_garment = os.path.join(tmp_dir, "garment.png")
23
+
24
+ # Extract image and mask from ImageEditor data
25
+ image = image_data["background"]
26
+ mask = image_data["layers"][0] # First layer contains the mask
27
+
28
+ # Convert to numpy array and process mask
29
+ mask_array = np.array(mask)
30
+ is_black = np.all(mask_array < 10, axis=2)
31
+ mask = Image.fromarray(((~is_black) * 255).astype(np.uint8))
32
+
33
+ # Save files to temp directory
34
+ image.save(temp_image)
35
+ mask.save(temp_mask)
36
+ garment.save(temp_garment)
37
+
38
+ try:
39
+ # Run inference
40
+ _, tryon_result = run_inference(
41
+ image_path=temp_image,
42
+ mask_path=temp_mask,
43
+ garment_path=temp_garment,
44
+ num_steps=num_steps,
45
+ guidance_scale=guidance_scale,
46
+ seed=seed,
47
+ size=size
48
+ )
49
+ return tryon_result
50
+ except Exception as e:
51
+ raise gr.Error(f"Error during inference: {str(e)}")
52
+
53
+ def create_demo():
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("""
56
+ # CATVTON FLUX Virtual Try-On Demo
57
+ Upload a model image, an agnostic mask, and a garment image to generate virtual try-on results.
58
+ """)
59
+
60
+ with gr.Column():
61
+ with gr.Row():
62
+ with gr.Column():
63
+ image_input = gr.ImageMask(
64
+ label="Model Image (Draw mask where garment should go)",
65
+ type="pil",
66
+ height=576,
67
+ )
68
+ gr.Examples(
69
+ examples=[
70
+ ["./example/person/00008_00.jpg"],
71
+ ["./example/person/00055_00.jpg"],
72
+ ["./example/person/00057_00.jpg"],
73
+ ["./example/person/00067_00.jpg"],
74
+ ["./example/person/00069_00.jpg"],
75
+ ],
76
+ inputs=[image_input],
77
+ label="Person Images",
78
+ )
79
+ with gr.Column():
80
+ garment_input = gr.Image(label="Garment Image", type="pil", height=576)
81
+ gr.Examples(
82
+ examples=[
83
+ ["./example/garment/04564_00.jpg"],
84
+ ["./example/garment/00055_00.jpg"],
85
+ ["./example/garment/00057_00.jpg"],
86
+ ["./example/garment/00067_00.jpg"],
87
+ ["./example/garment/00069_00.jpg"],
88
+ ],
89
+ inputs=[garment_input],
90
+ label="Garment Images",
91
+ )
92
+
93
+ with gr.Row():
94
+ num_steps = gr.Slider(
95
+ minimum=1,
96
+ maximum=100,
97
+ value=50,
98
+ step=1,
99
+ label="Number of Steps"
100
+ )
101
+ guidance_scale = gr.Slider(
102
+ minimum=1.0,
103
+ maximum=50.0,
104
+ value=30.0,
105
+ step=0.5,
106
+ label="Guidance Scale"
107
+ )
108
+ seed = gr.Slider(
109
+ minimum=-1,
110
+ maximum=2147483647,
111
+ step=1,
112
+ value=-1,
113
+ label="Seed (-1 for random)"
114
+ )
115
+
116
+ submit_btn = gr.Button("Generate Try-On", variant="primary")
117
+
118
+ with gr.Column():
119
+ tryon_output = gr.Image(label="Try-On Result")
120
+
121
+ with gr.Row():
122
+ gr.Markdown("""
123
+ ### Notes:
124
+ - The model image should be a full-body photo
125
+ - The mask should indicate the region where the garment will be placed
126
+ - The garment image should be on a clean background
127
+ """)
128
+
129
+ submit_btn.click(
130
+ fn=gradio_inference,
131
+ inputs=[
132
+ image_input,
133
+ garment_input,
134
+ num_steps,
135
+ guidance_scale,
136
+ seed
137
+ ],
138
+ outputs=[tryon_output],
139
+ api_name="try-on"
140
+ )
141
+
142
+ return demo
143
+
144
+ if __name__ == "__main__":
145
+ demo = create_demo()
146
+ demo.queue() # Enable queuing for multiple users
147
+ demo.launch(
148
+ share=True,
149
+ server_name="0.0.0.0" # Makes the server accessible from other machines
150
+ )
example/garment/00055_00.jpg ADDED
example/garment/00057_00.jpg ADDED
example/garment/00064_00.jpg ADDED
example/garment/00067_00.jpg ADDED
example/garment/00069_00.jpg ADDED
example/person/00055_00.jpg ADDED
example/person/00055_00_mask.png ADDED
example/person/00057_00.jpg ADDED
example/person/00057_00_mask.png ADDED
example/person/00064_00.jpg ADDED
example/person/00064_00_mask.png ADDED
example/person/00067_00.jpg ADDED
example/person/00067_00_mask.png ADDED
example/person/00069_00.jpg ADDED
example/person/00069_00_mask.png ADDED
requirements.txt CHANGED
@@ -94,5 +94,7 @@ yarl==1.9.4
94
  zipp==3.20.0
95
  peft==0.13.2
96
  bitsandbytes==0.44.1
 
 
97
  prodigyopt
98
  git+https://github.com/huggingface/diffusers.git
 
94
  zipp==3.20.0
95
  peft==0.13.2
96
  bitsandbytes==0.44.1
97
+ gradio==5.6.0
98
+ gradio_client==1.4.3
99
  prodigyopt
100
  git+https://github.com/huggingface/diffusers.git
tryon_inference.py CHANGED
@@ -10,8 +10,6 @@ def run_inference(
10
  image_path,
11
  mask_path,
12
  garment_path,
13
- output_garment_path=None,
14
- output_tryon_path='flux_inpaint_tryon.png',
15
  size=(576, 768),
16
  num_steps=50,
17
  guidance_scale=30,
@@ -82,9 +80,7 @@ def run_inference(
82
  garment_result = result.crop((0, 0, width, size[1]))
83
  tryon_result = result.crop((width, 0, width * 2, size[1]))
84
 
85
- if output_garment_path is not None:
86
- garment_result.save(output_garment_path)
87
- tryon_result.save(output_tryon_path)
88
  return garment_result, tryon_result
89
 
90
  def main():
@@ -115,6 +111,13 @@ def main():
115
  seed=args.seed,
116
  size=(args.width, args.height)
117
  )
 
 
 
 
 
 
 
118
  print("Successfully saved garment and try-on images")
119
 
120
  if __name__ == "__main__":
 
10
  image_path,
11
  mask_path,
12
  garment_path,
 
 
13
  size=(576, 768),
14
  num_steps=50,
15
  guidance_scale=30,
 
80
  garment_result = result.crop((0, 0, width, size[1]))
81
  tryon_result = result.crop((width, 0, width * 2, size[1]))
82
 
83
+
 
 
84
  return garment_result, tryon_result
85
 
86
  def main():
 
111
  seed=args.seed,
112
  size=(args.width, args.height)
113
  )
114
+ output_garment_path=args.output_garment,
115
+ output_tryon_path=args.output_tryon,
116
+
117
+ if output_garment_path is not None:
118
+ garment_result.save(output_garment_path)
119
+ tryon_result.save(output_tryon_path)
120
+
121
  print("Successfully saved garment and try-on images")
122
 
123
  if __name__ == "__main__":