multimodalart HF staff commited on
Commit
079a382
1 Parent(s): 3e49fd0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from diffusers import FluxInpaintPipeline
5
+ from PIL import Image
6
+
7
+ # Initialize the pipeline
8
+ pipe = FluxInpaintPipeline.from_pretrained(
9
+ "black-forest-labs/FLUX.1-dev",
10
+ torch_dtype=torch.bfloat16
11
+ )
12
+ pipe.to("cuda")
13
+ pipe.load_lora_weights(
14
+ "ali-vilab/In-Context-LoRA",
15
+ weight_name="visual-identity-design.safetensors"
16
+ )
17
+
18
+ def square_center_crop(img, target_size=768):
19
+ if img.mode in ('RGBA', 'P'):
20
+ img = img.convert('RGB')
21
+
22
+ width, height = img.size
23
+ crop_size = min(width, height)
24
+
25
+ left = (width - crop_size) // 2
26
+ top = (height - crop_size) // 2
27
+ right = left + crop_size
28
+ bottom = top + crop_size
29
+
30
+ img_cropped = img.crop((left, top, right, bottom))
31
+ return img_cropped.resize((target_size, target_size), Image.Resampling.LANCZOS)
32
+
33
+ def duplicate_horizontally(img):
34
+ width, height = img.size
35
+ if width != height:
36
+ raise ValueError(f"Input image must be square, got {width}x{height}")
37
+
38
+ new_image = Image.new('RGB', (width * 2, height))
39
+ new_image.paste(img, (0, 0))
40
+ new_image.paste(img, (width, 0))
41
+ return new_image
42
+
43
+ # Load the mask image
44
+ mask = Image.open("mask_square.png")
45
+
46
+ @spaces.GPU
47
+ def generate(image, prompt_user):
48
+ prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to "
49
+ prompt = prompt_structure + prompt_user
50
+
51
+ cropped_image = square_center_crop(image)
52
+ logo_dupli = duplicate_horizontally(cropped_image)
53
+
54
+ out = pipe(
55
+ prompt=prompt,
56
+ image=logo_dupli,
57
+ mask_image=mask,
58
+ guidance_scale=6,
59
+ height=768,
60
+ width=1536,
61
+ num_inference_steps=28,
62
+ max_sequence_length=256,
63
+ strength=1
64
+ ).images[0]
65
+
66
+ width, height = out.size
67
+ half_width = width // 2
68
+ image_2 = out.crop((half_width, 0, width, height))
69
+ return image_2
70
+
71
+ def process_image(input_image, prompt):
72
+ try:
73
+ if input_image is None:
74
+ return None, "Please upload an image first."
75
+
76
+ if not prompt:
77
+ return None, "Please provide a prompt."
78
+
79
+ result = generate(input_image, prompt)
80
+ return result, "Generation completed successfully!"
81
+ except Exception as e:
82
+ return None, f"Error during generation: {str(e)}"
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("# Logo in Context")
86
+ gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything")
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ input_image = gr.Image(
91
+ label="Upload Logo Image",
92
+ type="pil",
93
+ height=384
94
+ )
95
+ prompt_input = gr.Textbox(
96
+ label="Where should the logo be applied?",
97
+ placeholder="e.g., a coffee cup on a wooden table",
98
+ lines=2
99
+ )
100
+ generate_btn = gr.Button("Generate Application", variant="primary")
101
+
102
+ with gr.Column():
103
+ output_image = gr.Image(label="Generated Application")
104
+ status_text = gr.Textbox(
105
+ label="Status",
106
+ interactive=False
107
+ )
108
+
109
+ with gr.Row():
110
+ gr.Markdown("""
111
+ ### Instructions:
112
+ 1. Upload a logo image (preferably square)
113
+ 2. Describe where you'd like to see the logo applied
114
+ 3. Click 'Generate Application' and wait for the result
115
+
116
+ Note: The generation process might take a few moments.
117
+ """)
118
+
119
+ # Set up the click event
120
+ generate_btn.click(
121
+ fn=process_image,
122
+ inputs=[input_image, prompt_input],
123
+ outputs=[output_image]
124
+ )
125
+
126
+ # Launch the interface
127
+ if __name__ == "__main__":
128
+ demo.launch()