franciszzj commited on
Commit
24e151d
1 Parent(s): 1eea134

update app.py

Browse files
Files changed (2) hide show
  1. app.py +130 -52
  2. utils/utils.py +31 -0
app.py CHANGED
@@ -6,11 +6,12 @@ from leffa.model import LeffaModel
6
  from leffa.inference import LeffaInference
7
  from utils.garment_agnostic_mask_predictor import AutoMasker
8
  from utils.densepose_predictor import DensePosePredictor
 
9
 
10
  import gradio as gr
11
 
12
  # Download checkpoints
13
- snapshot_download(repo_id="franciszzj/Leffa", local_dir="./")
14
 
15
 
16
  def leffa_predict(src_image_path, ref_image_path, control_type):
@@ -18,6 +19,8 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
18
  "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
19
  src_image = Image.open(src_image_path)
20
  ref_image = Image.open(ref_image_path)
 
 
21
 
22
  src_image_array = np.array(src_image)
23
  ref_image_array = np.array(ref_image)
@@ -74,6 +77,14 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
74
  return np.array(gen_image)
75
 
76
 
 
 
 
 
 
 
 
 
77
  if __name__ == "__main__":
78
  # import sys
79
 
@@ -82,56 +93,123 @@ if __name__ == "__main__":
82
  # control_type = sys.argv[3]
83
  # leffa_predict(src_image_path, ref_image_path, control_type)
84
 
85
- with gr.Blocks().queue() as demo:
86
- gr.Markdown(
87
- "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation")
88
- gr.Markdown("Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer).")
89
- with gr.Row():
90
- with gr.Column():
91
- src_image = gr.Image(
92
- sources=["upload"],
93
- type="filepath",
94
- label="Source Person Image",
95
- width=384,
96
- height=512,
97
- )
98
- with gr.Row():
99
- control_type = gr.Dropdown(
100
- ["virtual_tryon", "pose_transfer"], label="Control Type")
101
-
102
- example = gr.Examples(
103
- inputs=src_image,
104
- examples_per_page=10,
105
- examples=["./examples/14684_00_person.jpg",
106
- "./examples/14092_00_person.jpg"],
107
- )
108
-
109
- with gr.Column():
110
- ref_image = gr.Image(
111
- sources=["upload"],
112
- type="filepath",
113
- label="Reference Image",
114
- width=384,
115
- height=512,
116
- )
117
- with gr.Row():
118
- gen_button = gr.Button("Generate")
119
-
120
- example = gr.Examples(
121
- inputs=ref_image,
122
- examples_per_page=10,
123
- examples=["./examples/04181_00_garment.jpg",
124
- "./examples/14684_00_person.jpg"],
125
- )
126
-
127
- with gr.Column():
128
- gen_image = gr.Image(
129
- label="Generated Person Image",
130
- width=384,
131
- height=512,
132
- )
133
-
134
- gen_button.click(fn=leffa_predict, inputs=[
135
- src_image, ref_image, control_type], outputs=[gen_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  demo.launch(share=True, server_port=7860)
 
6
  from leffa.inference import LeffaInference
7
  from utils.garment_agnostic_mask_predictor import AutoMasker
8
  from utils.densepose_predictor import DensePosePredictor
9
+ from utils.utils import resize_and_center
10
 
11
  import gradio as gr
12
 
13
  # Download checkpoints
14
+ snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
15
 
16
 
17
  def leffa_predict(src_image_path, ref_image_path, control_type):
 
19
  "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
20
  src_image = Image.open(src_image_path)
21
  ref_image = Image.open(ref_image_path)
22
+ src_image = resize_and_center(src_image, 768, 1024)
23
+ ref_image = resize_and_center(ref_image, 768, 1024)
24
 
25
  src_image_array = np.array(src_image)
26
  ref_image_array = np.array(ref_image)
 
77
  return np.array(gen_image)
78
 
79
 
80
+ def leffa_predict_vt(src_image_path, ref_image_path):
81
+ return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
82
+
83
+
84
+ def leffa_predict_pt(src_image_path, ref_image_path):
85
+ return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
86
+
87
+
88
  if __name__ == "__main__":
89
  # import sys
90
 
 
93
  # control_type = sys.argv[3]
94
  # leffa_predict(src_image_path, ref_image_path, control_type)
95
 
96
+ title = "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation"
97
+ description = "Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer)."
98
+
99
+ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo:
100
+ gr.Markdown(title)
101
+ gr.Markdown(description)
102
+
103
+ with gr.Tab("Control Appearance (Virtual Try-on)"):
104
+ with gr.Row():
105
+ with gr.Column():
106
+ gr.Markdown("#### Person Image")
107
+ vt_src_image = gr.Image(
108
+ sources=["upload"],
109
+ type="filepath",
110
+ label="Person Image",
111
+ width=512,
112
+ height=512,
113
+ )
114
+
115
+ gr.Examples(
116
+ inputs=vt_src_image,
117
+ examples_per_page=5,
118
+ examples=["./ckpts/examples/person1/01320_00.jpg",
119
+ "./ckpts/examples/person1/01350_00.jpg",
120
+ "./ckpts/examples/person1/01365_00.jpg",
121
+ "./ckpts/examples/person1/01376_00.jpg",
122
+ "./ckpts/examples/person1/01416_00.jpg",],
123
+ )
124
+
125
+ with gr.Column():
126
+ gr.Markdown("#### Garment Image")
127
+ vt_ref_image = gr.Image(
128
+ sources=["upload"],
129
+ type="filepath",
130
+ label="Garment Image",
131
+ width=512,
132
+ height=512,
133
+ )
134
+
135
+ gr.Examples(
136
+ inputs=vt_ref_image,
137
+ examples_per_page=5,
138
+ examples=["./ckpts/examples/garment/01449_00.jpg",
139
+ "./ckpts/examples/garment/01486_00.jpg",
140
+ "./ckpts/examples/garment/01853_00.jpg",
141
+ "./ckpts/examples/garment/02070_00.jpg",
142
+ "./ckpts/examples/garment/03553_00.jpg",],
143
+ )
144
+
145
+ with gr.Column():
146
+ gr.Markdown("#### Generated Image")
147
+ vt_gen_image = gr.Image(
148
+ label="Generated Image",
149
+ width=512,
150
+ height=512,
151
+ )
152
+
153
+ with gr.Row():
154
+ vt_gen_button = gr.Button("Generate")
155
+
156
+ vt_gen_button.click(fn=leffa_predict_vt, inputs=[
157
+ vt_src_image, vt_ref_image], outputs=[vt_gen_image])
158
+
159
+ with gr.Tab("Control Pose (Pose Transfer)"):
160
+ with gr.Row():
161
+ with gr.Column():
162
+ gr.Markdown("#### Person Image")
163
+ pt_ref_image = gr.Image(
164
+ sources=["upload"],
165
+ type="filepath",
166
+ label="Person Image",
167
+ width=512,
168
+ height=512,
169
+ )
170
+
171
+ gr.Examples(
172
+ inputs=vt_src_image,
173
+ examples_per_page=5,
174
+ examples=["./ckpts/examples/person1/01320_00.jpg",
175
+ "./ckpts/examples/person1/01350_00.jpg",
176
+ "./ckpts/examples/person1/01365_00.jpg",
177
+ "./ckpts/examples/person1/01376_00.jpg",
178
+ "./ckpts/examples/person1/01416_00.jpg",],
179
+ )
180
+
181
+ with gr.Column():
182
+ gr.Markdown("#### Target Pose Person Image")
183
+ pt_src_image = gr.Image(
184
+ sources=["upload"],
185
+ type="filepath",
186
+ label="Target Pose Person Image",
187
+ width=512,
188
+ height=512,
189
+ )
190
+
191
+ gr.Examples(
192
+ inputs=pt_src_image,
193
+ examples_per_page=5,
194
+ examples=["./ckpts/examples/person2/01850_00.jpg",
195
+ "./ckpts/examples/person2/01875_00.jpg",
196
+ "./ckpts/examples/person2/02532_00.jpg",
197
+ "./ckpts/examples/person2/02902_00.jpg",
198
+ "./ckpts/examples/person2/05346_00.jpg",],
199
+ )
200
+
201
+ with gr.Column():
202
+ gr.Markdown("#### Generated Image")
203
+ pt_gen_image = gr.Image(
204
+ label="Generated Image",
205
+ width=512,
206
+ height=512,
207
+ )
208
+
209
+ with gr.Row():
210
+ pose_transfer_gen_button = gr.Button("Generate")
211
+
212
+ pose_transfer_gen_button.click(fn=leffa_predict_pt, inputs=[
213
+ pt_src_image, pt_ref_image], outputs=[pt_gen_image])
214
 
215
  demo.launch(share=True, server_port=7860)
utils/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+
4
+ def resize_and_center(image, target_width, target_height, fill_color=(255, 255, 255)):
5
+ """
6
+ Resize the image to fit within (target_width, target_height) while maintaining aspect ratio,
7
+ and center it with padding to match the exact target size.
8
+
9
+ Parameters:
10
+ - image: PIL.Image object
11
+ - target_width: Desired width of the final image
12
+ - target_height: Desired height of the final image
13
+ - fill_color: Background color used for padding
14
+
15
+ Returns:
16
+ - A resized and centered PIL.Image object
17
+ """
18
+ # Resize the image while maintaining the aspect ratio
19
+ image.thumbnail((target_width, target_height), Image.Resampling.LANCZOS)
20
+
21
+ # Create a new image with the desired size and fill color
22
+ new_image = Image.new("RGB", (target_width, target_height), fill_color)
23
+
24
+ # Calculate the position to center the resized image
25
+ x_offset = (target_width - image.width) // 2
26
+ y_offset = (target_height - image.height) // 2
27
+
28
+ # Paste the resized image onto the new image with padding
29
+ new_image.paste(image, (x_offset, y_offset))
30
+
31
+ return new_image