File size: 4,951 Bytes
4ae4b3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import torch
from diffusers.utils import load_image, check_min_version
from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
from diffusers import FluxTransformer2DModel
import numpy as np
from torchvision import transforms

def run_inference(
    image_path,
    mask_path,
    garment_path,
    size=(576, 768),
    num_steps=50,
    guidance_scale=30,
    seed=42,
    pipe=None
):
    # Build pipeline
    if pipe is None:
        transformer = FluxTransformer2DModel.from_pretrained(
            "xiaozaa/flux1-fill-dev-diffusers",   ## The official Flux-Fill weights
            torch_dtype=torch.bfloat16
        )
        print("Start loading LoRA weights")
        state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
            pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha",     ## The tryon Lora weights
            weight_name="pytorch_lora_weights.safetensors",
            return_alphas=True
        )
        is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
        if not is_correct_format:
            raise ValueError("Invalid LoRA checkpoint.")

        FluxFillPipeline.load_lora_into_transformer(
            state_dict=state_dict,
            network_alphas=network_alphas,
            transformer=transformer,
        )
        
        pipe = FluxFillPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            transformer=transformer,
            torch_dtype=torch.bfloat16
        ).to("cuda")
    else:
        pipe.to("cuda")

    pipe.transformer.to(torch.bfloat16)

    # Add transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # For RGB images
    ])
    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Load and process images
    # print("image_path", image_path)
    image = load_image(image_path).convert("RGB").resize(size)
    mask = load_image(mask_path).convert("RGB").resize(size)
    garment = load_image(garment_path).convert("RGB").resize(size)

    # Transform images using the new preprocessing
    image_tensor = transform(image)
    mask_tensor = mask_transform(mask)[:1]  # Take only first channel
    garment_tensor = transform(garment)

    # Create concatenated images
    inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2)  # Concatenate along width
    garment_mask = torch.zeros_like(mask_tensor)
    extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)

    prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
            f"[IMAGE1] Detailed product shot of a clothing" \
            f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."

    generator = torch.Generator(device="cuda").manual_seed(seed)

    result = pipe(
        height=size[1],
        width=size[0] * 2,
        image=inpaint_image,
        mask_image=extended_mask,
        num_inference_steps=num_steps,
        generator=generator,
        max_sequence_length=512,
        guidance_scale=guidance_scale,
        prompt=prompt,
    ).images[0]

    # Split and save results
    width = size[0]
    garment_result = result.crop((0, 0, width, size[1]))
    tryon_result = result.crop((width, 0, width * 2, size[1]))
    
    return garment_result, tryon_result

def main():
    parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
    parser.add_argument('--image', required=True, help='Path to the model image')
    parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
    parser.add_argument('--garment', required=True, help='Path to the garment image')
    parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
    parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
    parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
    parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--width', type=int, default=576, help='Width')
    parser.add_argument('--height', type=int, default=768, help='Height')
    
    args = parser.parse_args()
    
    check_min_version("0.30.2")
    
    garment_result, tryon_result = run_inference(
        image_path=args.image,
        mask_path=args.mask,
        garment_path=args.garment,
        num_steps=args.steps,
        guidance_scale=args.guidance_scale,
        seed=args.seed,
        size=(args.width, args.height)
    )
    output_tryon_path=args.output_tryon
    
    tryon_result.save(output_tryon_path)
    
    print("Successfully saved garment and try-on images")

if __name__ == "__main__":
    main()