Spaces:
Paused
Paused
"""Prepares the datasets for calibration. Original code gently shared by TheBloke""" | |
from abc import ABC | |
import time | |
from typing import Dict, List, Optional | |
from datasets import load_dataset, Dataset | |
from transformers import PreTrainedTokenizerBase | |
class CalibrationDataset(ABC): | |
tokenizer: Optional[PreTrainedTokenizerBase] = None | |
num_samples: int = 128 | |
seqlen: int = 4096 | |
dataset_config: dict | |
dataset: str | |
dataset_name: str | |
dataset_limit: int = int(1e7) | |
# Defines the field to extract from the HF dataset | |
# If specified, just this field will be returned, and no transformation will be done. | |
dataset_field: Optional[str] = None | |
# Define the default parameters for a dataset which requires a transformation | |
# Only used if dataset_field is None. | |
# The fields to extract from the original dataset | |
transform_fields: List[str] = [] | |
# A format string describing how the fields should be joined | |
# Can use {field1}, {field2}, etc. as placeholders for the field names | |
# Or can use actual names, eg "{input} {output}" | |
transform_join: str = "{field1} {field2}" | |
# Optional override for the dataset URL | |
# By default this is automatically derived from the dataset name and config | |
dataset_url: Optional[str] = None | |
data: Optional[Dataset] = None | |
samples: List[str] = [] | |
tokenized_samples: List[Dict[str, str]] = {} | |
randomize: bool = False | |
randomize_seed: int = 42 | |
def __init__( | |
self, | |
num_samples: int = 128, | |
seqlen: int = 4096, | |
tokenizer: Optional[PreTrainedTokenizerBase] = None | |
): | |
self.num_samples = num_samples | |
self.seqlen = seqlen | |
self.tokenizer = tokenizer | |
def get_dataset(cls, dataset_name, **kwargs): | |
for subclass in cls.__subclasses__(): | |
if hasattr(subclass, "dataset") and subclass.dataset == dataset_name: | |
return subclass(**kwargs) | |
raise ValueError(f"No dataset class found for name: {dataset_name}") | |
def tokenize_dataset(self, samples: Optional[List[str]] = None) -> List[Dict[str, int]]: | |
""" | |
Tokenize the dataset and return a list of tokens of `seqlen` length | |
First tokenize the List[str] of samples, as a batch. | |
Then flatten the batch, and split it into `num_samples` rows of `seqlen` length. | |
""" | |
if not self.tokenizer: | |
raise ValueError("No tokenizer provided to tokenize_dataset()") | |
else: | |
if not samples: | |
if not self.samples: | |
self.get_samples() | |
samples = self.samples | |
print(f"Tokenizing {self.dataset_name} of length {len(samples)}") | |
start_time = time.time() | |
# Tokenize the list of samples. We don't use return_tensors="pt", | |
# as that requires the samples to be the same length, or padding to be used. | |
tokenized = self.tokenizer(samples) | |
# Output of tokenizer will be: | |
# {"input_ids": [[1,2,3], [4,5], [6,7]], "attention_mask": [[1,1,1], [1,1], [1,1]]} | |
# Flatten that so as to concatenate the samples into a single input_mask and attention_mask | |
flattened = { | |
key: [ | |
item for sublist in value | |
for item in sublist | |
] | |
for key, value in tokenized.items() | |
} | |
print( | |
f"Tokenized length: {len(flattened['input_ids'])} tokens." | |
) | |
# Slice our single input_mask list into num_samples samples of seqlen length | |
tokenized_samples = [] | |
for i in range(0, self.num_samples * self.seqlen, self.seqlen): | |
if i + self.seqlen >= len(flattened["input_ids"]): | |
break | |
sample = { | |
"input_ids": flattened["input_ids"][i:i + self.seqlen], | |
"attention_mask": flattened["attention_mask"][i:i + self.seqlen] | |
} | |
tokenized_samples.append(sample) | |
print( | |
f"Return {len(tokenized_samples)} samples of {self.seqlen} length. " | |
f"Time taken: {time.time() - start_time:.2f}s." | |
) | |
self.tokenized_samples = tokenized_samples | |
return self.tokenized_samples | |
def get_hf_dataset( | |
self, | |
path: str, | |
limit: Optional[int] = None, | |
**kwargs | |
) -> Dataset: | |
"""Load the Hugging Face dataset at `path`, using the provided kwargs.""" | |
print(f"Loading HF dataset {path} with params: {kwargs}") | |
data: Dataset = load_dataset(path=path, **kwargs) | |
limit = limit and min(limit, len(data)) or len(data) | |
return data.select(range(limit)) | |
def list_with_nls(samples: List[str]) -> List[str]: | |
""" | |
Return a List[str] with each sample ending in a newline. | |
Also filters the list by stripping, then removing any empty samples. | |
""" | |
return [ | |
x.rstrip() + '\n' | |
for x in samples | |
if x and len(x.strip()) > 0 | |
] | |
def get_samples(self) -> List[str]: | |
""" | |
Return a list of samples for the dataset. | |
If the subclass implements `dataset_field`, this is used to filter the HF Dataset. | |
Otherwise, the subclass must implement `process_samples()`, for custom filtering. | |
Samples are returned as a List[str], each ending in a newline. | |
""" | |
# Load HF dataset. Subclasses provide HF dataset details in `dataset_config` | |
if not self.data: | |
self.data = self.get_hf_dataset(**self.dataset_config, limit=self.dataset_limit) | |
if not self.samples: | |
if hasattr(self, "dataset_field") and self.dataset_field: | |
samples = self.data[self.dataset_field] | |
else: | |
try: | |
samples = self.process_samples() | |
except NotImplementedError: | |
raise ValueError( | |
f"No dataset field specified for class {self.__class__}, " | |
f"and process_samples() method not defined." | |
) | |
if self.randomize: | |
import random | |
random.seed(self.randomize_seed) | |
random.shuffle(samples) | |
self.samples = self.list_with_nls(samples) | |
return self.samples | |
def process_samples(self) -> List[str]: | |
if not self.transform_fields or not isinstance(self.transform_fields, list): | |
raise ValueError("transform_fields must be a List[str], defined in the subclass") | |
if not self.transform_join or not isinstance(self.transform_join, str): | |
raise ValueError("transform_fields must be a str defined in the subclass") | |
def transform_sample(sample): | |
field_values = {field: sample[field] for field in self.transform_fields} | |
# We support both: | |
# generic numbered fields: "{field1} {field2}" | |
# and named fields: "{input} {output}" | |
# Creating a combined dictionary to handle both specific field names and generic placeholders | |
combined_dict = {**field_values, **{f'field{i+1}': field for i, field in enumerate(field_values.values())}} | |
output = self.transform_join.format_map(combined_dict) | |
return {"output": output} | |
return self.data.map(transform_sample)["output"] | |
def generate_checksum(self) -> str: | |
# Create a sha256sum checksum of the joined samples | |
# Can be used to confirm that code updates haven't changed the output | |
import hashlib | |
samples = self.get_samples() | |
combined_samples = ''.join(samples) | |
checksum = hashlib.sha256(combined_samples.encode()).hexdigest() | |
return checksum | |
def get_dataset_url(cls) -> str: | |
"""Return the Hugging Face dataset URL for this dataset.""" | |
if hasattr(cls, "dataset_url") and cls.dataset_url: | |
return cls.dataset_url | |
else: | |
return "https://huggingface.co/datasets/{}/viewer/{}".format( | |
cls.dataset_config["path"], | |
cls.dataset_config.get("name", "") | |
) | |
class WikitextDataset(CalibrationDataset): | |
dataset = "wikitext" | |
dataset_config = { | |
"path": "wikitext", | |
"name": "wikitext-2-raw-v1", | |
"split": "train" | |
} | |
dataset_name = "Wikitext2 Full" | |
def process_samples(self) -> List[str]: | |
return [ | |
"\n" if len(item) == 0 else item | |
for item in self.data["text"] | |
] | |
class C4Dataset(CalibrationDataset): | |
dataset = "c4" | |
dataset_field = "text" | |
dataset_config = { | |
"path": "allenai/c4", | |
"data_files": { | |
"train": "en/c4-train.00000-of-01024.json.gz" | |
}, | |
"split": "train" | |
} | |
dataset_name = "C4" | |
class ThaiDataset(CalibrationDataset): | |
dataset = "thai" | |
dataset_field = "text" | |
dataset_config = { | |
"path": "pbwt/all-thai", | |
"data_files": { | |
"train": "data/train-00000-of-00047-985fbaed08d034cf.parquet" | |
}, | |
"split": "train" | |
} | |
dataset_name = "All Thai" | |
class MovieScriptDataset(CalibrationDataset): | |
dataset = "movie-scripts" | |
dataset_field = "full_script" | |
dataset_config = { | |
"path": "jondurbin/cinematika-v0.1", | |
"data_files": { "train": "full_script.parquet" }, | |
"split": "train" | |
} | |
dataset_name = "Cinematika Full Scripts" | |
class JapaneseEnglishDataset(CalibrationDataset): | |
dataset = "japanese-english" | |
dataset_config = { | |
"path": "augmxnt/shisa-en-ja-dpo-v1", | |
"split": "train" | |
} | |
dataset_name = "Shisa English Japanese DPO" | |
randomize = True | |
def process_samples(self) -> List[str]: | |
def transform_samples(sample): | |
prompt = sample["prompt"] | |
chosen = sample["chosen"] | |
# prompt example: "[INST] <<SYS>>\nYou are a helpful, unbiased, uncensored assistant.\n<</SYS>>\n\nWhat are cardigans made of? Leather or wood? [/INST]" | |
try: | |
part1 = prompt.split('\n<</SYS>>\n\n')[1] | |
extracted_text = part1.split(' [/INST]')[0] | |
except Exception as e: | |
print(f"Error extracting text from prompt '{prompt}': {e}") | |
raise | |
prompt = extracted_text | |
return {"output": f"{prompt} {chosen}"} | |
return self.data.map(transform_samples)["output"] | |
class PortugueseDataset(CalibrationDataset): | |
dataset = "portuguese" | |
dataset_config = { | |
"path": "adalbertojunior/portuguese_orca", | |
"split": "train" | |
} | |
dataset_name = "Portuguese Orca" | |
transform_fields = [ "question", "response" ] | |
class MathsDataset(CalibrationDataset): | |
dataset = "maths" | |
dataset_config = { | |
"path": "andersonbcdefg/math", | |
"split": "train" | |
} | |
dataset_name = "CamelAI Math" | |
transform_fields = [ "message_1", "message_2" ] | |
class MedicalDataset(CalibrationDataset): | |
dataset = "medical" | |
dataset_config = { | |
"path": "medalpaca/medical_meadow_wikidoc", | |
"split": "train" | |
} | |
dataset_name = "Medical Medaow WikiDoc" | |
transform_fields = [ "input", "output" ] | |
class OpenInstructDataset(CalibrationDataset): | |
dataset = "open-instruct" | |
dataset_config = { | |
"path": "VMware/open-instruct", | |
"split": "train" | |
} | |
dataset_name = "VMware Open Instruct" | |
transform_fields = [ "instruction", "response" ] | |
class KoreanDataset(CalibrationDataset): | |
dataset = "korean" | |
dataset_config = { | |
"path": "beomi/KoAlpaca-v1.1a", | |
"split": "train" | |
} | |
dataset_name = "Korean Alpaca" | |
transform_fields = [ "instruction", "output" ] | |
class CodeDataset(CalibrationDataset): | |
dataset = "code" | |
dataset_field = "output" | |
dataset_config = { | |
"path": "nickrosh/Evol-Instruct-Code-80k-v1", | |
"split": "train" | |
} | |
dataset_name = "Evol Instruct Code" | |
class MultiLanguageDataset(CalibrationDataset): | |
dataset = "multi-language" | |
dataset_field = "text" | |
dataset_config = { | |
"path": "papluca/language-identification", | |
"split": "train" | |
} | |
dataset_name = "Language Identification" | |
class RussianDataset(CalibrationDataset): | |
dataset = "russian" | |
dataset_config = { | |
"path": "Den4ikAI/russian_instructions_2", | |
"split": "train" | |
} | |
dataset_name = "Russian Instructions 2" | |
transform_fields = [ "question", "answer" ] | |
class DutchDataset(CalibrationDataset): | |
dataset = "dutch" | |
dataset_config = { | |
"path": "BramVanroy/dolly-15k-dutch", | |
"split": "train" | |
} | |
dataset_name = "Dolly 15K Dutch" | |
transform_fields = [ "instruction", "context", "response" ] | |
transform_join = "{field1} {field2} {field3}" | |
class VietnameseChineseDataset(CalibrationDataset): | |
dataset = "vietnamesechinese" | |
dataset_config = { | |
"path": "nRuaif/Vietnamese_x_Alpaca", | |
"split": "train" | |
} | |
dataset_name = "Vietnamese and Chinese" | |
def get_dataset_url(self) -> None: | |
return None | |
def process_samples(self) -> List[str]: | |
samples = self.data["output"] | |
chinese_samples = CalibrationDataset.get_dataset("chinese").get_samples() | |
joined_list = samples + chinese_samples | |
import random | |
random.shuffle(joined_list) | |
return joined_list[:self.dataset_limit] | |
class VietnameseDataset(CalibrationDataset): | |
dataset = "vietnamese" | |
dataset_field = "output" | |
dataset_config = { | |
"path": "nRuaif/Vietnamese_x_Alpaca", | |
"split": "train" | |
} | |
dataset_name = "Alpaca Vietnamese" | |
class ChineseDataset(CalibrationDataset): | |
dataset = "chinese" | |
dataset_config = { | |
"path": "TigerResearch/tigerbot-alpaca-zh-0.5m", | |
"split": "train" | |
} | |
dataset_name = "Tiger Alpaca ZH" | |
transform_fields = [ "instruction", "input", "output" ] | |
transform_join = "{field1} {field2} {field3}" | |
class LatinEnglishDataset(CalibrationDataset): | |
dataset = "latin-english" | |
dataset_config = { | |
"path": "grosenthal/latin_english_parallel", | |
"split": "train" | |
} | |
dataset_name = "Latin English Parallel" | |
transform_fields = [ "la", "en" ] | |
transform_join = "{field1}\n{field2}" | |
class PolishDataset(CalibrationDataset): | |
dataset = "polish" | |
dataset_field = "content" | |
dataset_config = { | |
"path": "WiktorS/polish-news", | |
"split": "train" | |
} | |
dataset_name = "Polish News" | |
class JapaneseDataset(CalibrationDataset): | |
dataset = "japanese" | |
dataset_field = "output" | |
dataset_config = { | |
"path": "fujiki/japanese_alpaca_data", | |
"split": "train" | |
} | |
dataset_name = "Alpaca Japanese" | |
class SpanishDataset(CalibrationDataset): | |
dataset = "spanish" | |
dataset_field = "output" | |
dataset_config = { | |
"path": "bertin-project/alpaca-spanish", | |
"split": "train" | |
} | |
dataset_name = "Alpaca Spanish" | |
class GermanDataset(CalibrationDataset): | |
dataset = "german" | |
dataset_config = { | |
"path": "deepset/germanquad", | |
"split": "train" | |
} | |
dataset_name = "German Quad" | |
def process_samples(self) -> List[str]: | |
def transform_samples(sample): | |
split_context = sample["context"].split("===") | |
if len(split_context) >= 3: | |
trans_context = split_context[2] | |
else: | |
trans_context = sample["context"] | |
return {"output": trans_context.strip()} | |
return self.data.map(transform_samples)["output"] | |
class FrenchDataset(CalibrationDataset): | |
dataset = "french" | |
dataset_field = "text" | |
dataset_config = { | |
"path": "Kant1/French_Wikipedia_articles", | |
"data_files": { "wiki_00.txt" }, | |
"split": "train" | |
} | |
dataset_name = "French Wikipedia Articles" | |
def validate_dataset(dataset_name: str, **kwargs): | |
for cls in CalibrationDataset.__subclasses__(): | |
if hasattr(cls, "dataset") and cls.dataset == dataset_name: | |
return True | |
return False | |
# FIXME: a temp function put in for AutoAWQ, pending full refactor where it won't be necessary | |
def get_dataset_url(dataset_name: str): | |
for cls in CalibrationDataset.__subclasses__(): | |
if hasattr(cls, "dataset") and cls.dataset == dataset_name: | |
return cls.get_dataset_url() | |
raise ValueError(f"No dataset class found for name: {dataset_name}") | |
def get_dataset_name(dataset_name: str): | |
for cls in CalibrationDataset.__subclasses__(): | |
if hasattr(cls, "dataset") and cls.dataset == dataset_name: | |
return cls.dataset_name | |
raise ValueError(f"No dataset class found for name: {dataset_name}") | |
def test_datasets(datasets: Optional[List[str]] = None, checksum_only=False): | |
import sys | |
from transformers import AutoTokenizer | |
try: | |
failed = [] | |
for cls in CalibrationDataset.__subclasses__(): | |
if not hasattr(cls, "dataset") or not cls.dataset: | |
failed.append(cls.__name__) | |
if failed: | |
print(f"The following classes have no 'dataset' attribute: {failed}") | |
sys.exit(-1) | |
else: | |
print()(f"All classes have 'dataset' attribute.") | |
print(f"Enumerating CalibrationDataset classes") | |
classes = CalibrationDataset.__subclasses__() | |
dataset_names = [ | |
cls.dataset | |
for cls in classes | |
if cls.dataset and (not datasets or cls.dataset in datasets) | |
] | |
print(f"Found {len(classes)} total dataset classes: {[c.dataset for c in classes]}") | |
if datasets: | |
print(f"Will test {len(dataset_names)} datasets: {dataset_names}") | |
print(f"Starting test: loading Llama-2 tokenizer") | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True) | |
for name in dataset_names: | |
print(f"{name} test: loading dataset.") | |
dataset = CalibrationDataset.get_dataset(name, tokenizer=tokenizer) | |
if not checksum_only: | |
print(f"{name} test: running tokenize_dataset.") | |
toks = dataset.tokenize_dataset() | |
print(f"{name} test: getting dataset_url.") | |
url = dataset.get_dataset_url() | |
print(f"{name} - randomized? {dataset.randomize}") | |
print( | |
f"{name} - result: cls.data: length: {len(dataset.data)}, " | |
f"first row length: {len(dataset.data[0])}, " | |
f"first row data: '{dataset.data[0]}'." | |
) | |
print( | |
f"{name} - result: cls.samples: length: {len(dataset.samples)}, " | |
f"first row length: {len(dataset.samples[0])}, " | |
f"first row sample: '{dataset.samples[0]}'." | |
) | |
print( | |
f"{name} - result: tokenize_dataset result: length: {len(toks)}, " | |
f"length first row input_ids: {len(toks[0]['input_ids'])}." | |
) | |
print( | |
f"{name} - result: dataset_url: {url}" | |
) | |
checksum = dataset.generate_checksum() | |
print( | |
f"{name} - result: sha256 checksum: {checksum}" | |
) | |
except KeyboardInterrupt: | |
print("Test aborted") | |
except Exception as e: | |
print( | |
f"Received an exception during test. Test failed. " | |
f"Exception: {e}" | |
) | |
raise | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Test calibration datasets") | |
parser.add_argument("--datasets", "-d", "-n", nargs="*", type=str, help="Dataset(s) to check; default is all") | |
parser.add_argument("--checksum_only", "-co", action="store_true", help="Only ouput the checksums for the datasets") | |
args = parser.parse_args() | |
test_datasets(args.datasets, checksum_only=args.checksum_only) | |