E2e device cuda (#575)
Browse files* use torch.cuda.current_device() instead of local_rank
* ignore NVML errors for gpu stats
* llama lora packing e2e tests
- .github/workflows/e2e.yml +1 -0
- src/axolotl/utils/bench.py +8 -5
- src/axolotl/utils/config.py +1 -1
- tests/e2e/test_lora_llama.py +42 -0
.github/workflows/e2e.yml
CHANGED
@@ -24,6 +24,7 @@ jobs:
|
|
24 |
- name: Install dependencies
|
25 |
run: |
|
26 |
pip3 install -e .
|
|
|
27 |
pip3 install -r requirements-tests.txt
|
28 |
|
29 |
- name: Run e2e tests
|
|
|
24 |
- name: Install dependencies
|
25 |
run: |
|
26 |
pip3 install -e .
|
27 |
+
pip3 install flash-attn
|
28 |
pip3 install -r requirements-tests.txt
|
29 |
|
30 |
- name: Run e2e tests
|
src/axolotl/utils/bench.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
|
3 |
import pynvml
|
4 |
import torch
|
|
|
5 |
|
6 |
|
7 |
def gpu_memory_usage(device=0):
|
@@ -20,11 +21,13 @@ def gpu_memory_usage_smi(device=0):
|
|
20 |
device = device.index
|
21 |
if isinstance(device, str) and device.startswith("cuda:"):
|
22 |
device = int(device[5:])
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
28 |
|
29 |
|
30 |
def log_gpu_memory_usage(log, msg, device):
|
|
|
2 |
|
3 |
import pynvml
|
4 |
import torch
|
5 |
+
from pynvml.nvml import NVMLError
|
6 |
|
7 |
|
8 |
def gpu_memory_usage(device=0):
|
|
|
21 |
device = device.index
|
22 |
if isinstance(device, str) and device.startswith("cuda:"):
|
23 |
device = int(device[5:])
|
24 |
+
try:
|
25 |
+
pynvml.nvmlInit()
|
26 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
27 |
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
28 |
+
return info.used / 1024.0**3
|
29 |
+
except NVMLError:
|
30 |
+
return 0.0
|
31 |
|
32 |
|
33 |
def log_gpu_memory_usage(log, msg, device):
|
src/axolotl/utils/config.py
CHANGED
@@ -29,7 +29,7 @@ def choose_device(cfg):
|
|
29 |
cfg.device_map = "auto"
|
30 |
else:
|
31 |
if cfg.device.startswith("cuda"):
|
32 |
-
cfg.device_map = {"":
|
33 |
else:
|
34 |
cfg.device_map = {"": cfg.device}
|
35 |
|
|
|
29 |
cfg.device_map = "auto"
|
30 |
else:
|
31 |
if cfg.device.startswith("cuda"):
|
32 |
+
cfg.device_map = {"": torch.cuda.current_device()}
|
33 |
else:
|
34 |
cfg.device_map = {"": cfg.device}
|
35 |
|
tests/e2e/test_lora_llama.py
CHANGED
@@ -78,3 +78,45 @@ class TestLoraLlama(unittest.TestCase):
|
|
78 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
79 |
|
80 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
79 |
|
80 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
81 |
+
|
82 |
+
def test_lora_packing(self):
|
83 |
+
cfg = DictDefault(
|
84 |
+
{
|
85 |
+
"base_model": "JackFram/llama-68m",
|
86 |
+
"base_model_config": "JackFram/llama-68m",
|
87 |
+
"tokenizer_type": "LlamaTokenizer",
|
88 |
+
"sequence_len": 1024,
|
89 |
+
"sample_packing": True,
|
90 |
+
"flash_attention": True,
|
91 |
+
"load_in_8bit": True,
|
92 |
+
"adapter": "lora",
|
93 |
+
"lora_r": 32,
|
94 |
+
"lora_alpha": 64,
|
95 |
+
"lora_dropout": 0.05,
|
96 |
+
"lora_target_linear": True,
|
97 |
+
"val_set_size": 0.1,
|
98 |
+
"special_tokens": {
|
99 |
+
"unk_token": "<unk>",
|
100 |
+
"bos_token": "<s>",
|
101 |
+
"eos_token": "</s>",
|
102 |
+
},
|
103 |
+
"datasets": [
|
104 |
+
{
|
105 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
106 |
+
"type": "alpaca",
|
107 |
+
},
|
108 |
+
],
|
109 |
+
"num_epochs": 2,
|
110 |
+
"micro_batch_size": 8,
|
111 |
+
"gradient_accumulation_steps": 1,
|
112 |
+
"output_dir": tempfile.mkdtemp(),
|
113 |
+
"learning_rate": 0.00001,
|
114 |
+
"optimizer": "adamw_torch",
|
115 |
+
"lr_scheduler": "cosine",
|
116 |
+
}
|
117 |
+
)
|
118 |
+
normalize_config(cfg)
|
119 |
+
cli_args = TrainerCliArgs()
|
120 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
121 |
+
|
122 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|