|
from diffusers import ( |
|
DPMSolverMultistepScheduler, |
|
DDPMScheduler, |
|
DDIMScheduler, |
|
PNDMScheduler, |
|
DEISMultistepScheduler, |
|
) |
|
import torch |
|
import yaml |
|
import math |
|
import tqdm |
|
import time |
|
|
|
|
|
class DiffusePipeline(object): |
|
|
|
def __init__( |
|
self, |
|
opt, |
|
model, |
|
diffuser_name, |
|
num_inference_steps, |
|
device, |
|
torch_dtype=torch.float16, |
|
): |
|
self.device = device |
|
self.torch_dtype = torch_dtype |
|
self.diffuser_name = diffuser_name |
|
self.num_inference_steps = num_inference_steps |
|
if self.torch_dtype == torch.float16: |
|
model = model.half() |
|
self.model = model.to(device) |
|
self.opt = opt |
|
|
|
|
|
with open("config/diffuser_params.yaml", "r") as yaml_file: |
|
diffuser_params = yaml.safe_load(yaml_file) |
|
|
|
|
|
if diffuser_name in diffuser_params: |
|
params = diffuser_params[diffuser_name] |
|
scheduler_class_name = params["scheduler_class"] |
|
additional_params = params["additional_params"] |
|
|
|
|
|
additional_params["num_train_timesteps"] = opt.diffusion_steps |
|
additional_params["beta_schedule"] = opt.beta_schedule |
|
additional_params["prediction_type"] = opt.prediction_type |
|
|
|
try: |
|
scheduler_class = globals()[scheduler_class_name] |
|
except KeyError: |
|
raise ValueError(f"Class '{scheduler_class_name}' not found.") |
|
|
|
self.scheduler = scheduler_class(**additional_params) |
|
else: |
|
raise ValueError(f"Unsupported diffuser_name: {diffuser_name}") |
|
|
|
def generate_batch(self, caption, m_lens): |
|
B = len(caption) |
|
T = m_lens.max() |
|
shape = (B, T, self.model.input_feats) |
|
|
|
|
|
sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype) |
|
|
|
|
|
self.scheduler.set_timesteps(self.num_inference_steps, self.device) |
|
timesteps = [ |
|
torch.tensor([t] * B, device=self.device).long() |
|
for t in self.scheduler.timesteps |
|
] |
|
|
|
|
|
enc_text = self.model.encode_text(caption, self.device) |
|
|
|
for i, t in enumerate(timesteps): |
|
|
|
with torch.no_grad(): |
|
if getattr(self.model, "cond_mask_prob", 0) > 0: |
|
predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text) |
|
else: |
|
|
|
predict = self.model(sample, t, enc_text=enc_text) |
|
|
|
|
|
sample = self.scheduler.step(predict, t[0], sample).prev_sample |
|
|
|
return sample |
|
|
|
def generate(self, caption, m_lens, batch_size=32): |
|
N = len(caption) |
|
infer_mode = "" |
|
if getattr(self.model, "cond_mask_prob", 0) > 0: |
|
infer_mode = "classifier-free-guidance" |
|
print( |
|
f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps." |
|
) |
|
self.model.eval() |
|
|
|
all_output = [] |
|
t_sum = 0 |
|
cur_idx = 0 |
|
for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))): |
|
if cur_idx + batch_size >= N: |
|
batch_caption = caption[cur_idx:] |
|
batch_m_lens = m_lens[cur_idx:] |
|
else: |
|
batch_caption = caption[cur_idx : cur_idx + batch_size] |
|
batch_m_lens = m_lens[cur_idx : cur_idx + batch_size] |
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
output = self.generate_batch(batch_caption, batch_m_lens) |
|
torch.cuda.synchronize() |
|
now_time = time.time() |
|
|
|
|
|
if (bacth_idx + 1) * self.num_inference_steps >= 50: |
|
t_sum += now_time - start_time |
|
|
|
|
|
B = output.shape[0] |
|
for i in range(B): |
|
all_output.append(output[i, : batch_m_lens[i]]) |
|
|
|
cur_idx += batch_size |
|
|
|
|
|
t_eval = t_sum / (bacth_idx - 1) |
|
print( |
|
"The average generation time of a batch motion (bs=%d) is %f seconds" |
|
% (batch_size, t_eval) |
|
) |
|
return all_output, t_eval |
|
|