Spanicin commited on
Commit
8b72e71
·
verified ·
1 Parent(s): cf2e13b

Update src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py CHANGED
@@ -99,50 +99,50 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
99
  return {'value': kp_transformed}
100
 
101
 
102
- def make_animation(source_image, source_semantics, target_semantics,
103
- generator, kp_detector, he_estimator, mapping,
104
- yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
- use_exp=True):
106
- with torch.no_grad():
107
- predictions = []
108
- device = 'cuda'
109
- source_image = source_image.to(device)
110
- source_semantics = source_semantics.to(device)
111
- target_semantics = target_semantics.to(device)
112
 
113
- kp_canonical = kp_detector(source_image)
114
- he_source = mapping(source_semantics)
115
- kp_source = keypoint_transformation(kp_canonical, he_source)
116
 
117
 
118
- for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
119
- target_semantics_frame = target_semantics[:, frame_idx]
120
- he_driving = mapping(target_semantics_frame)
121
- if yaw_c_seq is not None:
122
- he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
123
- if pitch_c_seq is not None:
124
- he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
125
- if roll_c_seq is not None:
126
- he_driving['roll_in'] = roll_c_seq[:, frame_idx]
127
 
128
- kp_driving = keypoint_transformation(kp_canonical, he_driving)
129
 
130
- #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
131
- #kp_driving_initial=kp_driving_initial)
132
- kp_norm = kp_driving
133
- out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
134
- '''
135
- source_image_new = out['prediction'].squeeze(1)
136
- kp_canonical_new = kp_detector(source_image_new)
137
- he_source_new = he_estimator(source_image_new)
138
- kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
139
- kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
140
- out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
141
- '''
142
- predictions.append(out['prediction'])
143
- torch.cuda.empty_cache()
144
- predictions_ts = torch.stack(predictions, dim=1)
145
- return predictions_ts
146
 
147
  # import torch
148
  # from torch.cuda.amp import autocast
@@ -200,49 +200,68 @@ def make_animation(source_image, source_semantics, target_semantics,
200
  generator, kp_detector, he_estimator, mapping,
201
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
202
  use_exp=True, batch_size=8):
 
203
  device = 'cuda'
 
204
  source_image = source_image.to(device)
205
  source_semantics = source_semantics.to(device)
206
  target_semantics = target_semantics.to(device)
207
 
208
- with torch.no_grad():
209
- predictions = []
 
 
210
  kp_canonical = kp_detector(source_image)
211
  he_source = mapping(source_semantics)
212
  kp_source = keypoint_transformation(kp_canonical, he_source)
213
 
214
- num_frames = target_semantics.shape[1]
215
  for start_idx in tqdm(range(0, num_frames, batch_size), desc='Face Renderer:', unit='batch'):
216
  end_idx = min(start_idx + batch_size, num_frames)
217
  target_semantics_batch = target_semantics[:, start_idx:end_idx]
218
- he_driving = mapping(target_semantics_batch)
219
 
220
- if yaw_c_seq is not None:
221
- he_driving['yaw_in'] = yaw_c_seq[:, start_idx:end_idx]
222
- if pitch_c_seq is not None:
223
- he_driving['pitch_in'] = pitch_c_seq[:, start_idx:end_idx]
224
- if roll_c_seq is not None:
225
- he_driving['roll_in'] = roll_c_seq[:, start_idx:end_idx]
226
 
227
- kp_driving = keypoint_transformation(kp_canonical, he_driving)
228
- kp_norm = kp_driving
 
229
 
230
- with autocast():
231
- out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
232
-
233
- predictions.append(out['prediction'])
234
-
235
- # Optional: Explicitly synchronize (use only if necessary)
236
- torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
237
 
238
- # Stack predictions into a single tensor
239
- predictions_ts = torch.stack(predictions, dim=1)
 
 
 
 
 
 
 
 
 
240
 
241
  return predictions_ts
242
 
243
 
244
 
245
 
 
246
  class AnimateModel(torch.nn.Module):
247
  """
248
  Merge all generator related updates into single model for better multi-gpu usage
 
99
  return {'value': kp_transformed}
100
 
101
 
102
+ # def make_animation(source_image, source_semantics, target_semantics,
103
+ # generator, kp_detector, he_estimator, mapping,
104
+ # yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
+ # use_exp=True):
106
+ # with torch.no_grad():
107
+ # predictions = []
108
+ # device = 'cuda'
109
+ # source_image = source_image.to(device)
110
+ # source_semantics = source_semantics.to(device)
111
+ # target_semantics = target_semantics.to(device)
112
 
113
+ # kp_canonical = kp_detector(source_image)
114
+ # he_source = mapping(source_semantics)
115
+ # kp_source = keypoint_transformation(kp_canonical, he_source)
116
 
117
 
118
+ # for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
119
+ # target_semantics_frame = target_semantics[:, frame_idx]
120
+ # he_driving = mapping(target_semantics_frame)
121
+ # if yaw_c_seq is not None:
122
+ # he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
123
+ # if pitch_c_seq is not None:
124
+ # he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
125
+ # if roll_c_seq is not None:
126
+ # he_driving['roll_in'] = roll_c_seq[:, frame_idx]
127
 
128
+ # kp_driving = keypoint_transformation(kp_canonical, he_driving)
129
 
130
+ # #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
131
+ # #kp_driving_initial=kp_driving_initial)
132
+ # kp_norm = kp_driving
133
+ # out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
134
+ # '''
135
+ # source_image_new = out['prediction'].squeeze(1)
136
+ # kp_canonical_new = kp_detector(source_image_new)
137
+ # he_source_new = he_estimator(source_image_new)
138
+ # kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
139
+ # kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
140
+ # out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
141
+ # '''
142
+ # predictions.append(out['prediction'])
143
+ # torch.cuda.empty_cache()
144
+ # predictions_ts = torch.stack(predictions, dim=1)
145
+ # return predictions_ts
146
 
147
  # import torch
148
  # from torch.cuda.amp import autocast
 
200
  generator, kp_detector, he_estimator, mapping,
201
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
202
  use_exp=True, batch_size=8):
203
+
204
  device = 'cuda'
205
+ # Move inputs to GPU
206
  source_image = source_image.to(device)
207
  source_semantics = source_semantics.to(device)
208
  target_semantics = target_semantics.to(device)
209
 
210
+ num_frames = target_semantics.shape[1]
211
+ predictions = []
212
+
213
+ with torch.no_grad(): # No gradients needed
214
  kp_canonical = kp_detector(source_image)
215
  he_source = mapping(source_semantics)
216
  kp_source = keypoint_transformation(kp_canonical, he_source)
217
 
 
218
  for start_idx in tqdm(range(0, num_frames, batch_size), desc='Face Renderer:', unit='batch'):
219
  end_idx = min(start_idx + batch_size, num_frames)
220
  target_semantics_batch = target_semantics[:, start_idx:end_idx]
 
221
 
222
+ batch_predictions = []
 
 
 
 
 
223
 
224
+ for frame_idx in range(target_semantics_batch.shape[1]):
225
+ target_semantics_frame = target_semantics_batch[:, frame_idx]
226
+ he_driving = mapping(target_semantics_frame)
227
 
228
+ if yaw_c_seq is not None:
229
+ he_driving['yaw_in'] = yaw_c_seq[:, start_idx:end_idx, frame_idx]
230
+ if pitch_c_seq is not None:
231
+ he_driving['pitch_in'] = pitch_c_seq[:, start_idx:end_idx, frame_idx]
232
+ if roll_c_seq is not None:
233
+ he_driving['roll_in'] = roll_c_seq[:, start_idx:end_idx, frame_idx]
234
+
235
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
236
+ kp_norm = kp_driving
237
+
238
+ # Ensure correct input shape for conv1d layers
239
+ source_image_reshaped = source_image.view(source_image.size(0), source_image.size(1), -1)
240
+ kp_source_reshaped = kp_source.view(kp_source.size(0), kp_source.size(1), -1)
241
+ kp_norm_reshaped = kp_norm.view(kp_norm.size(0), kp_norm.size(1), -1)
242
+
243
+ # Use mixed precision for faster computation
244
+ with autocast():
245
+ out = generator(source_image_reshaped, kp_source=kp_source_reshaped, kp_driving=kp_norm_reshaped)
246
 
247
+ batch_predictions.append(out['prediction'])
248
+
249
+ # Optional: Explicitly synchronize (use only if necessary)
250
+ torch.cuda.synchronize()
251
+
252
+ # Stack predictions for this batch
253
+ batch_predictions_ts = torch.stack(batch_predictions, dim=1)
254
+ predictions.append(batch_predictions_ts)
255
+
256
+ # Stack all batch predictions into a single tensor
257
+ predictions_ts = torch.cat(predictions, dim=1)
258
 
259
  return predictions_ts
260
 
261
 
262
 
263
 
264
+
265
  class AnimateModel(torch.nn.Module):
266
  """
267
  Merge all generator related updates into single model for better multi-gpu usage