import argparse import torch from diffusers.utils import load_image, check_min_version from diffusers import FluxPriorReduxPipeline, FluxFillPipeline from diffusers import FluxTransformer2DModel import numpy as np from torchvision import transforms def run_inference( image_path, mask_path, size=(576, 768), num_steps=50, guidance_scale=30, seed=42, pipe=None ): # Build pipeline if pipe is None: transformer = FluxTransformer2DModel.from_pretrained( "xiaozaa/cat-tryoff-flux", torch_dtype=torch.bfloat16 ) pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 ).to("cuda") else: pipe.to("cuda") pipe.transformer.to(torch.bfloat16) # Add transform transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) # For RGB images ]) mask_transform = transforms.Compose([ transforms.ToTensor() ]) # Load and process images # print("image_path", image_path) image = load_image(image_path).convert("RGB").resize(size) mask = load_image(mask_path).convert("RGB").resize(size) # Transform images using the new preprocessing image_tensor = transform(image) mask_tensor = mask_transform(mask)[:1] # Take only first channel garment_tensor = torch.zeros_like(image_tensor) image_tensor = image_tensor * mask_tensor # Create concatenated images inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width garment_mask = torch.zeros_like(mask_tensor) extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2) prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \ f"[IMAGE1] Detailed product shot of a clothing" \ f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting." generator = torch.Generator(device="cuda").manual_seed(seed) result = pipe( height=size[1], width=size[0] * 2, image=inpaint_image, mask_image=extended_mask, num_inference_steps=num_steps, generator=generator, max_sequence_length=512, guidance_scale=guidance_scale, prompt=prompt, ).images[0] # Split and save results width = size[0] garment_result = result.crop((0, 0, width, size[1])) tryon_result = result.crop((width, 0, width * 2, size[1])) return garment_result, tryon_result def main(): parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference') parser.add_argument('--image', required=True, help='Path to the model image') parser.add_argument('--mask', required=True, help='Path to the agnostic mask') parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result') parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result') parser.add_argument('--steps', type=int, default=50, help='Number of inference steps') parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale') parser.add_argument('--seed', type=int, default=0, help='Random seed') parser.add_argument('--width', type=int, default=576, help='Width') parser.add_argument('--height', type=int, default=768, help='Height') args = parser.parse_args() check_min_version("0.30.2") garment_result, tryon_result = run_inference( image_path=args.image, mask_path=args.mask, num_steps=args.steps, guidance_scale=args.guidance_scale, seed=args.seed, size=(args.width, args.height) ) output_tryon_path=args.output_tryon output_garment_path=args.output_garment tryon_result.save(output_tryon_path) garment_result.save(output_garment_path) print("Successfully saved garment and try-on images") if __name__ == "__main__": main()