|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import pytest |
|
|
|
from llamafactory.train.tuner import export_model, run_exp |
|
|
|
|
|
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data") |
|
|
|
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") |
|
|
|
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora") |
|
|
|
TRAIN_ARGS = { |
|
"model_name_or_path": TINY_LLAMA, |
|
"do_train": True, |
|
"finetuning_type": "lora", |
|
"dataset_dir": "REMOTE:" + DEMO_DATA, |
|
"template": "llama3", |
|
"cutoff_len": 1, |
|
"overwrite_cache": False, |
|
"overwrite_output_dir": True, |
|
"per_device_train_batch_size": 1, |
|
"max_steps": 1, |
|
} |
|
|
|
INFER_ARGS = { |
|
"model_name_or_path": TINY_LLAMA, |
|
"adapter_name_or_path": TINY_LLAMA_ADAPTER, |
|
"finetuning_type": "lora", |
|
"template": "llama3", |
|
"infer_dtype": "float16", |
|
"export_dir": "llama3_export", |
|
} |
|
|
|
OS_NAME = os.environ.get("OS_NAME", "") |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"stage,dataset", |
|
[ |
|
("pt", "c4_demo"), |
|
("sft", "alpaca_en_demo"), |
|
("dpo", "dpo_en_demo"), |
|
("kto", "kto_en_demo"), |
|
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")), |
|
], |
|
) |
|
def test_run_exp(stage: str, dataset: str): |
|
output_dir = "train_{}".format(stage) |
|
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS}) |
|
assert os.path.exists(output_dir) |
|
|
|
|
|
def test_export(): |
|
export_model(INFER_ARGS) |
|
assert os.path.exists("llama3_export") |
|
|