Spaces:
Running
on
A10G
Running
on
A10G
File size: 1,987 Bytes
0a3525d 69e8a46 28c720a 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 0a3525d 69e8a46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
defaults:
- base
- _self_
project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/fish-speech-1.4
# Lightning Trainer
trainer:
accumulate_grad_batches: 1
gradient_clip_val: 1.0
gradient_clip_algorithm: "norm"
max_steps: 1000
precision: bf16-true
limit_val_batches: 10
val_check_interval: 100
# Dataset Configuration
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${pretrained_ckpt_path}
# Dataset Configuration
train_dataset:
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
causal: true
max_length: ${max_length}
use_speaker: false
interactive_prob: 0.7
val_dataset:
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
causal: true
max_length: ${max_length}
use_speaker: false
interactive_prob: 0.7
data:
_target_: fish_speech.datasets.semantic.SemanticDataModule
train_dataset: ${train_dataset}
val_dataset: ${val_dataset}
num_workers: 4
batch_size: 8
tokenizer: ${tokenizer}
max_length: ${max_length}
# Model Configuration
model:
_target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
model:
_target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
path: ${pretrained_ckpt_path}
load_weights: true
max_length: ${max_length}
lora_config: null
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 1e-4
weight_decay: 0
betas: [0.9, 0.95]
eps: 1e-5
lr_scheduler:
_target_: torch.optim.lr_scheduler.LambdaLR
_partial_: true
lr_lambda:
_target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
_partial_: true
num_warmup_steps: 10
# Callbacks
callbacks:
model_checkpoint:
every_n_train_steps: ${trainer.val_check_interval}
|