Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,951 Bytes
4ae4b3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import argparse
import torch
from diffusers.utils import load_image, check_min_version
from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
from diffusers import FluxTransformer2DModel
import numpy as np
from torchvision import transforms
def run_inference(
image_path,
mask_path,
garment_path,
size=(576, 768),
num_steps=50,
guidance_scale=30,
seed=42,
pipe=None
):
# Build pipeline
if pipe is None:
transformer = FluxTransformer2DModel.from_pretrained(
"xiaozaa/flux1-fill-dev-diffusers", ## The official Flux-Fill weights
torch_dtype=torch.bfloat16
)
print("Start loading LoRA weights")
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights
weight_name="pytorch_lora_weights.safetensors",
return_alphas=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
FluxFillPipeline.load_lora_into_transformer(
state_dict=state_dict,
network_alphas=network_alphas,
transformer=transformer,
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
else:
pipe.to("cuda")
pipe.transformer.to(torch.bfloat16)
# Add transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # For RGB images
])
mask_transform = transforms.Compose([
transforms.ToTensor()
])
# Load and process images
# print("image_path", image_path)
image = load_image(image_path).convert("RGB").resize(size)
mask = load_image(mask_path).convert("RGB").resize(size)
garment = load_image(garment_path).convert("RGB").resize(size)
# Transform images using the new preprocessing
image_tensor = transform(image)
mask_tensor = mask_transform(mask)[:1] # Take only first channel
garment_tensor = transform(garment)
# Create concatenated images
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
garment_mask = torch.zeros_like(mask_tensor)
extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
f"[IMAGE1] Detailed product shot of a clothing" \
f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
generator = torch.Generator(device="cuda").manual_seed(seed)
result = pipe(
height=size[1],
width=size[0] * 2,
image=inpaint_image,
mask_image=extended_mask,
num_inference_steps=num_steps,
generator=generator,
max_sequence_length=512,
guidance_scale=guidance_scale,
prompt=prompt,
).images[0]
# Split and save results
width = size[0]
garment_result = result.crop((0, 0, width, size[1]))
tryon_result = result.crop((width, 0, width * 2, size[1]))
return garment_result, tryon_result
def main():
parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
parser.add_argument('--image', required=True, help='Path to the model image')
parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
parser.add_argument('--garment', required=True, help='Path to the garment image')
parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--width', type=int, default=576, help='Width')
parser.add_argument('--height', type=int, default=768, help='Height')
args = parser.parse_args()
check_min_version("0.30.2")
garment_result, tryon_result = run_inference(
image_path=args.image,
mask_path=args.mask,
garment_path=args.garment,
num_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
size=(args.width, args.height)
)
output_tryon_path=args.output_tryon
tryon_result.save(output_tryon_path)
print("Successfully saved garment and try-on images")
if __name__ == "__main__":
main() |