File size: 5,381 Bytes
9218ebe 03e5907 9218ebe 12a2dbb 9218ebe 12a2dbb 9218ebe c74f045 7de6a56 9218ebe 0402d19 6dc68a6 12a2dbb 9218ebe 2414673 6dc68a6 2414673 6dc68a6 2414673 0402d19 6dc68a6 12a2dbb 2414673 9218ebe 6dc68a6 9218ebe b3a61e8 9218ebe 6dc68a6 03e5907 0402d19 6dc68a6 03e5907 6dc68a6 03e5907 6dc68a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLoraLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_gptq(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
|