youngseng's picture
Upload 187 files
da855ff
import argparse
import json
import pathlib
from pathlib import Path
from shutil import copyfile
import numpy as np
import pandas as pd
from omegaconf import DictConfig
from rich.console import Console
from anim import bvh, quat
from anim.txform import *
from audio.audio_files import read_wavfile
from data_pipeline import preprocess_animation, preprocess_audio
from helpers import split_by_ratio
from utils import write_bvh
def generate_gesture(
audio_file,
styles,
network_path,
data_path,
results_path,
blend_type="add",
blend_ratio=[0.5, 0.5],
file_name=None,
first_pose=None,
temperature=1.0,
seed=1234,
use_gpu=True,
use_script=False,
):
"""Generate stylized gesture from raw audio and style example (ZEGGS)
Args:
audio_file ([type]): Path to audio file. If None the function does not generate geture and only outputs the style embedding
styles ([type]): What styles to use. This is a list of tuples S, where each tuple S provides info for one style.
Multiple styles are given for blending or stitching styles. Tuple S contains:
- S[0] is the path to the bvh example or the style embedding vec to be used directly
- S[1] is a list or tuple of size two defining the start and end frame to be used. None if style embedding is used directly
network_path ([type]): Path to the networks
data_path ([type]): Path to the data directory containing needed processing information
results_path ([type]): Path to result directory
blend_type (str, optional): Blending type, stitch (transitioning) or add (mixing). Defaults to "add".
blend_ratio (list, optional): The proportion of blending. If blend type is "stitch", this is the proportion of the length.
of the output for this style. If the blend type is "add" this is the interpolation weight
Defaults to [0.5, 0.5].
file_name ([type], optional): Output file name. If none the audio and example file names are used. Defaults to None.
first_pose ([type], optional): The info required as the first pose. It can either be the path to the bvh file for using
first pose or the animation dictionary extracted by loading a bvh file.
If None, the pose from the last example is used. Defaults to None.
temperature (float, optional): VAE temprature. This adjusts the amount of stochasticity. Defaults to 1.0.
seed (int, optional): Random seed. Defaults to 1234.
use_gpu (bool, optional): Use gpu or cpu. Defaults to True.
use_script (bool, optional): Use torch script. Defaults to False.
Returns:
final_style_encoding: The final style embedding. If blend_type is "stitch", it is the style embedding for each frame.
If blend_type is "add", it is the interpolated style embedding vector
"""
# Load details
path_network_speech_encoder_weights = network_path / "speech_encoder.pt"
path_network_decoder_weights = network_path / "decoder.pt"
path_network_style_encoder_weights = network_path / "style_encoder.pt"
path_stat_data = data_path / "stats.npz"
path_data_definition = data_path / "data_definition.json"
path_data_pipeline_conf = data_path / "data_pipeline_conf.json"
if results_path is not None:
results_path.mkdir(exist_ok=True)
assert (audio_file is None) == (results_path is None)
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_num_threads(1)
use_gpu = use_gpu and torch.cuda.is_available()
device = torch.device("cuda:0" if use_gpu else "cpu")
# Data pipeline conf (We must use the same processing configuration as the one in training)
with open(path_data_pipeline_conf, "r") as f:
data_pipeline_conf = json.load(f)
data_pipeline_conf = DictConfig(data_pipeline_conf)
# Animation static info (Skeleton, FPS, etc)
with open(path_data_definition, "r") as f:
details = json.load(f)
njoints = len(details["bone_names"])
nlabels = len(details["label_names"])
bone_names = details["bone_names"]
parents = torch.as_tensor(details["parents"], dtype=torch.long, device=device)
dt = details["dt"]
# Load Stats (Mean and Std of input/output)
stat_data = np.load(path_stat_data)
audio_input_mean = torch.as_tensor(
stat_data["audio_input_mean"], dtype=torch.float32, device=device
)
audio_input_std = torch.as_tensor(
stat_data["audio_input_std"], dtype=torch.float32, device=device
)
anim_input_mean = torch.as_tensor(
stat_data["anim_input_mean"], dtype=torch.float32, device=device
)
anim_input_std = torch.as_tensor(
stat_data["anim_input_std"], dtype=torch.float32, device=device
)
anim_output_mean = torch.as_tensor(
stat_data["anim_output_mean"], dtype=torch.float32, device=device
)
anim_output_std = torch.as_tensor(
stat_data["anim_output_std"], dtype=torch.float32, device=device
)
# Load Networks
network_speech_encoder = torch.load(path_network_speech_encoder_weights).to(device)
network_speech_encoder.eval()
network_decoder = torch.load(path_network_decoder_weights).to(device)
network_decoder.eval()
network_style_encoder = torch.load(path_network_style_encoder_weights).to(device)
network_style_encoder.eval()
if use_script:
network_speech_encoder_script = torch.jit.script(network_speech_encoder)
network_decoder_script = torch.jit.script(network_decoder)
network_style_encoder_script = torch.jit.script(network_style_encoder)
else:
network_speech_encoder_script = network_speech_encoder
network_decoder_script = network_decoder
network_style_encoder_script = network_style_encoder
network_speech_encoder_script.eval()
network_decoder_script.eval()
network_style_encoder_script.eval()
with torch.no_grad():
# If audio is None we only output the style encodings
if audio_file is not None:
# Load Audio
_, audio_data = read_wavfile(
audio_file,
rescale=True,
desired_fs=16000,
desired_nb_channels=None,
out_type="float32",
logger=None,
)
n_frames = int(round(60.0 * (len(audio_data) / 16000)))
audio_features = torch.as_tensor(
preprocess_audio(
audio_data,
60,
n_frames,
data_pipeline_conf.audio_conf,
feature_type=data_pipeline_conf.audio_feature_type,
),
device=device,
dtype=torch.float32,
)
speech_encoding = network_speech_encoder_script(
(audio_features[np.newaxis] - audio_input_mean) / audio_input_std
)
# Style Encoding
style_encodings = []
for example in styles:
if isinstance(example[0], pathlib.WindowsPath) or isinstance(example[0], pathlib.PosixPath):
anim_name = Path(example[0]).stem
anim_data = bvh.load(example[0])
# Trimming if start/end frames are given
if example[1] is not None:
anim_data["rotations"] = anim_data["rotations"][
example[1][0]: example[1][1]
]
anim_data["positions"] = anim_data["positions"][
example[1][0]: example[1][1]
]
anim_fps = int(np.ceil(1 / anim_data["frametime"]))
assert anim_fps == 60
# Extracting features
(
root_pos,
root_rot,
root_vel,
root_vrt,
lpos,
lrot,
ltxy,
lvel,
lvrt,
cpos,
crot,
ctxy,
cvel,
cvrt,
gaze_pos,
gaze_dir,
) = preprocess_animation(anim_data)
# convert to tensor
nframes = len(anim_data["rotations"])
root_vel = torch.as_tensor(root_vel, dtype=torch.float32, device=device)
root_vrt = torch.as_tensor(root_vrt, dtype=torch.float32, device=device)
root_pos = torch.as_tensor(root_pos, dtype=torch.float32, device=device)
root_rot = torch.as_tensor(root_rot, dtype=torch.float32, device=device)
lpos = torch.as_tensor(lpos, dtype=torch.float32, device=device)
ltxy = torch.as_tensor(ltxy, dtype=torch.float32, device=device)
lvel = torch.as_tensor(lvel, dtype=torch.float32, device=device)
lvrt = torch.as_tensor(lvrt, dtype=torch.float32, device=device)
gaze_pos = torch.as_tensor(gaze_pos, dtype=torch.float32, device=device)
S_root_vel = root_vel.reshape(nframes, -1)
S_root_vrt = root_vrt.reshape(nframes, -1)
S_lpos = lpos.reshape(nframes, -1)
S_ltxy = ltxy.reshape(nframes, -1)
S_lvel = lvel.reshape(nframes, -1)
S_lvrt = lvrt.reshape(nframes, -1)
example_feature_vec = torch.cat(
[
S_root_vel,
S_root_vrt,
S_lpos,
S_ltxy,
S_lvel,
S_lvrt,
torch.zeros_like(S_root_vel),
],
dim=1,
)
example_feature_vec = (example_feature_vec - anim_input_mean) / anim_input_std
style_encoding, _, _ = network_style_encoder_script(
example_feature_vec[np.newaxis], temperature
)
style_encodings.append(style_encoding)
elif isinstance(example[0], np.ndarray):
anim_name = example[1]
style_embeddding = torch.as_tensor(
example[0], dtype=torch.float32, device=device
)[np.newaxis]
style_encodings.append(style_embeddding)
if blend_type == "stitch":
if len(style_encodings) > 1:
if audio_file is None:
final_style_encoding = style_encodings
else:
assert len(styles) == len(blend_ratio)
se = split_by_ratio(n_frames, blend_ratio)
V_root_pos = []
V_root_rot = []
V_lpos = []
V_ltxy = []
final_style_encoding = []
for i, style_encoding in enumerate(style_encodings):
final_style_encoding.append(
style_encoding.unsqueeze(1).repeat((1, se[i][-1] - se[i][0], 1))
)
final_style_encoding = torch.cat(final_style_encoding, dim=1)
else:
final_style_encoding = style_encodings[0]
elif blend_type == "add":
# style_encoding = torch.mean(torch.stack(style_encodings), dim=0)
if len(style_encodings) > 1:
assert len(style_encodings) == len(blend_ratio)
final_style_encoding = torch.matmul(
torch.stack(style_encodings, dim=1).transpose(2, 1),
torch.tensor(blend_ratio, device=device),
)
else:
final_style_encoding = style_encodings[0]
if audio_file is not None:
se = np.array_split(np.arange(n_frames), len(style_encodings))
if first_pose is not None:
if isinstance(first_pose, pathlib.WindowsPath) or isinstance(first_pose, pathlib.PosixPath):
anim_data = bvh.load(first_pose)
elif isinstance(first_pose, dict):
anim_data = first_pose.copy()
(
root_pos,
root_rot,
root_vel,
root_vrt,
lpos,
lrot,
ltxy,
lvel,
lvrt,
cpos,
crot,
ctxy,
cvel,
cvrt,
gaze_pos,
gaze_dir,
) = preprocess_animation(anim_data)
root_vel = torch.as_tensor(root_vel, dtype=torch.float32, device=device)
root_vrt = torch.as_tensor(root_vrt, dtype=torch.float32, device=device)
root_pos = torch.as_tensor(root_pos, dtype=torch.float32, device=device)
root_rot = torch.as_tensor(root_rot, dtype=torch.float32, device=device)
lpos = torch.as_tensor(lpos, dtype=torch.float32, device=device)
ltxy = torch.as_tensor(ltxy, dtype=torch.float32, device=device)
lvel = torch.as_tensor(lvel, dtype=torch.float32, device=device)
lvrt = torch.as_tensor(lvrt, dtype=torch.float32, device=device)
gaze_pos = torch.as_tensor(gaze_pos, dtype=torch.float32, device=device)
root_pos_0 = root_pos[0][np.newaxis]
root_rot_0 = root_rot[0][np.newaxis]
root_vel_0 = root_vel[0][np.newaxis]
root_vrt_0 = root_vrt[0][np.newaxis]
lpos_0 = lpos[0][np.newaxis]
ltxy_0 = ltxy[0][np.newaxis]
lvel_0 = lvel[0][np.newaxis]
lvrt_0 = lvrt[0][np.newaxis]
if final_style_encoding.dim() == 2:
final_style_encoding = final_style_encoding.unsqueeze(1).repeat((1, speech_encoding.shape[1], 1))
(
V_root_pos,
V_root_rot,
V_root_vel,
V_root_vrt,
V_lpos,
V_ltxy,
V_lvel,
V_lvrt,
) = network_decoder_script(
root_pos_0,
root_rot_0,
root_vel_0,
root_vrt_0,
lpos_0,
ltxy_0,
lvel_0,
lvrt_0,
gaze_pos[0: 0 + 1].repeat_interleave(speech_encoding.shape[1], dim=0)[
np.newaxis
],
speech_encoding,
final_style_encoding,
parents,
anim_input_mean,
anim_input_std,
anim_output_mean,
anim_output_std,
dt,
)
V_lrot = quat.from_xform(xform_orthogonalize_from_xy(V_ltxy).detach().cpu().numpy())
if file_name is None:
file_name = f"audio_{audio_file.stem}_label_{anim_name}"
try:
write_bvh(
str(results_path / (file_name + ".bvh")),
V_root_pos[0].detach().cpu().numpy(),
V_root_rot[0].detach().cpu().numpy(),
V_lpos[0].detach().cpu().numpy(),
V_lrot[0],
parents=parents.detach().cpu().numpy(),
names=bone_names,
order="zyx",
dt=dt,
start_position=np.array([0, 0, 0]),
start_rotation=np.array([1, 0, 0, 0]),
)
copyfile(audio_file, str(results_path / (file_name + ".wav")))
except (PermissionError, OSError) as e:
print(e)
return final_style_encoding
if __name__ == "__main__":
# CLI for generating gesture from one pair of audio and style files or multiple pairs through a csv file
# For full functionality, please use the generate_gesture function
console = Console()
# Setting parser
parser = argparse.ArgumentParser(prog="ZEGGS", description="Generate samples by ZEGGS model")
parser.add_argument(
"-o",
"--options",
type=str,
help="Options filename (generated during training)",
)
parser.add_argument('-p', '--results_path', type=str,
help="Results path. Default if 'results' directory in the folder containing networks",
nargs="?", const=None, required=False)
# 1. Generating gesture from a single pair of audio and style files
parser.add_argument('-s', '--style', type=str, help="Path to style example file", required=False)
parser.add_argument('-a', '--audio', type=str, help="Path to audio file", required=False)
parser.add_argument('-n', '--file_name', type=str,
help="Output file name. If not given it will be automatically constructed", required=False)
parser.add_argument('-t', '--temperature', type=float,
help="VAE temprature. This adjusts the amount of stochasticity.", nargs="?", default=1.0,
required=False)
parser.add_argument('-r', '--seed', type=int, help="Random seed", nargs="?", default=1234, required=False)
parser.add_argument('-g', '--use_gpu', help="Use GPU (Default is using CPU)", action="store_true", required=False)
parser.add_argument('-f', '--frames', type=int, help="Start and end frame of the style example to be used", nargs=2,
required=False)
# 2. Generating gesture(s) from a csv file (some of the other arguments will be ignored)
parser.add_argument('-c', '--csv', type=str,
help="CSV file containing information about pairs of audio/style and other parameters",
required=False)
args = parser.parse_args()
with open(args.options, "r") as f:
options = json.load(f)
train_options = options["train_opt"]
network_options = options["net_opt"]
paths = options["paths"]
base_path = Path(paths["base_path"])
data_path = base_path / paths["path_processed_data"]
network_path = Path(paths["models_dir"])
output_path = Path(paths["output_dir"])
results_path = args.results_path
if results_path is None:
results_path = Path(output_path) / "results"
if args.csv is not None:
console.print("Getting arguments from CSV file")
df = pd.read_csv(args.csv)
for index, row in df.iterrows():
if not row["generate"]:
continue
with console.status(console.rule(f"Generating Gesture {index + 1}/{len(df)}")):
row["results_path"] = results_path
row["options"] = args.options
base_path = Path(row["base_path"])
frames = [int(x) for x in row["frames"].split(" ")] if isinstance(row["frames"], str) else None
console.print("Arguments:")
console.print(row.to_string(index=True))
generate_gesture(
audio_file=base_path / Path(row["audio"]),
styles=[(base_path / Path(row["style"]), frames)],
network_path=network_path,
data_path=data_path,
results_path=results_path,
file_name=row["file_name"],
temperature=row["temperature"],
seed=row["seed"],
use_gpu=row["use_gpu"]
)
else:
with console.status(console.rule("Generating Gesture")):
console.print("Arguments:")
df = pd.DataFrame([vars(args)])
console.print(df.iloc[0].to_string(index=True))
file_name = args.file_name
generate_gesture(
audio_file=Path(args.audio),
styles=[(Path(args.style), args.frames)],
network_path=network_path,
data_path=data_path,
results_path=results_path,
file_name=args.file_name,
temperature=args.temperature,
seed=args.seed,
use_gpu=args.use_gpu
)