Spanicin commited on
Commit
923c799
1 Parent(s): f657685

Update src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py CHANGED
@@ -144,131 +144,55 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
144
  # predictions_ts = torch.stack(predictions, dim=1)
145
  # return predictions_ts
146
 
147
- # import torch
148
- # from torch.cuda.amp import autocast
149
-
150
- # def make_animation(source_image, source_semantics, target_semantics,
151
- # generator, kp_detector, he_estimator, mapping,
152
- # yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
153
- # use_exp=True):
154
-
155
- # device='cuda'
156
- # # Move inputs to GPU
157
- # source_image = source_image.to(device)
158
- # source_semantics = source_semantics.to(device)
159
- # target_semantics = target_semantics.to(device)
160
-
161
- # with torch.no_grad(): # No gradients needed
162
- # predictions = []
163
- # kp_canonical = kp_detector(source_image)
164
- # he_source = mapping(source_semantics)
165
- # kp_source = keypoint_transformation(kp_canonical, he_source)
166
-
167
- # for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
168
- # target_semantics_frame = target_semantics[:, frame_idx]
169
- # he_driving = mapping(target_semantics_frame)
170
-
171
- # if yaw_c_seq is not None:
172
- # he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
173
- # if pitch_c_seq is not None:
174
- # he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
175
- # if roll_c_seq is not None:
176
- # he_driving['roll_in'] = roll_c_seq[:, frame_idx]
177
-
178
- # kp_driving = keypoint_transformation(kp_canonical, he_driving)
179
- # kp_norm = kp_driving
180
-
181
- # # Use mixed precision for faster computation
182
- # with autocast():
183
- # out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
184
-
185
- # predictions.append(out['prediction'])
186
-
187
- # # Optional: Explicitly synchronize (use only if necessary)
188
- # torch.cuda.synchronize()
189
-
190
- # # Stack predictions into a single tensor
191
- # predictions_ts = torch.stack(predictions, dim=1)
192
-
193
- # return predictions_ts
194
-
195
  import torch
196
  from torch.cuda.amp import autocast
197
- from tqdm import tqdm
198
 
199
  def make_animation(source_image, source_semantics, target_semantics,
200
  generator, kp_detector, he_estimator, mapping,
201
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
202
- use_exp=True, batch_size=8):
203
 
204
- device = 'cuda'
205
  # Move inputs to GPU
206
  source_image = source_image.to(device)
207
  source_semantics = source_semantics.to(device)
208
  target_semantics = target_semantics.to(device)
209
- print("source_image",source_image)
210
- num_frames = target_semantics.shape[1]
211
- predictions = []
212
 
213
  with torch.no_grad(): # No gradients needed
 
214
  kp_canonical = kp_detector(source_image)
215
  he_source = mapping(source_semantics)
216
  kp_source = keypoint_transformation(kp_canonical, he_source)
217
- print("kp_canonical",kp_canonical)
218
- print("he_source",he_source)
219
- print("kp_source",kp_source)
220
-
221
- for start_idx in tqdm(range(0, num_frames, batch_size), desc='Face Renderer:', unit='batch'):
222
- end_idx = min(start_idx + batch_size, num_frames)
223
- target_semantics_batch = target_semantics[:, start_idx:end_idx]
224
 
225
- batch_predictions = []
 
 
226
 
227
- for frame_idx in range(target_semantics_batch.shape[1]):
228
- target_semantics_frame = target_semantics_batch[:, frame_idx]
229
- he_driving = mapping(target_semantics_frame)
 
 
 
230
 
231
- if yaw_c_seq is not None:
232
- he_driving['yaw_in'] = yaw_c_seq[:, start_idx:end_idx, frame_idx]
233
- if pitch_c_seq is not None:
234
- he_driving['pitch_in'] = pitch_c_seq[:, start_idx:end_idx, frame_idx]
235
- if roll_c_seq is not None:
236
- he_driving['roll_in'] = roll_c_seq[:, start_idx:end_idx, frame_idx]
237
 
238
- kp_driving = keypoint_transformation(kp_canonical, he_driving)
239
- kp_norm = kp_driving
 
240
 
241
- print("source_image",source_image)
242
- print("kp_source",kp_source)
243
- print("kp_norm",kp_norm)
244
-
245
- # Ensure correct input shape for conv1d layers
246
- source_image_reshaped = source_image.view(source_image.size(0), source_image.size(1), -1)
247
- kp_source_reshaped = kp_source.view(kp_source.size(0), kp_source.size(1), -1)
248
- kp_norm_reshaped = kp_norm.view(kp_norm.size(0), kp_norm.size(1), -1)
249
-
250
- # Use mixed precision for faster computation
251
- with autocast():
252
- out = generator(source_image_reshaped, kp_source=kp_source_reshaped, kp_driving=kp_norm_reshaped)
253
-
254
- batch_predictions.append(out['prediction'])
255
-
256
- # Optional: Explicitly synchronize (use only if necessary)
257
- torch.cuda.synchronize()
258
-
259
- # Stack predictions for this batch
260
- batch_predictions_ts = torch.stack(batch_predictions, dim=1)
261
- predictions.append(batch_predictions_ts)
262
 
263
- # Stack all batch predictions into a single tensor
264
- predictions_ts = torch.cat(predictions, dim=1)
265
 
266
  return predictions_ts
267
 
268
 
269
-
270
-
271
-
272
  class AnimateModel(torch.nn.Module):
273
  """
274
  Merge all generator related updates into single model for better multi-gpu usage
 
144
  # predictions_ts = torch.stack(predictions, dim=1)
145
  # return predictions_ts
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  import torch
148
  from torch.cuda.amp import autocast
 
149
 
150
  def make_animation(source_image, source_semantics, target_semantics,
151
  generator, kp_detector, he_estimator, mapping,
152
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
153
+ use_exp=True):
154
 
155
+ device='cuda'
156
  # Move inputs to GPU
157
  source_image = source_image.to(device)
158
  source_semantics = source_semantics.to(device)
159
  target_semantics = target_semantics.to(device)
 
 
 
160
 
161
  with torch.no_grad(): # No gradients needed
162
+ predictions = []
163
  kp_canonical = kp_detector(source_image)
164
  he_source = mapping(source_semantics)
165
  kp_source = keypoint_transformation(kp_canonical, he_source)
 
 
 
 
 
 
 
166
 
167
+ for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
168
+ target_semantics_frame = target_semantics[:, frame_idx]
169
+ he_driving = mapping(target_semantics_frame)
170
 
171
+ if yaw_c_seq is not None:
172
+ he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
173
+ if pitch_c_seq is not None:
174
+ he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
175
+ if roll_c_seq is not None:
176
+ he_driving['roll_in'] = roll_c_seq[:, frame_idx]
177
 
178
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
179
+ kp_norm = kp_driving
 
 
 
 
180
 
181
+ # Use mixed precision for faster computation
182
+ with autocast():
183
+ out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
184
 
185
+ predictions.append(out['prediction'])
186
+
187
+ # Optional: Explicitly synchronize (use only if necessary)
188
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ # Stack predictions into a single tensor
191
+ predictions_ts = torch.stack(predictions, dim=1)
192
 
193
  return predictions_ts
194
 
195
 
 
 
 
196
  class AnimateModel(torch.nn.Module):
197
  """
198
  Merge all generator related updates into single model for better multi-gpu usage