xiaozaa commited on
Commit
4fb0ca5
1 Parent(s): a978bc7

using lora version for spaces zeroGPU

Browse files
Files changed (4) hide show
  1. README.md +8 -2
  2. app.py +17 -7
  3. app_no_lora.py +215 -0
  4. requirements.txt +1 -0
README.md CHANGED
@@ -70,11 +70,17 @@ python tryon_inference.py \
70
  --steps 30
71
  ```
72
 
73
- Run the following command to start a gradio demo:
74
  ```bash
75
  python app.py
76
  ```
77
- Gradio demo:
 
 
 
 
 
 
78
 
79
  <!-- Option 2: Using a thumbnail linked to the video -->
80
  <!-- [![Demo](example/github.jpg)](https://github.com/user-attachments/assets/e1e69dbf-f8a8-4f34-a84a-e7be5b3d0aec) -->
 
70
  --steps 30
71
  ```
72
 
73
+ Run the following command to start a gradio demo with LoRA weights:
74
  ```bash
75
  python app.py
76
  ```
77
+
78
+ Run the following command to start a gradio demo without LoRA weights:
79
+ ```bash
80
+ python app_no_lora.py
81
+ ```
82
+
83
+ <!-- Gradio demo: -->
84
 
85
  <!-- Option 2: Using a thumbnail linked to the video -->
86
  <!-- [![Demo](example/github.jpg)](https://github.com/user-attachments/assets/e1e69dbf-f8a8-4f34-a84a-e7be5b3d0aec) -->
app.py CHANGED
@@ -37,16 +37,26 @@ else:
37
 
38
  device = torch.device('cuda')
39
 
40
- print('Loading diffusion model ...')
41
- transformer = FluxTransformer2DModel.from_pretrained(
42
- "xiaozaa/catvton-flux-alpha",
43
- torch_dtype=torch.bfloat16
 
44
  )
 
 
 
 
45
  pipe = FluxFillPipeline.from_pretrained(
46
- "black-forest-labs/FLUX.1-dev",
47
- transformer=transformer,
48
  torch_dtype=torch.bfloat16
49
  ).to(device)
 
 
 
 
 
 
50
  print('Loading Finished!')
51
 
52
  @spaces.GPU
@@ -99,7 +109,7 @@ def gradio_inference(
99
 
100
  with gr.Blocks() as demo:
101
  gr.Markdown("""
102
- # CATVTON FLUX Virtual Try-On Demo
103
  Upload a model image, draw a mask, and a garment image to generate virtual try-on results.
104
 
105
  [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
 
37
 
38
  device = torch.device('cuda')
39
 
40
+ print("Start loading LoRA weights")
41
+ state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
42
+ pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights
43
+ weight_name="pytorch_lora_weights.safetensors",
44
+ return_alphas=True
45
  )
46
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
47
+ if not is_correct_format:
48
+ raise ValueError("Invalid LoRA checkpoint.")
49
+ print('Loading diffusion model ...')
50
  pipe = FluxFillPipeline.from_pretrained(
51
+ "black-forest-labs/FLUX.1-Fill-dev",
 
52
  torch_dtype=torch.bfloat16
53
  ).to(device)
54
+ FluxFillPipeline.load_lora_into_transformer(
55
+ state_dict=state_dict,
56
+ network_alphas=network_alphas,
57
+ transformer=pipe.transformer,
58
+ )
59
+
60
  print('Loading Finished!')
61
 
62
  @spaces.GPU
 
109
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown("""
112
+ # CATVTON FLUX Virtual Try-On Demo (by using LoRA weights)
113
  Upload a model image, draw a mask, and a garment image to generate virtual try-on results.
114
 
115
  [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
app_no_lora.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import gradio as gr
4
+ from tryon_inference import run_inference
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import tempfile
9
+ import torch
10
+ from diffusers import FluxTransformer2DModel, FluxFillPipeline
11
+
12
+ import shutil
13
+
14
+ def find_cuda():
15
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
16
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
17
+
18
+ if cuda_home and os.path.exists(cuda_home):
19
+ return cuda_home
20
+
21
+ # Search for the nvcc executable in the system's PATH
22
+ nvcc_path = shutil.which('nvcc')
23
+
24
+ if nvcc_path:
25
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
26
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
27
+ return cuda_path
28
+
29
+ return None
30
+
31
+ cuda_path = find_cuda()
32
+
33
+ if cuda_path:
34
+ print(f"CUDA installation found at: {cuda_path}")
35
+ else:
36
+ print("CUDA installation not found")
37
+
38
+ device = torch.device('cuda')
39
+
40
+ print('Loading diffusion model ...')
41
+ transformer = FluxTransformer2DModel.from_pretrained(
42
+ "xiaozaa/catvton-flux-alpha",
43
+ torch_dtype=torch.bfloat16
44
+ )
45
+ pipe = FluxFillPipeline.from_pretrained(
46
+ "black-forest-labs/FLUX.1-dev",
47
+ transformer=transformer,
48
+ torch_dtype=torch.bfloat16
49
+ ).to(device)
50
+ print('Loading Finished!')
51
+
52
+ @spaces.GPU
53
+ def gradio_inference(
54
+ image_data,
55
+ garment,
56
+ num_steps=50,
57
+ guidance_scale=30.0,
58
+ seed=-1,
59
+ width=768,
60
+ height=1024
61
+ ):
62
+ """Wrapper function for Gradio interface"""
63
+ # Use temporary directory
64
+ with tempfile.TemporaryDirectory() as tmp_dir:
65
+ # Save inputs to temp directory
66
+ temp_image = os.path.join(tmp_dir, "image.png")
67
+ temp_mask = os.path.join(tmp_dir, "mask.png")
68
+ temp_garment = os.path.join(tmp_dir, "garment.png")
69
+
70
+ # Extract image and mask from ImageEditor data
71
+ image = image_data["background"]
72
+ mask = image_data["layers"][0] # First layer contains the mask
73
+
74
+ # Convert to numpy array and process mask
75
+ mask_array = np.array(mask)
76
+ is_black = np.all(mask_array < 10, axis=2)
77
+ mask = Image.fromarray(((~is_black) * 255).astype(np.uint8))
78
+
79
+ # Save files to temp directory
80
+ image.save(temp_image)
81
+ mask.save(temp_mask)
82
+ garment.save(temp_garment)
83
+
84
+ try:
85
+ # Run inference
86
+ _, tryon_result = run_inference(
87
+ pipe=pipe,
88
+ image_path=temp_image,
89
+ mask_path=temp_mask,
90
+ garment_path=temp_garment,
91
+ num_steps=num_steps,
92
+ guidance_scale=guidance_scale,
93
+ seed=seed,
94
+ size=(width, height)
95
+ )
96
+ return tryon_result
97
+ except Exception as e:
98
+ raise gr.Error(f"Error during inference: {str(e)}")
99
+
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown("""
102
+ # CATVTON FLUX Virtual Try-On Demo
103
+ Upload a model image, draw a mask, and a garment image to generate virtual try-on results.
104
+
105
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
106
+ [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/nftblackmagic/catvton-flux)
107
+ """)
108
+
109
+ # gr.Video("example/github.mp4", label="Demo Video: How to use the tool")
110
+
111
+ with gr.Column():
112
+ with gr.Row():
113
+ with gr.Column():
114
+ image_input = gr.ImageMask(
115
+ label="Model Image (Click 'Edit' and draw mask over the clothing area)",
116
+ type="pil",
117
+ height=600,
118
+ width=300
119
+ )
120
+ gr.Examples(
121
+ examples=[
122
+ ["./example/person/00008_00.jpg"],
123
+ ["./example/person/00055_00.jpg"],
124
+ ["./example/person/00057_00.jpg"],
125
+ ["./example/person/00067_00.jpg"],
126
+ ["./example/person/00069_00.jpg"],
127
+ ],
128
+ inputs=[image_input],
129
+ label="Person Images",
130
+ )
131
+ with gr.Column():
132
+ garment_input = gr.Image(label="Garment Image", type="pil", height=600, width=300)
133
+ gr.Examples(
134
+ examples=[
135
+ ["./example/garment/04564_00.jpg"],
136
+ ["./example/garment/00055_00.jpg"],
137
+ ["./example/garment/00396_00.jpg"],
138
+ ["./example/garment/00067_00.jpg"],
139
+ ["./example/garment/00069_00.jpg"],
140
+ ],
141
+ inputs=[garment_input],
142
+ label="Garment Images",
143
+ )
144
+ with gr.Column():
145
+ tryon_output = gr.Image(label="Try-On Result", height=600, width=300)
146
+
147
+ with gr.Row():
148
+ num_steps = gr.Slider(
149
+ minimum=1,
150
+ maximum=100,
151
+ value=30,
152
+ step=1,
153
+ label="Number of Steps"
154
+ )
155
+ guidance_scale = gr.Slider(
156
+ minimum=1.0,
157
+ maximum=50.0,
158
+ value=30.0,
159
+ step=0.5,
160
+ label="Guidance Scale"
161
+ )
162
+ seed = gr.Slider(
163
+ minimum=-1,
164
+ maximum=2147483647,
165
+ step=1,
166
+ value=-1,
167
+ label="Seed (-1 for random)"
168
+ )
169
+ width = gr.Slider(
170
+ minimum=256,
171
+ maximum=1024,
172
+ step=64,
173
+ value=768,
174
+ label="Width"
175
+ )
176
+ height = gr.Slider(
177
+ minimum=256,
178
+ maximum=1024,
179
+ step=64,
180
+ value=1024,
181
+ label="Height"
182
+ )
183
+
184
+
185
+ submit_btn = gr.Button("Generate Try-On", variant="primary")
186
+
187
+
188
+ with gr.Row():
189
+ gr.Markdown("""
190
+ ### Notes:
191
+ - The model is trained on VITON-HD dataset. It focuses on the woman upper body try-on generation.
192
+ - The mask should indicate the region where the garment will be placed.
193
+ - The garment image should be on a clean background.
194
+ - The model is not perfect. It may generate some artifacts.
195
+ - The model is slow. Please be patient.
196
+ - The model is just for research purpose.
197
+ """)
198
+
199
+ submit_btn.click(
200
+ fn=gradio_inference,
201
+ inputs=[
202
+ image_input,
203
+ garment_input,
204
+ num_steps,
205
+ guidance_scale,
206
+ seed,
207
+ width,
208
+ height
209
+ ],
210
+ outputs=[tryon_output],
211
+ api_name="try-on"
212
+ )
213
+
214
+
215
+ demo.launch()
requirements.txt CHANGED
@@ -8,6 +8,7 @@ numpy==1.26.4
8
  accelerate==1.1.1
9
  sentencepiece==0.2.0
10
  protobuf==5.27.3
 
11
  huggingface-hub
12
  spaces
13
  git+https://github.com/huggingface/diffusers.git
 
8
  accelerate==1.1.1
9
  sentencepiece==0.2.0
10
  protobuf==5.27.3
11
+ peft==0.13.2
12
  huggingface-hub
13
  spaces
14
  git+https://github.com/huggingface/diffusers.git