yerang commited on
Commit
e8f9a31
โ€ข
1 Parent(s): c072e5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +391 -0
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ The entrance of the gradio
5
+ """
6
+
7
+ import tyro
8
+ import gradio as gr
9
+ import os.path as osp
10
+ from src.utils.helper import load_description
11
+ from src.gradio_pipeline import GradioPipeline
12
+ from src.config.crop_config import CropConfig
13
+ from src.config.argument_config import ArgumentConfig
14
+ from src.config.inference_config import InferenceConfig
15
+ import spaces
16
+ import cv2
17
+
18
+ # import gdown
19
+ # folder_url = f"https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib"
20
+ # gdown.download_folder(url=folder_url, output="pretrained_weights", quiet=False)
21
+
22
+ import sys
23
+ from src.utils.video import extract_audio
24
+ from elevenlabs_utils import ElevenLabsPipeline
25
+ from setup_environment import initialize_environment
26
+
27
+ initialize_environment()
28
+
29
+
30
+
31
+ sys.path.append('/home/user/.local/lib/python3.10/site-packages')
32
+ sys.path.append('/home/user/.local/lib/python3.10/site-packages/stf_alternative/src/stf_alternative')
33
+ sys.path.append('/home/user/.local/lib/python3.10/site-packages/stf_tools/src/stf_tools')
34
+ sys.path.append('/home/user/app/')
35
+ sys.path.append('/home/user/app/stf/')
36
+ sys.path.append('/home/user/app/stf/stf_alternative/')
37
+ sys.path.append('/home/user/app/stf/stf_alternative/src/stf_alternative')
38
+ sys.path.append('/home/user/app/stf/stf_tools')
39
+ sys.path.append('/home/user/app/stf/stf_tools/src/stf_tools')
40
+
41
+
42
+
43
+ # CUDA ๊ฒฝ๋กœ๋ฅผ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋กœ ์„ค์ •
44
+ os.environ['PATH'] = '/usr/local/cuda/bin:' + os.environ.get('PATH', '')
45
+ os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')
46
+ # ํ™•์ธ์šฉ ์ถœ๋ ฅ
47
+ print("PATH:", os.environ['PATH'])
48
+ print("LD_LIBRARY_PATH:", os.environ['LD_LIBRARY_PATH'])
49
+
50
+ from stf_utils import STFPipeline
51
+
52
+
53
+ def partial_fields(target_class, kwargs):
54
+ return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
55
+
56
+ # set tyro theme
57
+ tyro.extras.set_accent_color("bright_cyan")
58
+ args = tyro.cli(ArgumentConfig)
59
+
60
+ # specify configs for inference
61
+ inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
62
+ crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
63
+
64
+ gradio_pipeline = GradioPipeline(
65
+ inference_cfg=inference_cfg,
66
+ crop_cfg=crop_cfg,
67
+ args=args
68
+ )
69
+
70
+ @spaces.GPU(duration=240)
71
+ def gpu_wrapped_execute_video(*args, **kwargs):
72
+ return gradio_pipeline.execute_video(*args, **kwargs)
73
+
74
+ @spaces.GPU(duration=240)
75
+ def gpu_wrapped_execute_image(*args, **kwargs):
76
+ return gradio_pipeline.execute_image(*args, **kwargs)
77
+
78
+ @spaces.GPU(duration=240)
79
+ def gpu_wrapped_stf_pipeline_execute(audio_path):
80
+ return stf_pipeline.execute(audio_path)
81
+
82
+ @spaces.GPU(duration=240)
83
+ def gpu_wrapped_elevenlabs_pipeline_generate_voice(text, voice):
84
+ return elevenlabs_pipeline.generate_voice(text, voice)
85
+
86
+
87
+ def is_square_video(video_path):
88
+ video = cv2.VideoCapture(video_path)
89
+
90
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
91
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
92
+
93
+ video.release()
94
+ if width != height:
95
+ raise gr.Error("Error: the video does not have a square aspect ratio. We currently only support square videos")
96
+
97
+ return gr.update(visible=True)
98
+
99
+
100
+ # assets
101
+ title_md = "assets/gradio_title.md"
102
+ example_portrait_dir = "assets/examples/source"
103
+ example_video_dir = "assets/examples/driving"
104
+ data_examples = [
105
+ [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
106
+ [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
107
+ [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
108
+ [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, True],
109
+ [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, True],
110
+ [osp.join(example_portrait_dir, "s22.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
111
+ ]
112
+ #################### interface logic ####################
113
+
114
+
115
+
116
+ # Define components first
117
+ eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
118
+ lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
119
+ retargeting_input_image = gr.Image(type="filepath")
120
+ output_image = gr.Image(type="numpy")
121
+ output_image_paste_back = gr.Image(type="numpy")
122
+ output_video = gr.Video()
123
+ output_video_concat = gr.Video()
124
+
125
+
126
+
127
+
128
+ def run_end_to_end(image_path, text, voice, input_video, flag_relative, flag_do_crop, flag_remap, flag_crop_driving_video, male, animal):
129
+
130
+ # # animal ์ฒดํฌ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ๋‹ค๋ฅธ pipeline ์‚ฌ์šฉ
131
+ # if animal:
132
+ # gradio_pipeline = GradioPipelineAnimal(
133
+ # inference_cfg=inference_cfg,
134
+ # crop_cfg=crop_cfg,
135
+ # args=args
136
+ # )
137
+ # else:
138
+ # gradio_pipeline = GradioPipeline(
139
+ # inference_cfg=inference_cfg,
140
+ # crop_cfg=crop_cfg,
141
+ # args=args
142
+ # )
143
+
144
+
145
+ if not male:
146
+ stf_pipeline = STFPipeline()
147
+ else:
148
+ stf_pipeline = STFPipeline(template_video_path="/home/user/app/stf/TEMP/Cam2_2309071202_0012_Natural_Looped.mp4",
149
+ config_path="/home/user/app/stf/TEMP/front_config_v3.json",
150
+ checkpoint_path="/home/user/app/stf/TEMP/0157.pth",
151
+ )
152
+
153
+ if input_video is None:
154
+ #audio_path = elevenlabs_pipeline.generate_voice(text, voice)
155
+ audio_path = gpu_wrapped_elevenlabs_pipeline_generate_voice(text, voice)
156
+ #driving_video_path = stf_pipeline.execute(audio_path)
157
+ driving_video_path = gpu_wrapped_stf_pipeline_execute(audio_path)
158
+ else:
159
+ driving_video_path = input_video
160
+ audio_path = osp.join("animations", Path(input_video).stem+".wav")
161
+ extract_audio(driving_video_path, audio_path)
162
+
163
+
164
+ #output_path, crop_output_path = gradio_pipeline.execute_video(
165
+ output_path, crop_output_path = gpu_wrapped_execute_video(
166
+ input_source_image_path=image_path,
167
+ input_driving_video_path=driving_video_path,
168
+ # input_driving_video_pickle_path=None,
169
+ flag_do_crop_input=flag_do_crop,
170
+ # flag_remap_input=flag_remap,
171
+ # driving_multiplier=1.0,
172
+ # flag_stitching=False,
173
+ flag_crop_driving_video_input=flag_crop_driving_video,
174
+ # scale=2.3,
175
+ # vx_ratio=0.0,
176
+ # vy_ratio=-0.125,
177
+ # scale_crop_driving_video=2.2,
178
+ # vx_ratio_crop_driving_video=0.0,
179
+ # vy_ratio_crop_driving_video=-0.1,
180
+ # tab_selection=None,
181
+ audio_path=audio_path
182
+ )
183
+
184
+ return output_path, crop_output_path
185
+
186
+
187
+
188
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
189
+ with gr.Tabs():
190
+ # ์ฒซ ๋ฒˆ์งธ ํƒญ: Text to LipSync
191
+ with gr.Tab("Text to LipSync"):
192
+ gr.Markdown("# Text to LipSync")
193
+ with gr.Row():
194
+ script_txt = gr.Text()
195
+ voice = gr.Audio(label="์‚ฌ์šฉ์ž ์Œ์„ฑ", type="filepath")
196
+ input_video = gr.Video()
197
+
198
+ with gr.Row():
199
+ image_input = gr.Image(type="filepath") # ์—ฌ๊ธฐ์„œ image_input์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
200
+ output_video.render()
201
+ crop_output_video.render()
202
+
203
+ with gr.Row():
204
+ flag_relative_input = gr.Checkbox(value=True, label="relative motion")
205
+ flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
206
+ flag_remap_input = gr.Checkbox(value=True, label="paste-back")
207
+ flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
208
+ male = gr.Checkbox(value=False, label="male")
209
+ animal = gr.Checkbox(value=False, label="animal") # animal ์ฒดํฌ๋ฐ•์Šค ์ถ”๊ฐ€
210
+
211
+ with gr.Row():
212
+ generate_speech = gr.Button("๐Ÿš€ Generate Speech", variant="primary")
213
+
214
+ generate_speech.click(
215
+ fn=run_end_to_end,
216
+ inputs=[
217
+ image_input,
218
+ script_txt,
219
+ voice,
220
+ input_video,
221
+ flag_relative_input,
222
+ flag_do_crop_input,
223
+ flag_remap_input,
224
+ flag_crop_driving_video_input,
225
+ male,
226
+ animal # ์ถ”๊ฐ€๋œ animal ์ž…๋ ฅ
227
+ ],
228
+ outputs=[output_video, crop_output_video]
229
+ )
230
+
231
+ # # ๋‘ ๋ฒˆ์งธ ํƒญ: FLUX ์ด๋ฏธ์ง€ ์ƒ์„ฑ
232
+ # with gr.Tab("FLUX ์ด๋ฏธ์ง€ ์ƒ์„ฑ"):
233
+ # flux_tab(image_input) # FLUX ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์œ„ํ•œ ๋ณ„๋„์˜ ํƒญ
234
+
235
+ # # ์„ธ ๋ฒˆ์งธ ํƒญ: Flux ๊ฐœ๋ฐœ์šฉ ํƒญ
236
+ # with gr.Tab("FLUX Dev"):
237
+ # flux_demo = create_flux_tab() # Flux ๊ฐœ๋ฐœ์šฉ ํƒญ ์ƒ์„ฑ
238
+ # #flux_demo.render() # ํ•ด๋‹น UI๋ฅผ ๋ณ„๋„์˜ ํƒญ์—๋งŒ ๋ Œ๋”๋ง
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
+ # with gr.Blocks(theme=gr.themes.Soft()) as demo:
248
+ # gr.HTML(load_description(title_md))
249
+ # gr.Markdown(load_description("assets/gradio_description_upload.md"))
250
+ # with gr.Row():
251
+ # with gr.Accordion(open=True, label="Source Portrait"):
252
+ # image_input = gr.Image(type="filepath")
253
+ # gr.Examples(
254
+ # examples=[
255
+ # [osp.join(example_portrait_dir, "s9.jpg")],
256
+ # [osp.join(example_portrait_dir, "s6.jpg")],
257
+ # [osp.join(example_portrait_dir, "s10.jpg")],
258
+ # [osp.join(example_portrait_dir, "s5.jpg")],
259
+ # [osp.join(example_portrait_dir, "s7.jpg")],
260
+ # [osp.join(example_portrait_dir, "s12.jpg")],
261
+ # [osp.join(example_portrait_dir, "s22.jpg")],
262
+ # ],
263
+ # inputs=[image_input],
264
+ # cache_examples=False,
265
+ # )
266
+ # with gr.Accordion(open=True, label="Driving Video"):
267
+ # video_input = gr.Video()
268
+ # gr.Examples(
269
+ # examples=[
270
+ # [osp.join(example_video_dir, "d0.mp4")],
271
+ # [osp.join(example_video_dir, "d18.mp4")],
272
+ # [osp.join(example_video_dir, "d19.mp4")],
273
+ # [osp.join(example_video_dir, "d14_trim.mp4")],
274
+ # [osp.join(example_video_dir, "d6_trim.mp4")],
275
+ # ],
276
+ # inputs=[video_input],
277
+ # cache_examples=False,
278
+ # )
279
+ # with gr.Row():
280
+ # with gr.Accordion(open=False, label="Animation Instructions and Options"):
281
+ # gr.Markdown(load_description("assets/gradio_description_animation.md"))
282
+ # with gr.Row():
283
+ # flag_relative_input = gr.Checkbox(value=True, label="relative motion")
284
+ # flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
285
+ # flag_remap_input = gr.Checkbox(value=True, label="paste-back")
286
+ # gr.Markdown(load_description("assets/gradio_description_animate_clear.md"))
287
+ # with gr.Row():
288
+ # with gr.Column():
289
+ # process_button_animation = gr.Button("๐Ÿš€ Animate", variant="primary")
290
+ # with gr.Column():
291
+ # process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="๐Ÿงน Clear")
292
+ # with gr.Row():
293
+ # with gr.Column():
294
+ # with gr.Accordion(open=True, label="The animated video in the original image space"):
295
+ # output_video.render()
296
+ # with gr.Column():
297
+ # with gr.Accordion(open=True, label="The animated video"):
298
+ # output_video_concat.render()
299
+ # with gr.Row():
300
+ # # Examples
301
+ # gr.Markdown("## You could also choose the examples below by one click โฌ‡๏ธ")
302
+ # with gr.Row():
303
+ # gr.Examples(
304
+ # examples=data_examples,
305
+ # fn=gpu_wrapped_execute_video,
306
+ # inputs=[
307
+ # image_input,
308
+ # video_input,
309
+ # flag_relative_input,
310
+ # flag_do_crop_input,
311
+ # flag_remap_input
312
+ # ],
313
+ # outputs=[output_image, output_image_paste_back],
314
+ # examples_per_page=6,
315
+ # cache_examples=False,
316
+ # )
317
+ # gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
318
+ # with gr.Row(visible=True):
319
+ # eye_retargeting_slider.render()
320
+ # lip_retargeting_slider.render()
321
+ # with gr.Row(visible=True):
322
+ # process_button_retargeting = gr.Button("๐Ÿš— Retargeting", variant="primary")
323
+ # process_button_reset_retargeting = gr.ClearButton(
324
+ # [
325
+ # eye_retargeting_slider,
326
+ # lip_retargeting_slider,
327
+ # retargeting_input_image,
328
+ # output_image,
329
+ # output_image_paste_back
330
+ # ],
331
+ # value="๐Ÿงน Clear"
332
+ # )
333
+ # with gr.Row(visible=True):
334
+ # with gr.Column():
335
+ # with gr.Accordion(open=True, label="Retargeting Input"):
336
+ # retargeting_input_image.render()
337
+ # gr.Examples(
338
+ # examples=[
339
+ # [osp.join(example_portrait_dir, "s9.jpg")],
340
+ # [osp.join(example_portrait_dir, "s6.jpg")],
341
+ # [osp.join(example_portrait_dir, "s10.jpg")],
342
+ # [osp.join(example_portrait_dir, "s5.jpg")],
343
+ # [osp.join(example_portrait_dir, "s7.jpg")],
344
+ # [osp.join(example_portrait_dir, "s12.jpg")],
345
+ # [osp.join(example_portrait_dir, "s22.jpg")],
346
+ # ],
347
+ # inputs=[retargeting_input_image],
348
+ # cache_examples=False,
349
+ # )
350
+ # with gr.Column():
351
+ # with gr.Accordion(open=True, label="Retargeting Result"):
352
+ # output_image.render()
353
+ # with gr.Column():
354
+ # with gr.Accordion(open=True, label="Paste-back Result"):
355
+ # output_image_paste_back.render()
356
+ # # binding functions for buttons
357
+ # process_button_retargeting.click(
358
+ # # fn=gradio_pipeline.execute_image,
359
+ # fn=gpu_wrapped_execute_image,
360
+ # inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
361
+ # outputs=[output_image, output_image_paste_back],
362
+ # show_progress=True
363
+ # )
364
+ # process_button_animation.click(
365
+ # fn=gpu_wrapped_execute_video,
366
+ # inputs=[
367
+ # image_input,
368
+ # video_input,
369
+ # flag_relative_input,
370
+ # flag_do_crop_input,
371
+ # flag_remap_input
372
+ # ],
373
+ # outputs=[output_video, output_video_concat],
374
+ # show_progress=True
375
+ # )
376
+ # # image_input.change(
377
+ # # fn=gradio_pipeline.prepare_retargeting,
378
+ # # inputs=image_input,
379
+ # # outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
380
+ # # )
381
+ # video_input.upload(
382
+ # fn=is_square_video,
383
+ # inputs=video_input,
384
+ # outputs=video_input
385
+ # )
386
+
387
+ demo.launch(
388
+ server_port=args.server_port,
389
+ share=args.share,
390
+ server_name=args.server_name
391
+ )