QinOwen commited on
Commit
fe91ef5
·
1 Parent(s): 0e20ffa

add-example-fix-bug

Browse files
VADER-VideoCrafter/scripts/main/train_t2v_lora.py CHANGED
@@ -573,54 +573,11 @@ def run_training(args, peft_model, **kwargs):
573
  gradient_accumulation_steps=args.gradient_accumulation_steps,
574
  mixed_precision=args.mixed_precision,
575
  project_dir=args.project_dir
576
-
577
  )
578
  output_dir = args.project_dir
579
 
580
  # Make one log on every process with the configuration for debugging.
581
  create_logging(logging, logger, accelerator)
582
-
583
- # ## ------------------------step 2: model config-----------------------------
584
- # # download the checkpoint for VideoCrafter2 model
585
- # ckpt_dir = args.ckpt_path.split('/') # args.ckpt='checkpoints/base_512_v2/model.ckpt' -> 'checkpoints/base_512_v2'
586
- # ckpt_dir = '/'.join(ckpt_dir[:-1])
587
- # snapshot_download(repo_id='VideoCrafter/VideoCrafter2', local_dir =ckpt_dir)
588
-
589
- # # load the model
590
- # config = OmegaConf.load(args.config)
591
- # model_config = config.pop("model", OmegaConf.create())
592
- # model = instantiate_from_config(model_config)
593
-
594
- # assert os.path.exists(args.ckpt_path), f"Error: checkpoint [{args.ckpt_path}] Not Found!"
595
- # model = load_model_checkpoint(model, args.ckpt_path)
596
-
597
-
598
- # # convert first_stage_model and cond_stage_model to torch.float16 if mixed_precision is True
599
- # if args.mixed_precision != 'no':
600
- # model.first_stage_model = model.first_stage_model.half()
601
- # model.cond_stage_model = model.cond_stage_model.half()
602
-
603
- # # step 2.1: add LoRA using peft
604
- # config = peft.LoraConfig(
605
- # r=args.lora_rank,
606
- # target_modules=["to_k", "to_v", "to_q"], # only diffusion_model has these modules
607
- # lora_dropout=0.01,
608
- # )
609
-
610
- # peft_model = peft.get_peft_model(model, config)
611
-
612
- # peft_model.print_trainable_parameters()
613
-
614
- # # load the pretrained LoRA model
615
- # if args.lora_ckpt_path is not None:
616
- # if args.lora_ckpt_path == "huggingface-hps-aesthetic": # download the pretrained LoRA model from huggingface
617
- # snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
618
- # args.lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_hps_aesthetic.pt'
619
- # elif args.lora_ckpt_path == "huggingface-pickscore": # download the pretrained LoRA model from huggingface
620
- # snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
621
- # args.lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_pickscore.pt'
622
- # # load the pretrained LoRA model
623
- # peft.set_peft_model_state_dict(peft_model, torch.load(args.lora_ckpt_path))
624
 
625
  # Inference Step: only do inference and save the videos. Skip this step if it is training
626
  # ==================================================================
@@ -749,7 +706,7 @@ def setup_model(lora_ckpt_path="huggingface-pickscore", lora_rank=16):
749
  # download the checkpoint for VideoCrafter2 model
750
  ckpt_dir = args.ckpt_path.split('/') # args.ckpt='checkpoints/base_512_v2/model.ckpt' -> 'checkpoints/base_512_v2'
751
  ckpt_dir = '/'.join(ckpt_dir[:-1])
752
- snapshot_download(repo_id='VideoCrafter/VideoCrafter2', local_dir =ckpt_dir)
753
 
754
  # load the model
755
  config = OmegaConf.load(args.config)
@@ -766,7 +723,7 @@ def setup_model(lora_ckpt_path="huggingface-pickscore", lora_rank=16):
766
 
767
  # step 2.1: add LoRA using peft
768
  config = peft.LoraConfig(
769
- r=args.lora_rank,
770
  target_modules=["to_k", "to_v", "to_q"], # only diffusion_model has these modules
771
  lora_dropout=0.01,
772
  )
@@ -783,6 +740,14 @@ def setup_model(lora_ckpt_path="huggingface-pickscore", lora_rank=16):
783
  elif lora_ckpt_path == "huggingface-pickscore": # download the pretrained LoRA model from huggingface
784
  snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
785
  lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_pickscore.pt'
 
 
 
 
 
 
 
 
786
  # load the pretrained LoRA model
787
  peft.set_peft_model_state_dict(peft_model, torch.load(lora_ckpt_path))
788
 
 
573
  gradient_accumulation_steps=args.gradient_accumulation_steps,
574
  mixed_precision=args.mixed_precision,
575
  project_dir=args.project_dir
 
576
  )
577
  output_dir = args.project_dir
578
 
579
  # Make one log on every process with the configuration for debugging.
580
  create_logging(logging, logger, accelerator)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  # Inference Step: only do inference and save the videos. Skip this step if it is training
583
  # ==================================================================
 
706
  # download the checkpoint for VideoCrafter2 model
707
  ckpt_dir = args.ckpt_path.split('/') # args.ckpt='checkpoints/base_512_v2/model.ckpt' -> 'checkpoints/base_512_v2'
708
  ckpt_dir = '/'.join(ckpt_dir[:-1])
709
+ snapshot_download(repo_id='VideoCrafter/VideoCrafter2', local_dir=ckpt_dir)
710
 
711
  # load the model
712
  config = OmegaConf.load(args.config)
 
723
 
724
  # step 2.1: add LoRA using peft
725
  config = peft.LoraConfig(
726
+ r=lora_rank,
727
  target_modules=["to_k", "to_v", "to_q"], # only diffusion_model has these modules
728
  lora_dropout=0.01,
729
  )
 
740
  elif lora_ckpt_path == "huggingface-pickscore": # download the pretrained LoRA model from huggingface
741
  snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
742
  lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_pickscore.pt'
743
+ elif lora_ckpt_path == "peft_model_532":
744
+ lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_532.pt'
745
+ elif lora_ckpt_path == "peft_model_548":
746
+ lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_548.pt'
747
+ elif lora_ckpt_path == "peft_model_536":
748
+ lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_536.pt'
749
+ elif lora_ckpt_path == "peft_model_400":
750
+ lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_400.pt'
751
  # load the pretrained LoRA model
752
  peft.set_peft_model_state_dict(peft_model, torch.load(lora_ckpt_path))
753
 
app.py CHANGED
@@ -1,15 +1,27 @@
1
  import gradio as gr
2
  import os
3
-
4
  import sys
5
  sys.path.append('./VADER-VideoCrafter/scripts/main')
6
  sys.path.append('./VADER-VideoCrafter/scripts')
7
  sys.path.append('./VADER-VideoCrafter')
8
 
 
9
  from train_t2v_lora import main_fn, setup_model
10
 
 
 
 
 
 
 
 
 
 
 
11
  model = None # Placeholder for model
12
 
 
13
  def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
14
  frames, savefps):
15
  global model
@@ -30,7 +42,7 @@ def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, dd
30
 
31
  def reset_fn():
32
  return ("A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
33
- 200, 320, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore")
34
 
35
  def update_lora_rank(lora_model):
36
  if lora_model == "huggingface-pickscore":
@@ -38,7 +50,7 @@ def update_lora_rank(lora_model):
38
  elif lora_model == "huggingface-hps-aesthetic":
39
  return gr.update(value=8)
40
  else: # "Base Model"
41
- return gr.update(value=0)
42
 
43
  def update_dropdown(lora_rank):
44
  if lora_rank == 16:
@@ -48,7 +60,7 @@ def update_dropdown(lora_rank):
48
  else: # 0
49
  return gr.update(value="Base Model")
50
 
51
-
52
  def setup_model_progress(lora_model, lora_rank):
53
  global model
54
 
@@ -60,15 +72,58 @@ def setup_model_progress(lora_model, lora_rank):
60
  # Enable buttons after loading and update indicator
61
  yield (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), "Model loaded successfully")
62
 
63
- css = """
64
- .centered {
65
- display: flex;
66
- justify-content: center;
67
- }
68
- """
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- with gr.Blocks(css=css) as demo:
72
  with gr.Row():
73
  with gr.Column():
74
  gr.HTML(
@@ -152,21 +207,23 @@ with gr.Blocks(css=css) as demo:
152
  """
153
  )
154
 
155
- with gr.Row(elem_classes="centered"):
156
- with gr.Column(scale=0.6):
157
- output_video = gr.Video()
158
-
159
- with gr.Row():
160
- lora_model = gr.Dropdown(
161
- label="VADER Model",
162
- choices=["huggingface-pickscore", "huggingface-hps-aesthetic", "Base Model"],
163
- value="huggingface-pickscore"
164
- )
165
- lora_rank = gr.Slider(minimum=0, maximum=16, label="LoRA Rank", step = 8, value=16)
166
  load_btn = gr.Button("Load Model")
167
  # Add a label to show the loading indicator
168
  loading_indicator = gr.Label(value="", label="Loading Indicator")
 
 
 
169
 
 
 
170
  prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
171
  value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
172
 
@@ -176,7 +233,7 @@ with gr.Blocks(css=css) as demo:
176
 
177
 
178
  with gr.Row():
179
- height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=320)
180
  width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
181
 
182
  with gr.Row():
@@ -205,6 +262,12 @@ with gr.Blocks(css=css) as demo:
205
  lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank)
206
  lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
207
 
208
- demo.launch()
 
 
 
 
 
 
209
 
210
- # main_fn(prompt="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",)
 
1
  import gradio as gr
2
  import os
3
+ import spaces
4
  import sys
5
  sys.path.append('./VADER-VideoCrafter/scripts/main')
6
  sys.path.append('./VADER-VideoCrafter/scripts')
7
  sys.path.append('./VADER-VideoCrafter')
8
 
9
+
10
  from train_t2v_lora import main_fn, setup_model
11
 
12
+ examples = [
13
+ ["A fairy tends to enchanted, glowing flowers.", 'huggingface-hps-aesthetic', 8, 400, 384, 512, 12.0, 25, 1.0, 24, 10],
14
+ ["A cat playing an electric guitar in a loft with industrial-style decor and soft, multicolored lights.", 'huggingface-hps-aesthetic', 8, 206, 384, 512, 12.0, 25, 1.0, 24, 10],
15
+ ["A raccoon playing a guitar under a blossoming cherry tree.", 'huggingface-hps-aesthetic', 8, 204, 384, 512, 12.0, 25, 1.0, 24, 10],
16
+ ["A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
17
+ "huggingface-pickscore", 16, 205, 384, 512, 12.0, 25, 1.0, 24, 10],
18
+ ["A talking bird with shimmering feathers and a melodious voice leads an adventure to find a legendary treasure, guiding through enchanted forests, ancient ruins, and mystical challenges.",
19
+ "huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10]
20
+ ]
21
+
22
  model = None # Placeholder for model
23
 
24
+ @spaces.GPU(duration=70)
25
  def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
26
  frames, savefps):
27
  global model
 
42
 
43
  def reset_fn():
44
  return ("A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
45
+ 200, 384, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore")
46
 
47
  def update_lora_rank(lora_model):
48
  if lora_model == "huggingface-pickscore":
 
50
  elif lora_model == "huggingface-hps-aesthetic":
51
  return gr.update(value=8)
52
  else: # "Base Model"
53
+ return gr.update(value=8)
54
 
55
  def update_dropdown(lora_rank):
56
  if lora_rank == 16:
 
60
  else: # 0
61
  return gr.update(value="Base Model")
62
 
63
+ @spaces.GPU(duration=120)
64
  def setup_model_progress(lora_model, lora_rank):
65
  global model
66
 
 
72
  # Enable buttons after loading and update indicator
73
  yield (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), "Model loaded successfully")
74
 
75
+ @spaces.GPU(duration=120)
76
+ def generate_example(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
77
+ frames, savefps):
78
+ global model
79
+ model = setup_model(lora_model, lora_rank)
80
+ video_path = main_fn(prompt=prompt,
81
+ seed=int(seed),
82
+ height=int(height),
83
+ width=int(width),
84
+ unconditional_guidance_scale=float(unconditional_guidance_scale),
85
+ ddim_steps=int(ddim_steps),
86
+ ddim_eta=float(ddim_eta),
87
+ frames=int(frames),
88
+ savefps=int(savefps),
89
+ model=model)
90
+ return video_path
91
 
92
+ custom_css = """
93
+ #centered {
94
+ display: flex;
95
+ justify-content: center;
96
+ }
97
+ .column-centered {
98
+ display: flex;
99
+ flex-direction: column;
100
+ align-items: center;
101
+ width: 60%;
102
+ }
103
+ #image-upload {
104
+ flex-grow: 1;
105
+ }
106
+ #params .tabs {
107
+ display: flex;
108
+ flex-direction: column;
109
+ flex-grow: 1;
110
+ }
111
+ #params .tabitem[style="display: block;"] {
112
+ flex-grow: 1;
113
+ display: flex !important;
114
+ }
115
+ #params .gap {
116
+ flex-grow: 1;
117
+ }
118
+ #params .form {
119
+ flex-grow: 1 !important;
120
+ }
121
+ #params .form > :last-child{
122
+ flex-grow: 1;
123
+ }
124
+ """
125
 
126
+ with gr.Blocks(css=custom_css) as demo:
127
  with gr.Row():
128
  with gr.Column():
129
  gr.HTML(
 
207
  """
208
  )
209
 
210
+ with gr.Row(elem_id="centered"):
211
+ with gr.Column(scale=0.3, elem_id="params"):
212
+ lora_model = gr.Dropdown(
213
+ label="VADER Model",
214
+ choices=["huggingface-pickscore", "huggingface-hps-aesthetic", "Base Model"],
215
+ value="huggingface-pickscore"
216
+ )
217
+ lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
 
 
 
218
  load_btn = gr.Button("Load Model")
219
  # Add a label to show the loading indicator
220
  loading_indicator = gr.Label(value="", label="Loading Indicator")
221
+
222
+ with gr.Column(scale=0.3):
223
+ output_video = gr.Video(elem_id="image-upload")
224
 
225
+ with gr.Row(elem_id="centered"):
226
+ with gr.Column(scale=0.6):
227
  prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
228
  value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
229
 
 
233
 
234
 
235
  with gr.Row():
236
+ height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384)
237
  width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
238
 
239
  with gr.Row():
 
262
  lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank)
263
  lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
264
 
265
+ gr.Examples(examples=examples,
266
+ inputs=[prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, savefps],
267
+ outputs=output_video,
268
+ fn=generate_example,
269
+ run_on_click=False,
270
+ cache_examples=True,
271
+ )
272
 
273
+ demo.launch(share=True)
gradio_cached_examples/34/log.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ component 0,flag,username,timestamp
2
+ "{""video"": {""path"": ""gradio_cached_examples/34/component 0/098dac4a3713d5d7c6a8/temporal.mp4"", ""url"": ""/file=/tmp/gradio/4bc133becbc469de8da700250f7f7df1103c6f56/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-18 22:50:14.868519
3
+ "{""video"": {""path"": ""gradio_cached_examples/34/component 0/b32c2706faa4801becfc/temporal.mp4"", ""url"": ""/file=/tmp/gradio/7f62f2e865f6a6eef4c27968ad35c3102d6ba5a4/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-18 22:51:57.454233
4
+ "{""video"": {""path"": ""gradio_cached_examples/34/component 0/0ced86d109f80abd1456/temporal.mp4"", ""url"": ""/file=/tmp/gradio/2af48d5977a6b60b9c91982ef479e44a2ce2bd42/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-18 22:53:33.714132
5
+ "{""video"": {""path"": ""gradio_cached_examples/34/component 0/b3018d4fa1632c5c33d3/temporal.mp4"", ""url"": ""/file=/tmp/gradio/50c4df5d030c66ff3f75b5f427bb6ef42eb20597/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-18 22:55:14.236468
6
+ "{""video"": {""path"": ""gradio_cached_examples/34/component 0/73648e9d504425f92839/temporal.mp4"", ""url"": ""/file=/tmp/gradio/469d6c7ffc22a14449337ee8c966b3a517d581a3/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-18 22:56:46.543720
requirements.txt CHANGED
@@ -9,7 +9,7 @@ Pillow==9.5.0
9
  pytorch_lightning==2.3.1
10
  PyYAML==6.0
11
  setuptools==65.6.3
12
- tqdm==4.65.0
13
  transformers==4.25.1
14
  moviepy==1.0.3
15
  av==12.2.0
@@ -27,4 +27,5 @@ wandb==0.17.3
27
  ipdb==0.13.13
28
  huggingface-hub==0.23.4
29
  gradio
 
30
  -e git+https://github.com/tgxs002/HPSv2.git#egg=hpsv2
 
9
  pytorch_lightning==2.3.1
10
  PyYAML==6.0
11
  setuptools==65.6.3
12
+ tqdm>=4.66.3
13
  transformers==4.25.1
14
  moviepy==1.0.3
15
  av==12.2.0
 
27
  ipdb==0.13.13
28
  huggingface-hub==0.23.4
29
  gradio
30
+ spaces
31
  -e git+https://github.com/tgxs002/HPSv2.git#egg=hpsv2