harlanhong commited on
Commit
9f99eb8
·
1 Parent(s): 57e5bc7
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +2 -3
  3. demo_dagan.py +3 -4
.gitignore CHANGED
@@ -1 +1,2 @@
1
- *.pyc
 
 
1
+ *.pyc
2
+ *.sh
app.py CHANGED
@@ -32,9 +32,8 @@ def inference(img, video):
32
  if not os.path.exists('temp'):
33
  os.system('mkdir temp')
34
  #### Resize the longer edge of the input image
35
- cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy temp/driving_video.mp4"
36
- subprocess.run(cmd.split())
37
- driving_video = "video_input.mp4"
38
  os.system("python demo_dagan.py --source_image {} --driving_video 'temp/driving_video.mp4' --output 'temp/rst.mp4'".format(img))
39
  return f'temp/rst.mp4'
40
 
 
32
  if not os.path.exists('temp'):
33
  os.system('mkdir temp')
34
  #### Resize the longer edge of the input image
35
+ os.system("ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy temp/driving_video.mp4")
36
+ # driving_video = "video_input.mp4"
 
37
  os.system("python demo_dagan.py --source_image {} --driving_video 'temp/driving_video.mp4' --output 'temp/rst.mp4'".format(img))
38
  return f'temp/rst.mp4'
39
 
demo_dagan.py CHANGED
@@ -5,9 +5,7 @@
5
 
6
  import torch
7
  import torch.nn.functional as F
8
- import os
9
  from skimage import img_as_ubyte
10
- import cv2
11
  import argparse
12
  import imageio
13
  from skimage.transform import resize
@@ -123,6 +121,7 @@ def make_animation(source_image, driving_video, generator, kp_detector, relative
123
  return sources, drivings, predictions,depth_gray
124
  with open("config/vox-adv-256.yaml") as f:
125
  config = yaml.load(f)
 
126
  generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
127
  config['model_params']['common_params']['num_channels'] = 4
128
  kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
@@ -139,8 +138,8 @@ kp_detector.load_state_dict(ckp_kp_detector)
139
 
140
  depth_encoder = depth.ResnetEncoder(18, False)
141
  depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
142
- loaded_dict_enc = torch.load('encoder.pth')
143
- loaded_dict_dec = torch.load('depth.pth')
144
  filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
145
  depth_encoder.load_state_dict(filtered_dict_enc)
146
  ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
 
5
 
6
  import torch
7
  import torch.nn.functional as F
 
8
  from skimage import img_as_ubyte
 
9
  import argparse
10
  import imageio
11
  from skimage.transform import resize
 
121
  return sources, drivings, predictions,depth_gray
122
  with open("config/vox-adv-256.yaml") as f:
123
  config = yaml.load(f)
124
+
125
  generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
126
  config['model_params']['common_params']['num_channels'] = 4
127
  kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
 
138
 
139
  depth_encoder = depth.ResnetEncoder(18, False)
140
  depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
141
+ loaded_dict_enc = torch.load('encoder.pth',map_location=device)
142
+ loaded_dict_dec = torch.load('depth.pth',map_location=device)
143
  filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
144
  depth_encoder.load_state_dict(filtered_dict_enc)
145
  ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}