multimodalart's picture
Create app.py
079a382 verified
raw
history blame
3.83 kB
import gradio as gr
import torch
import spaces
from diffusers import FluxInpaintPipeline
from PIL import Image
# Initialize the pipeline
pipe = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
pipe.load_lora_weights(
"ali-vilab/In-Context-LoRA",
weight_name="visual-identity-design.safetensors"
)
def square_center_crop(img, target_size=768):
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
width, height = img.size
crop_size = min(width, height)
left = (width - crop_size) // 2
top = (height - crop_size) // 2
right = left + crop_size
bottom = top + crop_size
img_cropped = img.crop((left, top, right, bottom))
return img_cropped.resize((target_size, target_size), Image.Resampling.LANCZOS)
def duplicate_horizontally(img):
width, height = img.size
if width != height:
raise ValueError(f"Input image must be square, got {width}x{height}")
new_image = Image.new('RGB', (width * 2, height))
new_image.paste(img, (0, 0))
new_image.paste(img, (width, 0))
return new_image
# Load the mask image
mask = Image.open("mask_square.png")
@spaces.GPU
def generate(image, prompt_user):
prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to "
prompt = prompt_structure + prompt_user
cropped_image = square_center_crop(image)
logo_dupli = duplicate_horizontally(cropped_image)
out = pipe(
prompt=prompt,
image=logo_dupli,
mask_image=mask,
guidance_scale=6,
height=768,
width=1536,
num_inference_steps=28,
max_sequence_length=256,
strength=1
).images[0]
width, height = out.size
half_width = width // 2
image_2 = out.crop((half_width, 0, width, height))
return image_2
def process_image(input_image, prompt):
try:
if input_image is None:
return None, "Please upload an image first."
if not prompt:
return None, "Please provide a prompt."
result = generate(input_image, prompt)
return result, "Generation completed successfully!"
except Exception as e:
return None, f"Error during generation: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("# Logo in Context")
gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Logo Image",
type="pil",
height=384
)
prompt_input = gr.Textbox(
label="Where should the logo be applied?",
placeholder="e.g., a coffee cup on a wooden table",
lines=2
)
generate_btn = gr.Button("Generate Application", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Application")
status_text = gr.Textbox(
label="Status",
interactive=False
)
with gr.Row():
gr.Markdown("""
### Instructions:
1. Upload a logo image (preferably square)
2. Describe where you'd like to see the logo applied
3. Click 'Generate Application' and wait for the result
Note: The generation process might take a few moments.
""")
# Set up the click event
generate_btn.click(
fn=process_image,
inputs=[input_image, prompt_input],
outputs=[output_image]
)
# Launch the interface
if __name__ == "__main__":
demo.launch()