winglian's picture
support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)
d2e7f27 unverified
raw
history blame
765 Bytes
"""Module to load prompt strategies."""
import importlib
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
def load(strategy, tokenizer, cfg, ds_cfg):
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
return None