Spaces:
Sleeping
Sleeping
Update src/facerender/modules/make_animation.py
Browse files
src/facerender/modules/make_animation.py
CHANGED
@@ -144,131 +144,55 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
|
|
144 |
# predictions_ts = torch.stack(predictions, dim=1)
|
145 |
# return predictions_ts
|
146 |
|
147 |
-
# import torch
|
148 |
-
# from torch.cuda.amp import autocast
|
149 |
-
|
150 |
-
# def make_animation(source_image, source_semantics, target_semantics,
|
151 |
-
# generator, kp_detector, he_estimator, mapping,
|
152 |
-
# yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
153 |
-
# use_exp=True):
|
154 |
-
|
155 |
-
# device='cuda'
|
156 |
-
# # Move inputs to GPU
|
157 |
-
# source_image = source_image.to(device)
|
158 |
-
# source_semantics = source_semantics.to(device)
|
159 |
-
# target_semantics = target_semantics.to(device)
|
160 |
-
|
161 |
-
# with torch.no_grad(): # No gradients needed
|
162 |
-
# predictions = []
|
163 |
-
# kp_canonical = kp_detector(source_image)
|
164 |
-
# he_source = mapping(source_semantics)
|
165 |
-
# kp_source = keypoint_transformation(kp_canonical, he_source)
|
166 |
-
|
167 |
-
# for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
|
168 |
-
# target_semantics_frame = target_semantics[:, frame_idx]
|
169 |
-
# he_driving = mapping(target_semantics_frame)
|
170 |
-
|
171 |
-
# if yaw_c_seq is not None:
|
172 |
-
# he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
|
173 |
-
# if pitch_c_seq is not None:
|
174 |
-
# he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
|
175 |
-
# if roll_c_seq is not None:
|
176 |
-
# he_driving['roll_in'] = roll_c_seq[:, frame_idx]
|
177 |
-
|
178 |
-
# kp_driving = keypoint_transformation(kp_canonical, he_driving)
|
179 |
-
# kp_norm = kp_driving
|
180 |
-
|
181 |
-
# # Use mixed precision for faster computation
|
182 |
-
# with autocast():
|
183 |
-
# out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
|
184 |
-
|
185 |
-
# predictions.append(out['prediction'])
|
186 |
-
|
187 |
-
# # Optional: Explicitly synchronize (use only if necessary)
|
188 |
-
# torch.cuda.synchronize()
|
189 |
-
|
190 |
-
# # Stack predictions into a single tensor
|
191 |
-
# predictions_ts = torch.stack(predictions, dim=1)
|
192 |
-
|
193 |
-
# return predictions_ts
|
194 |
-
|
195 |
import torch
|
196 |
from torch.cuda.amp import autocast
|
197 |
-
from tqdm import tqdm
|
198 |
|
199 |
def make_animation(source_image, source_semantics, target_semantics,
|
200 |
generator, kp_detector, he_estimator, mapping,
|
201 |
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
202 |
-
use_exp=True
|
203 |
|
204 |
-
device
|
205 |
# Move inputs to GPU
|
206 |
source_image = source_image.to(device)
|
207 |
source_semantics = source_semantics.to(device)
|
208 |
target_semantics = target_semantics.to(device)
|
209 |
-
print("source_image",source_image)
|
210 |
-
num_frames = target_semantics.shape[1]
|
211 |
-
predictions = []
|
212 |
|
213 |
with torch.no_grad(): # No gradients needed
|
|
|
214 |
kp_canonical = kp_detector(source_image)
|
215 |
he_source = mapping(source_semantics)
|
216 |
kp_source = keypoint_transformation(kp_canonical, he_source)
|
217 |
-
print("kp_canonical",kp_canonical)
|
218 |
-
print("he_source",he_source)
|
219 |
-
print("kp_source",kp_source)
|
220 |
-
|
221 |
-
for start_idx in tqdm(range(0, num_frames, batch_size), desc='Face Renderer:', unit='batch'):
|
222 |
-
end_idx = min(start_idx + batch_size, num_frames)
|
223 |
-
target_semantics_batch = target_semantics[:, start_idx:end_idx]
|
224 |
|
225 |
-
|
|
|
|
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
if pitch_c_seq is not None:
|
234 |
-
he_driving['pitch_in'] = pitch_c_seq[:, start_idx:end_idx, frame_idx]
|
235 |
-
if roll_c_seq is not None:
|
236 |
-
he_driving['roll_in'] = roll_c_seq[:, start_idx:end_idx, frame_idx]
|
237 |
|
238 |
-
|
239 |
-
|
|
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
# Ensure correct input shape for conv1d layers
|
246 |
-
source_image_reshaped = source_image.view(source_image.size(0), source_image.size(1), -1)
|
247 |
-
kp_source_reshaped = kp_source.view(kp_source.size(0), kp_source.size(1), -1)
|
248 |
-
kp_norm_reshaped = kp_norm.view(kp_norm.size(0), kp_norm.size(1), -1)
|
249 |
-
|
250 |
-
# Use mixed precision for faster computation
|
251 |
-
with autocast():
|
252 |
-
out = generator(source_image_reshaped, kp_source=kp_source_reshaped, kp_driving=kp_norm_reshaped)
|
253 |
-
|
254 |
-
batch_predictions.append(out['prediction'])
|
255 |
-
|
256 |
-
# Optional: Explicitly synchronize (use only if necessary)
|
257 |
-
torch.cuda.synchronize()
|
258 |
-
|
259 |
-
# Stack predictions for this batch
|
260 |
-
batch_predictions_ts = torch.stack(batch_predictions, dim=1)
|
261 |
-
predictions.append(batch_predictions_ts)
|
262 |
|
263 |
-
# Stack
|
264 |
-
predictions_ts = torch.
|
265 |
|
266 |
return predictions_ts
|
267 |
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
class AnimateModel(torch.nn.Module):
|
273 |
"""
|
274 |
Merge all generator related updates into single model for better multi-gpu usage
|
|
|
144 |
# predictions_ts = torch.stack(predictions, dim=1)
|
145 |
# return predictions_ts
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
import torch
|
148 |
from torch.cuda.amp import autocast
|
|
|
149 |
|
150 |
def make_animation(source_image, source_semantics, target_semantics,
|
151 |
generator, kp_detector, he_estimator, mapping,
|
152 |
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
153 |
+
use_exp=True):
|
154 |
|
155 |
+
device='cuda'
|
156 |
# Move inputs to GPU
|
157 |
source_image = source_image.to(device)
|
158 |
source_semantics = source_semantics.to(device)
|
159 |
target_semantics = target_semantics.to(device)
|
|
|
|
|
|
|
160 |
|
161 |
with torch.no_grad(): # No gradients needed
|
162 |
+
predictions = []
|
163 |
kp_canonical = kp_detector(source_image)
|
164 |
he_source = mapping(source_semantics)
|
165 |
kp_source = keypoint_transformation(kp_canonical, he_source)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
+
for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
|
168 |
+
target_semantics_frame = target_semantics[:, frame_idx]
|
169 |
+
he_driving = mapping(target_semantics_frame)
|
170 |
|
171 |
+
if yaw_c_seq is not None:
|
172 |
+
he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
|
173 |
+
if pitch_c_seq is not None:
|
174 |
+
he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
|
175 |
+
if roll_c_seq is not None:
|
176 |
+
he_driving['roll_in'] = roll_c_seq[:, frame_idx]
|
177 |
|
178 |
+
kp_driving = keypoint_transformation(kp_canonical, he_driving)
|
179 |
+
kp_norm = kp_driving
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
# Use mixed precision for faster computation
|
182 |
+
with autocast():
|
183 |
+
out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
|
184 |
|
185 |
+
predictions.append(out['prediction'])
|
186 |
+
|
187 |
+
# Optional: Explicitly synchronize (use only if necessary)
|
188 |
+
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
# Stack predictions into a single tensor
|
191 |
+
predictions_ts = torch.stack(predictions, dim=1)
|
192 |
|
193 |
return predictions_ts
|
194 |
|
195 |
|
|
|
|
|
|
|
196 |
class AnimateModel(torch.nn.Module):
|
197 |
"""
|
198 |
Merge all generator related updates into single model for better multi-gpu usage
|