feat: enable trl's autounwrap (#1060)
Browse files* feat: test trl's autounwrap
* fix: add check for adapter
* feat: add config to disable autounwrap
* chore: fix lint
- .vscode/launch.json +1 -1
- devtools/README.md +1 -1
- docs/debugging.md +4 -4
- docs/rlhf.md +9 -0
- src/axolotl/train.py +9 -4
.vscode/launch.json
CHANGED
@@ -11,7 +11,7 @@
|
|
11 |
"request": "launch",
|
12 |
"args": [
|
13 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
14 |
-
// The flags below simplify debugging by overriding the axolotl config
|
15 |
// with the debugging tips above. Modify as needed.
|
16 |
"--dataset_processes=1", // limits data preprocessing to one process
|
17 |
"--max_steps=1", // limits training to just one step
|
|
|
11 |
"request": "launch",
|
12 |
"args": [
|
13 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
14 |
+
// The flags below simplify debugging by overriding the axolotl config
|
15 |
// with the debugging tips above. Modify as needed.
|
16 |
"--dataset_processes=1", // limits data preprocessing to one process
|
17 |
"--max_steps=1", // limits training to just one step
|
devtools/README.md
CHANGED
@@ -1 +1 @@
|
|
1 |
-
This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
|
|
|
1 |
+
This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
|
docs/debugging.md
CHANGED
@@ -30,13 +30,13 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
|
30 |
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
|
31 |
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
|
32 |
- `micro_batch_size: 1`
|
33 |
-
- `max_steps: 1`
|
34 |
- `val_set_size: 0`
|
35 |
5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.
|
36 |
- Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`.
|
37 |
- HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).
|
38 |
- **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.**
|
39 |
-
|
40 |
|
41 |
## Debugging with VSCode
|
42 |
|
@@ -74,7 +74,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
|
74 |
"request": "launch",
|
75 |
"args": [
|
76 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
77 |
-
// The flags below simplify debugging by overriding the axolotl config
|
78 |
// with the debugging tips above. Modify as needed.
|
79 |
"--dataset_processes=1", // limits data preprocessing to one process
|
80 |
"--max_steps=1", // limits training to just one step
|
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
|
101 |
|
102 |
- The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`.
|
103 |
- The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:
|
104 |
-
- `./devtools/temp_debug/axolotl_outputs`
|
105 |
- `./devtools/temp_debug/.hf-cache/datasets`
|
106 |
|
107 |
>[!Tip]
|
|
|
30 |
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
|
31 |
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
|
32 |
- `micro_batch_size: 1`
|
33 |
+
- `max_steps: 1`
|
34 |
- `val_set_size: 0`
|
35 |
5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.
|
36 |
- Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`.
|
37 |
- HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).
|
38 |
- **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.**
|
39 |
+
|
40 |
|
41 |
## Debugging with VSCode
|
42 |
|
|
|
74 |
"request": "launch",
|
75 |
"args": [
|
76 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
77 |
+
// The flags below simplify debugging by overriding the axolotl config
|
78 |
// with the debugging tips above. Modify as needed.
|
79 |
"--dataset_processes=1", // limits data preprocessing to one process
|
80 |
"--max_steps=1", // limits training to just one step
|
|
|
101 |
|
102 |
- The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`.
|
103 |
- The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:
|
104 |
+
- `./devtools/temp_debug/axolotl_outputs`
|
105 |
- `./devtools/temp_debug/.hf-cache/datasets`
|
106 |
|
107 |
>[!Tip]
|
docs/rlhf.md
CHANGED
@@ -33,3 +33,12 @@ datasets:
|
|
33 |
```yaml
|
34 |
rl: ipo
|
35 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
```yaml
|
34 |
rl: ipo
|
35 |
```
|
36 |
+
|
37 |
+
#### Trl autounwrap for peft
|
38 |
+
|
39 |
+
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
40 |
+
|
41 |
+
```yaml
|
42 |
+
# load ref model when adapter training.
|
43 |
+
rl_adapter_ref_model: true
|
44 |
+
```
|
src/axolotl/train.py
CHANGED
@@ -63,10 +63,15 @@ def train(
|
|
63 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
64 |
model_ref = None
|
65 |
if cfg.rl:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
safe_serialization = cfg.save_safetensors is True
|
72 |
|
|
|
63 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
64 |
model_ref = None
|
65 |
if cfg.rl:
|
66 |
+
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
67 |
+
# use built-in trl autounwrap
|
68 |
+
LOG.debug("Passing model_ref: None to RL trainer")
|
69 |
+
model_ref = None # explicit setting to None
|
70 |
+
else:
|
71 |
+
# load the model again for model_ref/baseline
|
72 |
+
model_ref, _ = load_model(
|
73 |
+
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
74 |
+
)
|
75 |
|
76 |
safe_serialization = cfg.save_safetensors is True
|
77 |
|