|
"""Module containing data utilities""" |
|
import functools |
|
import hashlib |
|
import logging |
|
from collections import defaultdict |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import yaml |
|
from datasets import ( |
|
Dataset, |
|
DatasetDict, |
|
concatenate_datasets, |
|
load_dataset, |
|
load_from_disk, |
|
) |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.utils import HFValidationError |
|
from torch.utils.data import RandomSampler |
|
from transformers import PreTrainedTokenizerBase |
|
|
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH |
|
from axolotl.datasets import TokenizedPromptDataset |
|
from axolotl.prompt_strategies import load |
|
from axolotl.prompt_strategies.dpo import load as load_dpo |
|
from axolotl.prompt_tokenizers import ( |
|
AlpacaMultipleChoicePromptTokenizingStrategy, |
|
AlpacaPromptTokenizingStrategy, |
|
AlpacaReflectionPTStrategy, |
|
GPTeacherPromptTokenizingStrategy, |
|
JeopardyPromptTokenizingStrategy, |
|
OpenAssistantPromptTokenizingStrategy, |
|
SummarizeTLDRPromptTokenizingStrategy, |
|
) |
|
from axolotl.prompters import ( |
|
AlpacaPrompter, |
|
GPTeacherPrompter, |
|
JeopardyPrompter, |
|
MultipleChoiceConcisePrompter, |
|
MultipleChoiceExplainPrompter, |
|
Prompter, |
|
ReflectAlpacaPrompter, |
|
SummarizeTLDRPrompter, |
|
UnsupportedPrompter, |
|
) |
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq |
|
from axolotl.utils.dict import DictDefault |
|
from axolotl.utils.distributed import is_main_process, zero_first |
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths |
|
from axolotl.utils.trainer import ( |
|
calculate_total_num_steps, |
|
process_datasets_for_packing, |
|
process_pretraining_datasets_for_packing, |
|
) |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
|
|
def md5(to_hash: str, encoding: str = "utf-8") -> str: |
|
try: |
|
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() |
|
except TypeError: |
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() |
|
|
|
|
|
def prepare_dataset(cfg, tokenizer): |
|
prompters = [] |
|
if not cfg.pretraining_dataset: |
|
with zero_first(is_main_process()): |
|
if cfg.test_datasets: |
|
train_dataset, _, prompters = load_prepare_datasets( |
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" |
|
) |
|
_, eval_dataset, _ = load_prepare_datasets( |
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" |
|
) |
|
else: |
|
train_dataset, eval_dataset, prompters = load_prepare_datasets( |
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH |
|
) |
|
else: |
|
path = cfg.pretraining_dataset |
|
name = None |
|
if isinstance(cfg.pretraining_dataset, list) and isinstance( |
|
cfg.pretraining_dataset[0], dict |
|
): |
|
path = cfg.pretraining_dataset[0]["path"] |
|
name = cfg.pretraining_dataset[0]["name"] |
|
|
|
ds_wrapper_partial = functools.partial( |
|
get_dataset_wrapper, |
|
cfg.pretraining_dataset[0], |
|
tokenizer, |
|
cfg, |
|
cfg.pretraining_dataset[0]["type"] or "pretrain", |
|
) |
|
|
|
train_dataset = wrap_pretraining_dataset( |
|
load_dataset(path, streaming=True, split="train", name=name), |
|
tokenizer, |
|
cfg, |
|
ds_wrapper_partial, |
|
max_tokens=cfg.sequence_len, |
|
batch_size=cfg.micro_batch_size, |
|
seed=cfg.seed or 42, |
|
) |
|
|
|
train_dataset = train_dataset.with_format("torch") |
|
eval_dataset = None |
|
return train_dataset, eval_dataset, cfg.max_steps, prompters |
|
|
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: |
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) |
|
if total_eval_steps == 0: |
|
raise ValueError( |
|
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " |
|
) |
|
|
|
if cfg.max_steps: |
|
total_num_steps = min( |
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps |
|
) |
|
LOG.info(f"Maximum number of steps set at {total_num_steps}") |
|
else: |
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset) |
|
return train_dataset, eval_dataset, total_num_steps, prompters |
|
|
|
|
|
def load_tokenized_prepared_datasets( |
|
tokenizer, |
|
cfg, |
|
default_dataset_prepared_path, |
|
split="train", |
|
) -> Tuple[DatasetDict, List[Prompter]]: |
|
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets |
|
tokenizer_name = tokenizer.__class__.__name__ |
|
ds_hash = str( |
|
md5( |
|
( |
|
str(cfg.sequence_len) |
|
+ "@" |
|
+ str(cfg.sample_packing) |
|
+ "@" |
|
+ str(cfg.eval_sample_packing) |
|
+ "@" |
|
+ str(cfg.group_by_length) |
|
+ "@" |
|
+ "|".join( |
|
sorted( |
|
[ |
|
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}" |
|
for d in cfg_datasets |
|
] |
|
) |
|
) |
|
+ "|" |
|
+ tokenizer_name |
|
) |
|
) |
|
) |
|
prepared_ds_path = ( |
|
Path(cfg.dataset_prepared_path) / ds_hash |
|
if cfg.dataset_prepared_path |
|
else Path(default_dataset_prepared_path) / ds_hash |
|
) |
|
dataset = None |
|
prompters = [] |
|
use_auth_token = cfg.hf_use_auth_token |
|
try: |
|
if cfg.push_dataset_to_hub: |
|
dataset = load_dataset( |
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", |
|
token=use_auth_token, |
|
) |
|
dataset = dataset[split] |
|
except Exception: |
|
pass |
|
|
|
if dataset: |
|
... |
|
elif ( |
|
cfg.dataset_prepared_path |
|
and any(prepared_ds_path.glob("*")) |
|
and not cfg.is_preprocess |
|
): |
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") |
|
dataset = load_from_disk(str(prepared_ds_path)) |
|
LOG.info("Prepared dataset loaded from disk...") |
|
else: |
|
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") |
|
LOG.info("Loading raw datasets...") |
|
if not cfg.is_preprocess: |
|
LOG.warning( |
|
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." |
|
) |
|
|
|
if cfg.seed: |
|
seed = cfg.seed |
|
else: |
|
LOG.info("No seed provided, using default seed of 42") |
|
seed = 42 |
|
|
|
datasets = [] |
|
|
|
def for_d_in_datasets(dataset_configs): |
|
for dataset in dataset_configs: |
|
if dataset.name and isinstance(dataset.name, list): |
|
for name in dataset.name: |
|
yield DictDefault({**dataset, "name": name}) |
|
else: |
|
yield dataset |
|
|
|
|
|
for config_dataset in for_d_in_datasets(cfg_datasets): |
|
ds: Optional[Union[Dataset, DatasetDict]] = None |
|
ds_from_hub = False |
|
try: |
|
load_dataset( |
|
config_dataset.path, |
|
name=config_dataset.name, |
|
streaming=True, |
|
token=use_auth_token, |
|
) |
|
ds_from_hub = True |
|
except (FileNotFoundError, ConnectionError, HFValidationError): |
|
pass |
|
|
|
ds_from_cloud = False |
|
storage_options = {} |
|
remote_file_system = None |
|
if config_dataset.path.startswith("s3://"): |
|
try: |
|
import aiobotocore.session |
|
import s3fs |
|
except ImportError as exc: |
|
raise ImportError( |
|
"s3:// paths require aiobotocore and s3fs to be installed" |
|
) from exc |
|
|
|
|
|
s3_session = aiobotocore.session.AioSession(profile="default") |
|
storage_options = {"session": s3_session} |
|
remote_file_system = s3fs.S3FileSystem(**storage_options) |
|
elif config_dataset.path.startswith( |
|
"gs://" |
|
) or config_dataset.path.startswith("gcs://"): |
|
try: |
|
import gcsfs |
|
except ImportError as exc: |
|
raise ImportError( |
|
"gs:// or gcs:// paths require gcsfs to be installed" |
|
) from exc |
|
|
|
|
|
|
|
storage_options = {"token": None} |
|
remote_file_system = gcsfs.GCSFileSystem(**storage_options) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
if remote_file_system and remote_file_system.exists( |
|
config_dataset.path |
|
): |
|
ds_from_cloud = True |
|
except (FileNotFoundError, ConnectionError): |
|
pass |
|
|
|
|
|
local_path = Path(config_dataset.path) |
|
if local_path.exists(): |
|
if local_path.is_dir(): |
|
|
|
ds = load_dataset( |
|
config_dataset.path, |
|
name=config_dataset.name, |
|
data_files=config_dataset.data_files, |
|
streaming=False, |
|
split=None, |
|
) |
|
elif local_path.is_file(): |
|
ds_type = get_ds_type(config_dataset) |
|
|
|
ds = load_dataset( |
|
ds_type, |
|
name=config_dataset.name, |
|
data_files=config_dataset.path, |
|
streaming=False, |
|
split=None, |
|
) |
|
else: |
|
raise ValueError( |
|
"unhandled dataset load: local path exists, but is neither a directory or a file" |
|
) |
|
elif ds_from_hub: |
|
ds = load_dataset( |
|
config_dataset.path, |
|
name=config_dataset.name, |
|
streaming=False, |
|
data_files=config_dataset.data_files, |
|
token=use_auth_token, |
|
) |
|
elif ds_from_cloud and remote_file_system: |
|
if remote_file_system.isdir(config_dataset.path): |
|
ds = load_from_disk( |
|
config_dataset.path, |
|
storage_options=storage_options, |
|
) |
|
elif remote_file_system.isfile(config_dataset.path): |
|
ds_type = get_ds_type(config_dataset) |
|
ds = load_dataset( |
|
ds_type, |
|
name=config_dataset.name, |
|
data_files=config_dataset.path, |
|
streaming=False, |
|
split=None, |
|
storage_options=storage_options, |
|
) |
|
elif config_dataset.path.startswith("https://"): |
|
ds_type = get_ds_type(config_dataset) |
|
ds = load_dataset( |
|
ds_type, |
|
name=config_dataset.name, |
|
data_files=config_dataset.path, |
|
streaming=False, |
|
split=None, |
|
storage_options=storage_options, |
|
) |
|
else: |
|
if isinstance(config_dataset.data_files, str): |
|
fp = hf_hub_download( |
|
repo_id=config_dataset.path, |
|
repo_type="dataset", |
|
filename=config_dataset.data_files, |
|
) |
|
elif isinstance(config_dataset.data_files, list): |
|
fp = [] |
|
for file in config_dataset.data_files: |
|
fp.append( |
|
hf_hub_download( |
|
repo_id=config_dataset.path, |
|
repo_type="dataset", |
|
filename=file, |
|
) |
|
) |
|
else: |
|
raise ValueError( |
|
"data_files must be either a string or list of strings" |
|
) |
|
ds = load_dataset( |
|
"json", |
|
name=config_dataset.name, |
|
data_files=fp, |
|
streaming=False, |
|
split=None, |
|
) |
|
if not ds: |
|
raise ValueError("unhandled dataset load") |
|
|
|
d_base_type = d_prompt_style = None |
|
d_type = config_dataset.type |
|
if isinstance(d_type, str): |
|
d_type_split = d_type.split(":") |
|
d_base_type = d_type_split[0] |
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None |
|
|
|
if config_dataset.split and config_dataset.split in ds: |
|
ds = ds[config_dataset.split] |
|
elif split in ds: |
|
ds = ds[split] |
|
elif isinstance(ds, DatasetDict): |
|
raise ValueError( |
|
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" |
|
) |
|
|
|
|
|
if config_dataset.shards: |
|
shards_idx = config_dataset.get("shards_idx", 0) |
|
ds = ds.shuffle(seed=seed).shard( |
|
num_shards=config_dataset.shards, index=shards_idx |
|
) |
|
|
|
dataset_wrapper, dataset_prompter = get_dataset_wrapper( |
|
config_dataset=config_dataset, |
|
tokenizer=tokenizer, |
|
cfg=cfg, |
|
dataset=ds, |
|
d_base_type=d_base_type, |
|
d_prompt_style=d_prompt_style, |
|
) |
|
datasets.append(dataset_wrapper) |
|
prompters.append(dataset_prompter) |
|
|
|
LOG.info("merging datasets") |
|
dataset = concatenate_datasets(datasets) |
|
|
|
if len(datasets) > 1: |
|
LOG.info("shuffle merged datasets") |
|
dataset = dataset.shuffle(seed=seed) |
|
|
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None) |
|
|
|
if cfg.local_rank == 0: |
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") |
|
dataset.save_to_disk(prepared_ds_path) |
|
if cfg.push_dataset_to_hub: |
|
LOG.info( |
|
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" |
|
) |
|
dataset.push_to_hub( |
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True |
|
) |
|
|
|
return dataset, prompters |
|
|
|
|
|
def get_ds_type(config_dataset: DictDefault): |
|
""" |
|
Get the dataset type from the path if it's not specified |
|
""" |
|
ds_type = "json" |
|
if config_dataset.ds_type: |
|
ds_type = config_dataset.ds_type |
|
elif ".parquet" in config_dataset.path: |
|
ds_type = "parquet" |
|
elif ".arrow" in config_dataset.path: |
|
ds_type = "arrow" |
|
elif ".csv" in config_dataset.path: |
|
ds_type = "csv" |
|
elif ".txt" in config_dataset.path: |
|
ds_type = "text" |
|
return ds_type |
|
|
|
|
|
def load_prepare_datasets( |
|
tokenizer: PreTrainedTokenizerBase, |
|
cfg, |
|
default_dataset_prepared_path, |
|
split="train", |
|
) -> Tuple[Dataset, Dataset, List[Prompter]]: |
|
dataset, prompters = load_tokenized_prepared_datasets( |
|
tokenizer, cfg, default_dataset_prepared_path, split=split |
|
) |
|
|
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: |
|
LOG.info( |
|
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" |
|
) |
|
dataset = dataset.shard( |
|
num_shards=cfg.dataset_shard_num, |
|
index=cfg.dataset_shard_idx, |
|
) |
|
|
|
if split == "train" and cfg.val_set_size: |
|
|
|
to_hash_train = ( |
|
dataset._fingerprint |
|
+ "|" |
|
+ str(cfg.val_set_size) |
|
+ "|" |
|
+ "train" |
|
+ "|" |
|
+ str(cfg.seed or 42) |
|
) |
|
to_hash_test = ( |
|
dataset._fingerprint |
|
+ "|" |
|
+ str(cfg.val_set_size) |
|
+ "|" |
|
+ "test" |
|
+ "|" |
|
+ str(cfg.seed or 42) |
|
) |
|
train_fingerprint = md5(to_hash_train) |
|
test_fingerprint = md5(to_hash_test) |
|
|
|
dataset = dataset.train_test_split( |
|
test_size=cfg.val_set_size, |
|
shuffle=False, |
|
seed=cfg.seed or 42, |
|
train_new_fingerprint=train_fingerprint, |
|
test_new_fingerprint=test_fingerprint, |
|
) |
|
|
|
train_dataset = dataset["train"] |
|
eval_dataset = dataset["test"] |
|
elif split == "test": |
|
train_dataset = None |
|
eval_dataset = dataset |
|
else: |
|
train_dataset = dataset |
|
eval_dataset = None |
|
|
|
return train_dataset, eval_dataset, prompters |
|
|
|
|
|
def get_dataset_wrapper( |
|
config_dataset, |
|
tokenizer, |
|
cfg, |
|
d_base_type, |
|
dataset, |
|
d_prompt_style=None, |
|
): |
|
dataset_wrapper = None |
|
dataset_prompter = None |
|
|
|
ds_kwargs = { |
|
"process_count": cfg.dataset_processes, |
|
"keep_in_memory": cfg.dataset_keep_in_memory is True, |
|
} |
|
|
|
if ( |
|
isinstance(dataset, Dataset) |
|
and "input_ids" in dataset.features |
|
and "attention_mask" in dataset.features |
|
and "labels" in dataset.features |
|
): |
|
|
|
dataset_prompter = UnsupportedPrompter() |
|
dataset_wrapper = dataset |
|
elif isinstance(config_dataset.type, DictDefault): |
|
ds_strategy = load( |
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict() |
|
) |
|
dataset_prompter = UnsupportedPrompter() |
|
dataset_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): |
|
dataset_prompter = UnsupportedPrompter() |
|
dataset_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
elif d_base_type == "alpaca": |
|
dataset_prompter = AlpacaPrompter(d_prompt_style) |
|
ds_strategy = AlpacaPromptTokenizingStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
elif d_base_type == "explainchoice": |
|
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) |
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
elif d_base_type == "concisechoice": |
|
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) |
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
elif d_base_type == "summarizetldr": |
|
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) |
|
ds_strategy = SummarizeTLDRPromptTokenizingStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
elif d_base_type == "jeopardy": |
|
dataset_prompter = JeopardyPrompter(d_prompt_style) |
|
ds_strategy = JeopardyPromptTokenizingStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
elif d_base_type == "oasst": |
|
dataset_prompter = AlpacaPrompter(d_prompt_style) |
|
ds_strategy = OpenAssistantPromptTokenizingStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
elif d_base_type == "gpteacher": |
|
dataset_prompter = GPTeacherPrompter(d_prompt_style) |
|
ds_strategy = GPTeacherPromptTokenizingStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
elif d_base_type == "reflection": |
|
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) |
|
ds_strategy = AlpacaReflectionPTStrategy( |
|
dataset_prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset( |
|
ds_strategy, |
|
dataset, |
|
**ds_kwargs, |
|
) |
|
dataset_wrapper = ds_wrapper |
|
else: |
|
suffix = "" |
|
if ":load_" in config_dataset.type: |
|
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" |
|
LOG.error( |
|
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" |
|
) |
|
raise ValueError( |
|
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" |
|
) |
|
|
|
return dataset_wrapper, dataset_prompter |
|
|
|
|
|
def encode_pretraining( |
|
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] |
|
) -> Dict[str, List]: |
|
res = tokenizer( |
|
examples, |
|
truncation=True, |
|
max_length=max_tokens - 2, |
|
add_special_tokens=True, |
|
) |
|
|
|
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] |
|
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] |
|
new_input_ids = [] |
|
new_attention_mask = [] |
|
|
|
for i, _ in enumerate(input_ids): |
|
input_ids[i] = torch.cat( |
|
( |
|
input_ids[i], |
|
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), |
|
), |
|
dim=0, |
|
) |
|
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) |
|
|
|
|
|
buffer_input_ids = torch.tensor([], dtype=torch.long) |
|
buffer_attention_mask = torch.tensor([], dtype=torch.long) |
|
|
|
for ids, mask in zip(input_ids, attention_mask): |
|
if buffer_input_ids.numel() == max_tokens: |
|
new_input_ids.append(buffer_input_ids) |
|
new_attention_mask.append(buffer_attention_mask) |
|
buffer_input_ids = torch.tensor([], dtype=torch.long) |
|
buffer_attention_mask = torch.tensor([], dtype=torch.long) |
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) |
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) |
|
elif buffer_input_ids.numel() + ids.numel() <= max_tokens: |
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) |
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) |
|
else: |
|
buffer_input_ids = torch.cat( |
|
( |
|
buffer_input_ids, |
|
torch.full( |
|
(max_tokens - buffer_input_ids.numel(),), |
|
tokenizer.pad_token_id, |
|
dtype=torch.long, |
|
), |
|
), |
|
dim=0, |
|
) |
|
buffer_attention_mask = torch.cat( |
|
( |
|
buffer_attention_mask, |
|
torch.full( |
|
(max_tokens - buffer_attention_mask.numel(),), |
|
0, |
|
dtype=torch.long, |
|
), |
|
), |
|
dim=0, |
|
) |
|
new_input_ids.append(buffer_input_ids) |
|
new_attention_mask.append(buffer_attention_mask) |
|
buffer_input_ids = torch.tensor([], dtype=torch.long) |
|
buffer_attention_mask = torch.tensor([], dtype=torch.long) |
|
|
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) |
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) |
|
|
|
if buffer_input_ids.numel() > 0: |
|
while buffer_input_ids.numel() < max_tokens: |
|
buffer_input_ids = torch.cat( |
|
( |
|
buffer_input_ids, |
|
torch.full( |
|
(max_tokens - buffer_input_ids.numel(),), |
|
tokenizer.pad_token_id, |
|
dtype=torch.long, |
|
), |
|
), |
|
dim=0, |
|
) |
|
buffer_attention_mask = torch.cat( |
|
( |
|
buffer_attention_mask, |
|
torch.full( |
|
(max_tokens - buffer_attention_mask.numel(),), |
|
0, |
|
dtype=torch.long, |
|
), |
|
), |
|
dim=0, |
|
) |
|
new_input_ids.append(buffer_input_ids) |
|
new_attention_mask.append(buffer_attention_mask) |
|
|
|
ret = { |
|
"input_ids": [seq.tolist() for seq in new_input_ids], |
|
"labels": [seq.tolist() for seq in new_input_ids], |
|
"attention_mask": [seq.tolist() for seq in new_attention_mask], |
|
} |
|
|
|
LOG.debug(len(ret["input_ids"])) |
|
return ret |
|
|
|
|
|
def wrap_pretraining_dataset( |
|
dataset, |
|
tokenizer, |
|
cfg, |
|
ds_wrapper_fn, |
|
max_tokens=2048, |
|
batch_size=1, |
|
seed=42, |
|
buffer_size=10_000, |
|
): |
|
if cfg.sample_packing: |
|
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( |
|
tokenizer, |
|
return_tensors="pt", |
|
padding=True, |
|
pad_to_multiple_of=max_tokens * batch_size, |
|
) |
|
encode = functools.partial( |
|
encode_packed_pretraining, |
|
collate_fn, |
|
ds_wrapper_fn, |
|
max_seq_length=max_tokens, |
|
batch_size=batch_size, |
|
) |
|
|
|
cfg.micro_batch_size = 1 |
|
else: |
|
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) |
|
|
|
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) |
|
dataset = dataset.map( |
|
encode, |
|
batched=True, |
|
batch_size=buffer_size, |
|
|
|
|
|
|
|
remove_columns=dataset.features.keys(), |
|
) |
|
return dataset |
|
|
|
|
|
def encode_packed_pretraining( |
|
collate_fn, |
|
ds_wrapper: Callable, |
|
examples: Dict[str, List], |
|
max_seq_length: int = 2048, |
|
batch_size: int = 4, |
|
) -> Dict[str, List]: |
|
|
|
|
|
|
|
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] |
|
|
|
train_dataset = process_pretraining_datasets_for_packing( |
|
train_dataset, max_seq_length |
|
) |
|
|
|
sampler = MultipackBatchSampler( |
|
RandomSampler(train_dataset), |
|
batch_size=1, |
|
drop_last=True, |
|
batch_max_len=batch_size * max_seq_length, |
|
lengths=get_dataset_lengths(train_dataset), |
|
) |
|
|
|
chunked_data = defaultdict(list) |
|
|
|
for batch in sampler: |
|
for data in batch: |
|
features = train_dataset[data] |
|
if "num_truncated_tokens" in features: |
|
del features["num_truncated_tokens"] |
|
if "num_truncated_tokens" in features: |
|
del features["num_truncated_tokens"] |
|
if "overflow_to_sample_mapping" in features: |
|
del features["overflow_to_sample_mapping"] |
|
if "labels" not in features: |
|
features["labels"] = features["input_ids"].copy() |
|
collated_features = collate_fn(features) |
|
|
|
for feature in features.keys(): |
|
if feature == "length": |
|
continue |
|
chunked_data[feature].append(collated_features[feature].squeeze(0)) |
|
|
|
return chunked_data |
|
|
|
|
|
def _get_path(ds_hash, cfg): |
|
prepared_ds_path = ( |
|
Path(cfg.dataset_prepared_path) / ds_hash |
|
if cfg.dataset_prepared_path |
|
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash |
|
) |
|
|
|
return prepared_ds_path |
|
|
|
|
|
def _load_preprocessed_ds(cfg, sub_cfg): |
|
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) |
|
prepared_ds_path = _get_path(ds_hash, cfg) |
|
dataset = None |
|
|
|
if ( |
|
cfg.dataset_prepared_path |
|
and any(prepared_ds_path.glob("*")) |
|
and not cfg.is_preprocess |
|
): |
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") |
|
dataset = load_from_disk(str(prepared_ds_path)) |
|
|
|
return dataset |
|
|
|
|
|
def _save_preprocessed_ds(cfg, sub_cfg, dataset): |
|
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) |
|
prepared_ds_path = _get_path(ds_hash, cfg) |
|
|
|
if cfg.is_preprocess and is_main_process(): |
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") |
|
dataset.save_to_disk(str(prepared_ds_path)) |
|
|
|
|
|
def load_prepare_dpo_datasets(cfg): |
|
def load_split(dataset_cfgs, _cfg): |
|
split_datasets: List[Any] = [] |
|
for i, ds_cfg in enumerate(dataset_cfgs): |
|
if ds_cfg["ds_type"] == "json": |
|
for data_file in ds_cfg["data_files"]: |
|
data_files = {ds_cfg["split"]: data_file} |
|
ds = load_dataset( |
|
"json", |
|
data_files=data_files, |
|
split=ds_cfg["split"], |
|
) |
|
split_datasets.insert(i, ds) |
|
else: |
|
ds = load_dataset( |
|
ds_cfg["path"], |
|
split=ds_cfg["split"], |
|
) |
|
split_datasets.insert(i, ds) |
|
|
|
for i, data_set in enumerate(split_datasets): |
|
_type = dataset_cfgs[i]["type"] |
|
if _type: |
|
if isinstance(_type, DictDefault): |
|
_type = "user_defined.default" |
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) |
|
split_datasets[i] = data_set.map( |
|
ds_transform_fn, |
|
desc="Mapping RL Dataset", |
|
) |
|
else: |
|
|
|
|
|
split_datasets[i] = data_set |
|
|
|
return concatenate_datasets(split_datasets) |
|
|
|
with zero_first(is_main_process()): |
|
train_is_preprocessed = False |
|
eval_is_preprocessed = False |
|
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): |
|
train_is_preprocessed = True |
|
else: |
|
train_dataset = load_split(cfg.datasets, cfg) |
|
|
|
eval_dataset = None |
|
if cfg.test_datasets: |
|
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): |
|
eval_is_preprocessed = True |
|
else: |
|
eval_dataset = load_split(cfg.test_datasets, cfg) |
|
if not eval_dataset: |
|
eval_dataset = None |
|
|
|
if not train_is_preprocessed: |
|
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset) |
|
if eval_dataset and not eval_is_preprocessed: |
|
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) |
|
|
|
return train_dataset, eval_dataset |
|
|