File size: 1,987 Bytes
4f6613a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
defaults:
  - base
  - _self_

project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/fish-speech-1.4

# Lightning Trainer
trainer:
  accumulate_grad_batches: 1
  gradient_clip_val: 1.0
  gradient_clip_algorithm: "norm"
  max_steps: 1000
  precision: bf16-true
  limit_val_batches: 10
  val_check_interval: 100

# Dataset Configuration
tokenizer:
  _target_: transformers.AutoTokenizer.from_pretrained
  pretrained_model_name_or_path: ${pretrained_ckpt_path}

# Dataset Configuration
train_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

val_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

data:
  _target_: fish_speech.datasets.semantic.SemanticDataModule
  train_dataset: ${train_dataset}
  val_dataset: ${val_dataset}
  num_workers: 4
  batch_size: 8
  tokenizer: ${tokenizer}
  max_length: ${max_length}

# Model Configuration
model:
  _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
  model: 
    _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
    path: ${pretrained_ckpt_path}
    load_weights: true
    max_length: ${max_length}
    lora_config: null

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 1e-4
    weight_decay: 0
    betas: [0.9, 0.95]
    eps: 1e-5

  lr_scheduler:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 10

# Callbacks
callbacks:
  model_checkpoint:
    every_n_train_steps: ${trainer.val_check_interval}