Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import spaces | |
from diffusers import FluxInpaintPipeline | |
from PIL import Image | |
# Initialize the pipeline | |
pipe = FluxInpaintPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.to("cuda") | |
pipe.load_lora_weights( | |
"ali-vilab/In-Context-LoRA", | |
weight_name="visual-identity-design.safetensors" | |
) | |
def square_center_crop(img, target_size=768): | |
if img.mode in ('RGBA', 'P'): | |
img = img.convert('RGB') | |
width, height = img.size | |
crop_size = min(width, height) | |
left = (width - crop_size) // 2 | |
top = (height - crop_size) // 2 | |
right = left + crop_size | |
bottom = top + crop_size | |
img_cropped = img.crop((left, top, right, bottom)) | |
return img_cropped.resize((target_size, target_size), Image.Resampling.LANCZOS) | |
def duplicate_horizontally(img): | |
width, height = img.size | |
if width != height: | |
raise ValueError(f"Input image must be square, got {width}x{height}") | |
new_image = Image.new('RGB', (width * 2, height)) | |
new_image.paste(img, (0, 0)) | |
new_image.paste(img, (width, 0)) | |
return new_image | |
# Load the mask image | |
mask = Image.open("mask_square.png") | |
def generate(image, prompt_user): | |
prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to " | |
prompt = prompt_structure + prompt_user | |
cropped_image = square_center_crop(image) | |
logo_dupli = duplicate_horizontally(cropped_image) | |
out = pipe( | |
prompt=prompt, | |
image=logo_dupli, | |
mask_image=mask, | |
guidance_scale=6, | |
height=768, | |
width=1536, | |
num_inference_steps=28, | |
max_sequence_length=256, | |
strength=1 | |
).images[0] | |
width, height = out.size | |
half_width = width // 2 | |
image_2 = out.crop((half_width, 0, width, height)) | |
return image_2 | |
def process_image(input_image, prompt): | |
try: | |
if input_image is None: | |
return None, "Please upload an image first." | |
if not prompt: | |
return None, "Please provide a prompt." | |
result = generate(input_image, prompt) | |
return result, "Generation completed successfully!" | |
except Exception as e: | |
return None, f"Error during generation: {str(e)}" | |
with gr.Blocks() as demo: | |
gr.Markdown("# Logo in Context") | |
gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Logo Image", | |
type="pil", | |
height=384 | |
) | |
prompt_input = gr.Textbox( | |
label="Where should the logo be applied?", | |
placeholder="e.g., a coffee cup on a wooden table", | |
lines=2 | |
) | |
generate_btn = gr.Button("Generate Application", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Application") | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
with gr.Row(): | |
gr.Markdown(""" | |
### Instructions: | |
1. Upload a logo image (preferably square) | |
2. Describe where you'd like to see the logo applied | |
3. Click 'Generate Application' and wait for the result | |
Note: The generation process might take a few moments. | |
""") | |
# Set up the click event | |
generate_btn.click( | |
fn=process_image, | |
inputs=[input_image, prompt_input], | |
outputs=[output_image] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch() |