File size: 4,144 Bytes
d19cc56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import torch
from diffusers.utils import load_image, check_min_version
from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
from diffusers import FluxTransformer2DModel
import numpy as np
from torchvision import transforms

def run_inference(
    image_path,
    mask_path,
    size=(576, 768),
    num_steps=50,
    guidance_scale=30,
    seed=42,
    pipe=None
):
    # Build pipeline
    if pipe is None:
        transformer = FluxTransformer2DModel.from_pretrained(
            "xiaozaa/cat-tryoff-flux", 
            torch_dtype=torch.bfloat16
        )
        pipe = FluxFillPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            transformer=transformer,
            torch_dtype=torch.bfloat16
        ).to("cuda")
    else:
        pipe.to("cuda")

    pipe.transformer.to(torch.bfloat16)

    # Add transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # For RGB images
    ])
    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Load and process images
    # print("image_path", image_path)
    image = load_image(image_path).convert("RGB").resize(size)
    mask = load_image(mask_path).convert("RGB").resize(size)

    # Transform images using the new preprocessing
    image_tensor = transform(image)
    mask_tensor = mask_transform(mask)[:1]  # Take only first channel
    garment_tensor = torch.zeros_like(image_tensor)
    image_tensor = image_tensor * mask_tensor

    # Create concatenated images
    inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2)  # Concatenate along width
    garment_mask = torch.zeros_like(mask_tensor)
    extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2)

    prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
            f"[IMAGE1] Detailed product shot of a clothing" \
            f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."

    generator = torch.Generator(device="cuda").manual_seed(seed)

    result = pipe(
        height=size[1],
        width=size[0] * 2,
        image=inpaint_image,
        mask_image=extended_mask,
        num_inference_steps=num_steps,
        generator=generator,
        max_sequence_length=512,
        guidance_scale=guidance_scale,
        prompt=prompt,
    ).images[0]

    # Split and save results
    width = size[0]
    garment_result = result.crop((0, 0, width, size[1]))
    tryon_result = result.crop((width, 0, width * 2, size[1]))
    
    return garment_result, tryon_result

def main():
    parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
    parser.add_argument('--image', required=True, help='Path to the model image')
    parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
    parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
    parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
    parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
    parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--width', type=int, default=576, help='Width')
    parser.add_argument('--height', type=int, default=768, help='Height')
    
    args = parser.parse_args()
    
    check_min_version("0.30.2")
    
    garment_result, tryon_result = run_inference(
        image_path=args.image,
        mask_path=args.mask,
        num_steps=args.steps,
        guidance_scale=args.guidance_scale,
        seed=args.seed,
        size=(args.width, args.height)
    )
    output_tryon_path=args.output_tryon
    output_garment_path=args.output_garment
    
    tryon_result.save(output_tryon_path)
    garment_result.save(output_garment_path)
    
    print("Successfully saved garment and try-on images")

if __name__ == "__main__":
    main()