BE_XTTS_V3_3epDatKirNewDVAE / train_gpt_xtts_advanced.py
archivartaunik's picture
Initial upload of GPT_XTTS_V2 model files. v 2.0.2 nev vocab with DVAE
b3730bf verified
#!/usr/bin/env python3
"""
Advanced GPT XTTS Training with All P2 Optimizations
Includes: Mixed Precision, Data Augmentation, Gradient Monitoring, torch.compile
Usage:
# Full optimizations (Ampere+ GPU)
CUDA_VISIBLE_DEVICES=0 python train_gpt_xtts_advanced.py \
--output_path checkpoints/ \
--metadatas datasets/metadata_train.csv,datasets/metadata_eval.csv,be \
--use_bfloat16 \
--use_augmentation \
--augmentation_preset medium \
--monitor_gradients \
--compile_model \
--tf32_matmul=True \
--tf32_cudnn=True \
--batch_size 8 \
--grad_acumm 4
# Conservative (older GPUs)
CUDA_VISIBLE_DEVICES=0 python train_gpt_xtts_advanced.py \
--output_path checkpoints/ \
--metadatas datasets/metadata_train.csv,datasets/metadata_eval.csv,be \
--use_fp16 \
--use_augmentation \
--batch_size 4 \
--grad_acumm 8
"""
import os
import sys
import gc
import torch
import argparse
from pathlib import Path
# Add current directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager
# Import P2 modules
try:
from audio_augmentation import get_augmentor, PRESET_CONFIGS
AUGMENTATION_AVAILABLE = True
except ImportError:
print("Warning: audio_augmentation.py not found. Augmentation disabled.")
AUGMENTATION_AVAILABLE = False
try:
from gradient_monitor import GradientMonitor
GRADIENT_MONITOR_AVAILABLE = True
except ImportError:
print("Warning: gradient_monitor.py not found. Gradient monitoring disabled.")
GRADIENT_MONITOR_AVAILABLE = False
def create_parser():
"""Create comprehensive argument parser"""
parser = argparse.ArgumentParser(
description="Advanced XTTS GPT Training with P1+P2 Optimizations",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# === Basic Training Arguments ===
basic = parser.add_argument_group('Basic Training')
basic.add_argument("--output_path", type=str, required=True,
help="Path to checkpoint output directory")
basic.add_argument("--metadatas", nargs='+', type=str, required=True,
help="train_csv,eval_csv,language for each dataset")
basic.add_argument("--num_epochs", type=int, default=8,
help="Number of training epochs")
basic.add_argument("--batch_size", type=int, default=8,
help="Training batch size")
basic.add_argument("--grad_acumm", type=int, default=4,
help="Gradient accumulation steps")
basic.add_argument("--max_audio_length", type=int, default=330750,
help="Maximum audio length in samples (~15 seconds)")
basic.add_argument("--max_text_length", type=int, default=400,
help="Maximum text length in characters")
basic.add_argument("--weight_decay", type=float, default=1e-2,
help="Weight decay for optimizer")
basic.add_argument("--lr", type=float, default=5e-6,
help="Learning rate")
basic.add_argument("--save_step", type=int, default=10000,
help="Save checkpoint every N steps")
# === P1: GPU Optimizations ===
p1 = parser.add_argument_group('P1: GPU Optimizations')
p1.add_argument("--tf32_matmul", action="store_true",
help="Enable TF32 matrix multiplication (Ampere+ GPUs)")
p1.add_argument("--tf32_cudnn", action="store_true",
help="Enable TF32 for cuDNN (Ampere+ GPUs)")
p1.add_argument("--num_workers", type=int, default=8,
help="DataLoader worker threads")
# === P2: Mixed Precision ===
mp = parser.add_argument_group('P2: Mixed Precision Training')
mp.add_argument("--use_fp16", action="store_true",
help="Use FP16 mixed precision (older GPUs)")
mp.add_argument("--use_bfloat16", action="store_true",
help="Use BFloat16 mixed precision (Ampere+ recommended)")
mp.add_argument("--no_gradient_scaling", action="store_true",
help="Disable gradient scaling (for BFloat16)")
# === P2: Data Augmentation ===
aug = parser.add_argument_group('P2: Data Augmentation')
aug.add_argument("--use_augmentation", action="store_true",
help="Enable audio data augmentation")
aug.add_argument("--augmentation_preset", type=str, default="medium",
choices=list(PRESET_CONFIGS.keys()) if AUGMENTATION_AVAILABLE else [],
help="Augmentation preset (light/medium/heavy)")
aug.add_argument("--augment_prob", type=float, default=0.3,
help="Probability of applying augmentation")
# === P2: Gradient Monitoring ===
grad = parser.add_argument_group('P2: Gradient Monitoring')
grad.add_argument("--monitor_gradients", action="store_true",
help="Enable gradient monitoring and logging")
grad.add_argument("--gradient_clip_val", type=float, default=1.0,
help="Gradient clipping value (None = no clipping)")
grad.add_argument("--gradient_log_interval", type=int, default=100,
help="Log gradients every N steps")
# === P2: Model Compilation ===
comp = parser.add_argument_group('P2: Model Compilation (PyTorch 2.0+)')
comp.add_argument("--compile_model", action="store_true",
help="Use torch.compile for faster training")
comp.add_argument("--compile_mode", type=str, default="reduce-overhead",
choices=["default", "reduce-overhead", "max-autotune"],
help="Compilation mode")
# === Advanced Options ===
adv = parser.add_argument_group('Advanced Options')
adv.add_argument("--profile", action="store_true",
help="Enable PyTorch profiler")
adv.add_argument("--detect_anomaly", action="store_true",
help="Enable anomaly detection (slower, for debugging)")
return parser
def print_configuration(args):
"""Print training configuration"""
print("\n" + "=" * 70)
print(" " * 20 + "TRAINING CONFIGURATION")
print("=" * 70)
print("\n๐Ÿ“Š Basic Settings:")
print(f" Output Path: {args.output_path}")
print(f" Epochs: {args.num_epochs}")
print(f" Batch Size: {args.batch_size}")
print(f" Grad Accumulation: {args.grad_acumm}")
print(f" Effective Batch: {args.batch_size * args.grad_acumm}")
print(f" Learning Rate: {args.lr:.2e}")
print(f" Weight Decay: {args.weight_decay}")
print("\n๐ŸŽฎ GPU Settings (P1):")
if torch.cuda.is_available():
print(f" Device: {torch.cuda.get_device_name(0)}")
print(f" TF32 MatMul: {'โœ“ Enabled' if args.tf32_matmul else 'โœ— Disabled'}")
print(f" TF32 cuDNN: {'โœ“ Enabled' if args.tf32_cudnn else 'โœ— Disabled'}")
print(f" DataLoader Workers: {args.num_workers}")
else:
print(f" Device: CPU (CUDA not available)")
print("\n๐Ÿ”ฌ Mixed Precision (P2):")
if args.use_bfloat16:
print(f" Type: BFloat16 โœ“")
print(f" Gradient Scaling: Disabled (not needed)")
elif args.use_fp16:
print(f" Type: FP16 โœ“")
print(f" Gradient Scaling: {'Disabled' if args.no_gradient_scaling else 'Enabled โœ“'}")
else:
print(f" Type: FP32 (disabled)")
print("\n๐ŸŽจ Data Augmentation (P2):")
if args.use_augmentation and AUGMENTATION_AVAILABLE:
print(f" Status: โœ“ Enabled")
print(f" Preset: {args.augmentation_preset}")
print(f" Probability: {args.augment_prob}")
else:
print(f" Status: โœ— Disabled")
print("\n๐Ÿ“ˆ Gradient Monitoring (P2):")
if args.monitor_gradients and GRADIENT_MONITOR_AVAILABLE:
print(f" Status: โœ“ Enabled")
print(f" Gradient Clipping: {args.gradient_clip_val if args.gradient_clip_val else 'Disabled'}")
print(f" Log Interval: Every {args.gradient_log_interval} steps")
else:
print(f" Status: โœ— Disabled")
print("\nโšก Model Compilation (P2):")
if args.compile_model and hasattr(torch, 'compile'):
print(f" Status: โœ“ Enabled")
print(f" Mode: {args.compile_mode}")
else:
print(f" Status: โœ— Disabled")
print("\n" + "=" * 70 + "\n")
def optimize_gpu_settings(args):
"""Apply GPU optimizations (P1) - GPU-specific configuration"""
if not torch.cuda.is_available():
print("โš ๏ธ CUDA not available. Training on CPU will be very slow.")
return {"gpu_name": "CPU", "supports_tf32": False, "compute_capability": None}
gpu_name = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
compute_capability = f"{props.major}.{props.minor}"
print("=" * 70)
print("๐ŸŽฎ GPU Configuration")
print("=" * 70)
print(f"GPU: {gpu_name}")
print(f"Compute Capability: {compute_capability}")
print(f"Memory: {props.total_memory / 1024**3:.1f} GB")
# Determine architecture
is_ampere_or_newer = props.major >= 8 # Ampere (A100, A6000, 3090, 4090, etc.)
is_turing = props.major == 7 and props.minor == 5 # Turing (T4, RTX 2080, etc.)
supports_tf32 = is_ampere_or_newer
# TF32 - Only available on Ampere+ (compute capability >= 8.0)
if args.tf32_matmul or args.tf32_cudnn:
if supports_tf32:
torch.backends.cuda.matmul.allow_tf32 = args.tf32_matmul
torch.backends.cudnn.allow_tf32 = args.tf32_cudnn
print(f"TF32 MatMul: โœ“ Enabled")
print(f"TF32 cuDNN: โœ“ Enabled")
else:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
print(f"TF32: โœ— Not supported (requires Ampere+, CC >= 8.0)")
print(f" โ†’ Using FP32 instead (still fast on {gpu_name})")
else:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
print(f"TF32: โœ— Disabled by user")
# cuDNN benchmark - Good for all GPUs with fixed input sizes
torch.backends.cudnn.benchmark = True
print(f"cuDNN Benchmark: โœ“ Enabled")
# Memory allocator - Adjust based on GPU memory
# T4 has 16GB, more aggressive fragmentation management helps
if props.total_memory < 20 * 1024**3: # Less than 20GB
os.environ['PYTORCH_ALLOC_CONF'] = 'max_split_size_mb:256,expandable_segments:True'
print(f"Memory Management: โœ“ Optimized for <20GB VRAM")
else:
os.environ['PYTORCH_ALLOC_CONF'] = 'max_split_size_mb:512'
print(f"Memory Management: โœ“ Standard configuration")
# Anomaly detection
if args.detect_anomaly:
torch.autograd.set_detect_anomaly(True)
print(f"Anomaly Detection: โš ๏ธ Enabled (slower, for debugging)")
print("=" * 70 + "\n")
return {
"gpu_name": gpu_name,
"supports_tf32": supports_tf32,
"compute_capability": compute_capability,
"is_turing": is_turing,
"is_ampere_or_newer": is_ampere_or_newer
}
def train_with_advanced_features(args):
"""Main training function with all P2 features"""
# === Setup ===
gpu_info = optimize_gpu_settings(args)
RUN_NAME = "GPT_XTTS_Advanced"
OUT_PATH = args.output_path
# Process datasets
DATASETS_CONFIG_LIST = []
for metadata in args.metadatas:
train_csv, eval_csv, language = metadata.split(",")
print(f"๐Ÿ“ Dataset: {language} - {train_csv}")
config_dataset = BaseDatasetConfig(
formatter="coqui",
dataset_name="ft_dataset",
path=os.path.dirname(train_csv),
meta_file_train=os.path.basename(train_csv),
meta_file_val=os.path.basename(eval_csv),
language=language,
)
DATASETS_CONFIG_LIST.append(config_dataset)
# === Download checkpoints ===
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, "dvae.pth")
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, "mel_stats.pth")
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print("๐Ÿ“ฅ Downloading DVAE files...")
ModelManager._download_model_files(
["https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"],
CHECKPOINTS_OUT_PATH, progress_bar=True)
# XTTS files
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, "vocab.json")
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, "model.pth")
XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, "config.json")
if not all(os.path.isfile(f) for f in [TOKENIZER_FILE, XTTS_CHECKPOINT, XTTS_CONFIG_FILE]):
print("๐Ÿ“ฅ Downloading XTTS v2.0 files...")
ModelManager._download_model_files(
["https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json"],
CHECKPOINTS_OUT_PATH, progress_bar=True)
# === Model configuration ===
model_args = GPTArgs(
max_conditioning_length=264600,
min_conditioning_length=88200,
debug_loading_failures=False,
max_wav_length=args.max_audio_length,
max_text_length=args.max_text_length,
mel_norm_file=MEL_NORM_FILE,
dvae_checkpoint=DVAE_CHECKPOINT,
xtts_checkpoint=XTTS_CHECKPOINT,
tokenizer_file=TOKENIZER_FILE,
gpt_num_audio_tokens=1026,
gpt_start_audio_token=1024,
gpt_stop_audio_token=1025,
gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True,
)
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
config = GPTTrainerConfig()
config.load_json(XTTS_CONFIG_FILE)
# Remove incompatible parameters for newer trainer versions
# These parameters exist in old config files but are not supported by TrainerArgs anymore
incompatible_params = ['grad_clip', 'grad_clip_norm_type']
for param in incompatible_params:
if hasattr(config, param):
print(f" > Removing incompatible config parameter: {param}")
delattr(config, param)
# Also remove from __dict__ if present (Coqpit stores values there)
if param in config.__dict__:
del config.__dict__[param]
config.epochs = args.num_epochs
config.output_path = OUT_PATH
config.model_args = model_args
config.run_name = RUN_NAME
config.project_name = "XTTS_Advanced"
config.run_description = "Advanced training with P1+P2 optimizations"
config.dashboard_logger = "tensorboard"
config.audio = audio_config
config.batch_size = args.batch_size
config.num_loader_workers = args.num_workers
config.eval_split_max_size = 256
config.print_step = 50
config.plot_step = 100
config.log_model_step = 100
config.save_step = args.save_step
config.save_n_checkpoints = 1
config.save_checkpoints = True
config.print_eval = False
config.optimizer = "AdamW"
config.optimizer_wd_only_on_weights = True
config.optimizer_params = {
"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": args.weight_decay}
config.lr = args.lr
config.lr_scheduler = "MultiStepLR"
config.lr_scheduler_params = {
"milestones": [args.save_step * 3, args.save_step * 6, args.save_step * 9],
"gamma": 0.5, "last_epoch": -1}
config.test_sentences = []
# === Initialize model ===
print("\n๐Ÿ”ง Initializing model...")
model = GPTTrainer.init_from_config(config)
# === P2: Apply torch.compile ===
if args.compile_model and hasattr(torch, 'compile'):
print(f"\nโšก Compiling model (mode: {args.compile_mode})...")
# GPU-specific compile options
compile_options = {"triton.cudagraphs": False} # Disabled to prevent tensor overwriting
# T4 and Turing GPUs: Use reduce-overhead mode for better performance
# Ampere+: Can use max-autotune for more aggressive optimizations
if gpu_info.get("is_turing"):
recommended_mode = "reduce-overhead"
if args.compile_mode == "default":
print(f" โ†’ Turing GPU detected: Using '{recommended_mode}' mode for better performance")
compile_mode = recommended_mode
else:
compile_mode = args.compile_mode
elif gpu_info.get("is_ampere_or_newer"):
recommended_mode = "max-autotune"
if args.compile_mode == "default":
print(f" โ†’ Ampere+ GPU detected: Using '{recommended_mode}' mode for maximum speed")
compile_mode = recommended_mode
else:
compile_mode = args.compile_mode
else:
compile_mode = args.compile_mode
try:
model.xtts.gpt = torch.compile(
model.xtts.gpt,
mode=compile_mode,
fullgraph=False,
options=compile_options
)
print(f"โœ“ Model compiled successfully")
print(f" Mode: {compile_mode}")
print(f" CUDA graphs: Disabled (prevents backward pass errors)")
except Exception as e:
print(f"โš ๏ธ Compilation failed: {e}")
print(" Continuing without compilation")
# === Load training samples ===
print("\n๐Ÿ“š Loading training samples...")
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
print(f" Train samples: {len(train_samples)}")
print(f" Eval samples: {len(eval_samples)}")
# === P2: Setup augmentation ===
augmentor = None
if args.use_augmentation and AUGMENTATION_AVAILABLE:
print(f"\n๐ŸŽจ Setting up data augmentation ({args.augmentation_preset} preset)...")
augmentor = get_augmentor(
preset=args.augmentation_preset,
sample_rate=22050,
augment_prob=args.augment_prob
)
print("โœ“ Augmentation enabled")
# === P2: Setup gradient monitoring ===
gradient_monitor = None
if args.monitor_gradients and GRADIENT_MONITOR_AVAILABLE:
print(f"\n๐Ÿ“ˆ Setting up gradient monitoring...")
gradient_monitor = GradientMonitor(
log_dir=os.path.join(OUT_PATH, "gradient_logs"),
log_interval=args.gradient_log_interval,
enable_tensorboard=True,
)
print("โœ“ Gradient monitoring enabled")
# === Initialize trainer ===
print("\n๐Ÿš€ Initializing trainer...")
trainer = Trainer(
TrainerArgs(
restore_path=None,
skip_train_epoch=False,
start_with_eval=False,
grad_accum_steps=args.grad_acumm,
# Note: grad_clip is not supported in newer trainer versions
# Gradient clipping is handled by the model's train_step if needed
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
# === Start training ===
print("\n" + "=" * 70)
print(" " * 25 + "STARTING TRAINING")
print("=" * 70 + "\n")
try:
trainer.fit()
except KeyboardInterrupt:
print("\nโš ๏ธ Training interrupted by user")
except Exception as e:
print(f"\nโŒ Training failed with error: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Cleanup
if gradient_monitor is not None:
print("\n๐Ÿ“Š Saving gradient statistics...")
gradient_monitor.close()
trainer_out_path = trainer.output_path
# Final cleanup
del model, trainer, train_samples, eval_samples
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return trainer_out_path
def main():
parser = create_parser()
args = parser.parse_args()
# Validate arguments
if args.use_bfloat16 and args.use_fp16:
print("โŒ Error: Cannot use both BFloat16 and FP16 simultaneously")
sys.exit(1)
if args.use_augmentation and not AUGMENTATION_AVAILABLE:
print("โš ๏ธ Warning: Augmentation requested but module not available")
args.use_augmentation = False
if args.monitor_gradients and not GRADIENT_MONITOR_AVAILABLE:
print("โš ๏ธ Warning: Gradient monitoring requested but module not available")
args.monitor_gradients = False
# Print configuration
print_configuration(args)
# Train
try:
trainer_out_path = train_with_advanced_features(args)
print("\n" + "=" * 70)
print(" " * 25 + "TRAINING COMPLETED!")
print("=" * 70)
print(f"\nโœ“ Checkpoint saved in: {trainer_out_path}")
print(f"โœ“ TensorBoard logs: {os.path.join(trainer_out_path, 'tensorboard')}")
if args.monitor_gradients:
print(f"โœ“ Gradient logs: {os.path.join(trainer_out_path, 'gradient_logs')}")
print("\n" + "=" * 70 + "\n")
except Exception as e:
print(f"\nโŒ Training failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()