xiaozaa commited on
Commit
d19cc56
1 Parent(s): 6260538

try off version

Browse files
.gitignore ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Distribution / packaging
7
+ dist/
8
+ build/
9
+ *.egg-info/
10
+
11
+ # Virtual environments
12
+ venv/
13
+ env/
14
+ .env/
15
+ .venv/
16
+
17
+ # IDE specific files
18
+ .idea/
19
+ .vscode/
20
+ *.swp
21
+ *.swo
22
+
23
+ # Unit test / coverage reports
24
+ htmlcov/
25
+ .tox/
26
+ .coverage
27
+ .coverage.*
28
+ coverage.xml
29
+ *.cover
30
+
31
+ # Jupyter Notebook
32
+ .ipynb_checkpoints
33
+
34
+ # Local development settings
35
+ .env
36
+ .env.local
37
+
38
+ # Logs
39
+ *.log
40
+
41
+ # Database files
42
+ *.db
43
+ *.sqlite3
44
+
45
+ # OS generated files
46
+ .DS_Store
47
+ .DS_Store?
48
+ ._*
49
+ .Spotlight-V100
50
+ .Trashes
51
+ ehthumbs.db
52
+ Thumbs.db
53
+
54
+ # Gradio cache
55
+ .gradio/example/github.mp4
56
+
57
+ aws/
58
+ checkpoints/
README.md CHANGED
@@ -1,14 +1,17 @@
1
  ---
2
- title: Cat Try Off Flux
3
- emoji: 👀
4
  colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-nc-4.0
11
- short_description: Extract and reconstruct the front view of clothing
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
1
  ---
2
+ title: cat-tryoff-flux
3
+ emoji: 🖥️
4
  colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+
13
+ # cat-tryoff-flux
14
+
15
+ CAT-Tryoff-Flux is an advanced tryoff model. This model can extract and reconstruct the front view of clothing items from images of people wearing them. It used the same method of (CATVTON-FLUX)[https://huggingface.co/xiaozaa/catvton-flux-alpha].
16
+
17
+ The github repo is [here](https://github.com/nftblackmagic/catvton-flux).
app.py CHANGED
@@ -1,7 +1,189 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
+ import spaces
2
+
3
  import gradio as gr
4
+ from tryoff_inference import run_inference
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import tempfile
9
+ import torch
10
+ from diffusers import FluxTransformer2DModel, FluxFillPipeline
11
+ import subprocess
12
+
13
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
14
+ dtype = torch.bfloat16
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ print('Loading diffusion model ...')
18
+ transformer = FluxTransformer2DModel.from_pretrained(
19
+ "xiaozaa/cat-tryoff-flux",
20
+ torch_dtype=dtype
21
+ )
22
+ pipe = FluxFillPipeline.from_pretrained(
23
+ "black-forest-labs/FLUX.1-dev",
24
+ transformer=transformer,
25
+ torch_dtype=dtype
26
+ ).to(device)
27
+ print('Loading Finished!')
28
+
29
+ @spaces.GPU(duration=120)
30
+ def gradio_inference(
31
+ image_data,
32
+ garment,
33
+ num_steps=50,
34
+ guidance_scale=30.0,
35
+ seed=-1,
36
+ width=768,
37
+ height=1024
38
+ ):
39
+ """Wrapper function for Gradio interface"""
40
+ # Check if mask has been drawn
41
+ if image_data is None or "layers" not in image_data or not image_data["layers"]:
42
+ raise gr.Error("Please draw a mask over the clothing area before generating!")
43
+
44
+ # Check if mask is empty (all black)
45
+ mask = image_data["layers"][0]
46
+ mask_array = np.array(mask)
47
+ if np.all(mask_array < 10):
48
+ raise gr.Error("The mask is empty! Please draw over the clothing area you want to replace.")
49
+
50
+ # Use temporary directory
51
+ with tempfile.TemporaryDirectory() as tmp_dir:
52
+ # Save inputs to temp directory
53
+ temp_image = os.path.join(tmp_dir, "image.png")
54
+ temp_mask = os.path.join(tmp_dir, "mask.png")
55
+
56
+ # Extract image and mask from ImageEditor data
57
+ image = image_data["background"]
58
+ mask = image_data["layers"][0] # First layer contains the mask
59
+
60
+ # Convert to numpy array and process mask
61
+ mask_array = np.array(mask)
62
+ is_black = np.all(mask_array < 10, axis=2)
63
+ mask = Image.fromarray(((~is_black) * 255).astype(np.uint8))
64
+
65
+ # Save files to temp directory
66
+ image.save(temp_image)
67
+ mask.save(temp_mask)
68
+
69
+ try:
70
+ # Run inference
71
+ garment_result, _ = run_inference(
72
+ pipe=pipe,
73
+ image_path=temp_image,
74
+ mask_path=temp_mask,
75
+ num_steps=num_steps,
76
+ guidance_scale=guidance_scale,
77
+ seed=seed,
78
+ size=(width, height)
79
+ )
80
+ return garment_result
81
+ except Exception as e:
82
+ raise gr.Error(f"Error during inference: {str(e)}")
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("""
86
+ # CAT-TRYOFF-FLUX Virtual Try-Off Demo
87
+ Upload a model image, draw a mask, and a garment image to generate virtual try-off results.
88
+
89
+ """)
90
+
91
+ # gr.Video("example/github.mp4", label="Demo Video: How to use the tool")
92
+
93
+ with gr.Column():
94
+ gr.Markdown("""
95
+ ### ⚠️ Important:
96
+ 1. Choose a model image or upload your own
97
+ 2. Use the Pen tool to draw a mask over the clothing area you want to restore
98
+ """)
99
+
100
+ with gr.Row():
101
+ with gr.Column():
102
+ image_input = gr.ImageMask(
103
+ label="Model Image (Click 'Edit' and draw mask over the clothing area)",
104
+ type="pil",
105
+ height=600,
106
+ width=300
107
+ )
108
+ gr.Examples(
109
+ examples=[
110
+ ["./example/person/00008_00.jpg"],
111
+ ["./example/person/00055_00.jpg"],
112
+ ["./example/person/00064_00.jpg"],
113
+ ["./example/person/00067_00.jpg"],
114
+ ["./example/person/00069_00.jpg"],
115
+ ],
116
+ inputs=[image_input],
117
+ label="Person Images",
118
+ )
119
+ with gr.Column():
120
+ garment_output = gr.Image(label="Try-On Result", height=600, width=300)
121
+
122
+ with gr.Row():
123
+ num_steps = gr.Slider(
124
+ minimum=1,
125
+ maximum=100,
126
+ value=30,
127
+ step=1,
128
+ label="Number of Steps"
129
+ )
130
+ guidance_scale = gr.Slider(
131
+ minimum=1.0,
132
+ maximum=50.0,
133
+ value=30.0,
134
+ step=0.5,
135
+ label="Guidance Scale"
136
+ )
137
+ seed = gr.Slider(
138
+ minimum=-1,
139
+ maximum=2147483647,
140
+ step=1,
141
+ value=-1,
142
+ label="Seed (-1 for random)"
143
+ )
144
+ width = gr.Slider(
145
+ minimum=256,
146
+ maximum=1024,
147
+ step=64,
148
+ value=768,
149
+ label="Width"
150
+ )
151
+ height = gr.Slider(
152
+ minimum=256,
153
+ maximum=1024,
154
+ step=64,
155
+ value=1024,
156
+ label="Height"
157
+ )
158
+
159
+
160
+ submit_btn = gr.Button("Generate Try-On", variant="primary")
161
+
162
+
163
+ with gr.Row():
164
+ gr.Markdown("""
165
+ ### Notes:
166
+ - The model is trained on VITON-HD dataset. It focuses on the woman upper body try-on generation.
167
+ - The mask should indicate the region where the garment will be placed.
168
+ - The garment image should be on a clean background.
169
+ - The model is not perfect. It may generate some artifacts.
170
+ - The model is slow. Please be patient.
171
+ - The model is just for research purpose.
172
+ """)
173
+
174
+ submit_btn.click(
175
+ fn=gradio_inference,
176
+ inputs=[
177
+ image_input,
178
+ num_steps,
179
+ guidance_scale,
180
+ seed,
181
+ width,
182
+ height
183
+ ],
184
+ outputs=[garment_output],
185
+ api_name="try-off"
186
+ )
187
 
 
 
188
 
 
189
  demo.launch()
example/person/00008_00.jpg ADDED
example/person/00008_00_mask.png ADDED
example/person/00055_00.jpg ADDED
example/person/00055_00_mask.png ADDED
example/person/00057_00.jpg ADDED
example/person/00057_00_mask.png ADDED
example/person/00064_00.jpg ADDED
example/person/00064_00_mask.png ADDED
example/person/00067_00.jpg ADDED
example/person/00067_00_mask.png ADDED
example/person/00069_00.jpg ADDED
example/person/00069_00_mask.png ADDED
example/person/1.jpg ADDED
example/person/1_mask.png ADDED
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ git+https://github.com/huggingface/diffusers.git
3
+ gradio==5.6.0
4
+ gradio_client==1.4.3
5
+ torch==2.4.0
6
+ torchvision==0.19.0
7
+ tqdm==4.66.5
8
+ transformers==4.43.3
9
+ numpy==1.26.4
10
+ sentencepiece
11
+ peft==0.13.2
12
+ huggingface-hub
13
+ spaces
14
+ protobuf
tryoff.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python tryoff_inference.py \
2
+ --image ./example/person/00069_00.jpg \
3
+ --mask ./example/person/00069_00_mask.png \
4
+ --seed 41 \
5
+ --output_tryon test_original.png \
6
+ --output_garment restored_garment6.png \
7
+ --steps 30
tryoff_inference.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from diffusers.utils import load_image, check_min_version
4
+ from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
5
+ from diffusers import FluxTransformer2DModel
6
+ import numpy as np
7
+ from torchvision import transforms
8
+
9
+ def run_inference(
10
+ image_path,
11
+ mask_path,
12
+ size=(576, 768),
13
+ num_steps=50,
14
+ guidance_scale=30,
15
+ seed=42,
16
+ pipe=None
17
+ ):
18
+ # Build pipeline
19
+ if pipe is None:
20
+ transformer = FluxTransformer2DModel.from_pretrained(
21
+ "xiaozaa/cat-tryoff-flux",
22
+ torch_dtype=torch.bfloat16
23
+ )
24
+ pipe = FluxFillPipeline.from_pretrained(
25
+ "black-forest-labs/FLUX.1-dev",
26
+ transformer=transformer,
27
+ torch_dtype=torch.bfloat16
28
+ ).to("cuda")
29
+ else:
30
+ pipe.to("cuda")
31
+
32
+ pipe.transformer.to(torch.bfloat16)
33
+
34
+ # Add transform
35
+ transform = transforms.Compose([
36
+ transforms.ToTensor(),
37
+ transforms.Normalize([0.5], [0.5]) # For RGB images
38
+ ])
39
+ mask_transform = transforms.Compose([
40
+ transforms.ToTensor()
41
+ ])
42
+
43
+ # Load and process images
44
+ # print("image_path", image_path)
45
+ image = load_image(image_path).convert("RGB").resize(size)
46
+ mask = load_image(mask_path).convert("RGB").resize(size)
47
+
48
+ # Transform images using the new preprocessing
49
+ image_tensor = transform(image)
50
+ mask_tensor = mask_transform(mask)[:1] # Take only first channel
51
+ garment_tensor = torch.zeros_like(image_tensor)
52
+ image_tensor = image_tensor * mask_tensor
53
+
54
+ # Create concatenated images
55
+ inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
56
+ garment_mask = torch.zeros_like(mask_tensor)
57
+ extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2)
58
+
59
+ prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
60
+ f"[IMAGE1] Detailed product shot of a clothing" \
61
+ f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
62
+
63
+ generator = torch.Generator(device="cuda").manual_seed(seed)
64
+
65
+ result = pipe(
66
+ height=size[1],
67
+ width=size[0] * 2,
68
+ image=inpaint_image,
69
+ mask_image=extended_mask,
70
+ num_inference_steps=num_steps,
71
+ generator=generator,
72
+ max_sequence_length=512,
73
+ guidance_scale=guidance_scale,
74
+ prompt=prompt,
75
+ ).images[0]
76
+
77
+ # Split and save results
78
+ width = size[0]
79
+ garment_result = result.crop((0, 0, width, size[1]))
80
+ tryon_result = result.crop((width, 0, width * 2, size[1]))
81
+
82
+ return garment_result, tryon_result
83
+
84
+ def main():
85
+ parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
86
+ parser.add_argument('--image', required=True, help='Path to the model image')
87
+ parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
88
+ parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
89
+ parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
90
+ parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
91
+ parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
92
+ parser.add_argument('--seed', type=int, default=0, help='Random seed')
93
+ parser.add_argument('--width', type=int, default=576, help='Width')
94
+ parser.add_argument('--height', type=int, default=768, help='Height')
95
+
96
+ args = parser.parse_args()
97
+
98
+ check_min_version("0.30.2")
99
+
100
+ garment_result, tryon_result = run_inference(
101
+ image_path=args.image,
102
+ mask_path=args.mask,
103
+ num_steps=args.steps,
104
+ guidance_scale=args.guidance_scale,
105
+ seed=args.seed,
106
+ size=(args.width, args.height)
107
+ )
108
+ output_tryon_path=args.output_tryon
109
+ output_garment_path=args.output_garment
110
+
111
+ tryon_result.save(output_tryon_path)
112
+ garment_result.save(output_garment_path)
113
+
114
+ print("Successfully saved garment and try-on images")
115
+
116
+ if __name__ == "__main__":
117
+ main()