"""Prepares the datasets for calibration. Original code gently shared by TheBloke""" from abc import ABC import time from typing import Dict, List, Optional from datasets import load_dataset, Dataset from transformers import PreTrainedTokenizerBase class CalibrationDataset(ABC): tokenizer: Optional[PreTrainedTokenizerBase] = None num_samples: int = 128 seqlen: int = 4096 dataset_config: dict dataset: str dataset_name: str dataset_limit: int = int(1e7) # Defines the field to extract from the HF dataset # If specified, just this field will be returned, and no transformation will be done. dataset_field: Optional[str] = None # Define the default parameters for a dataset which requires a transformation # Only used if dataset_field is None. # The fields to extract from the original dataset transform_fields: List[str] = [] # A format string describing how the fields should be joined # Can use {field1}, {field2}, etc. as placeholders for the field names # Or can use actual names, eg "{input} {output}" transform_join: str = "{field1} {field2}" # Optional override for the dataset URL # By default this is automatically derived from the dataset name and config dataset_url: Optional[str] = None data: Optional[Dataset] = None samples: List[str] = [] tokenized_samples: List[Dict[str, str]] = {} randomize: bool = False randomize_seed: int = 42 def __init__( self, num_samples: int = 128, seqlen: int = 4096, tokenizer: Optional[PreTrainedTokenizerBase] = None ): self.num_samples = num_samples self.seqlen = seqlen self.tokenizer = tokenizer @classmethod def get_dataset(cls, dataset_name, **kwargs): for subclass in cls.__subclasses__(): if hasattr(subclass, "dataset") and subclass.dataset == dataset_name: return subclass(**kwargs) raise ValueError(f"No dataset class found for name: {dataset_name}") def tokenize_dataset(self, samples: Optional[List[str]] = None) -> List[Dict[str, int]]: """ Tokenize the dataset and return a list of tokens of `seqlen` length First tokenize the List[str] of samples, as a batch. Then flatten the batch, and split it into `num_samples` rows of `seqlen` length. """ if not self.tokenizer: raise ValueError("No tokenizer provided to tokenize_dataset()") else: if not samples: if not self.samples: self.get_samples() samples = self.samples print(f"Tokenizing {self.dataset_name} of length {len(samples)}") start_time = time.time() # Tokenize the list of samples. We don't use return_tensors="pt", # as that requires the samples to be the same length, or padding to be used. tokenized = self.tokenizer(samples) # Output of tokenizer will be: # {"input_ids": [[1,2,3], [4,5], [6,7]], "attention_mask": [[1,1,1], [1,1], [1,1]]} # Flatten that so as to concatenate the samples into a single input_mask and attention_mask flattened = { key: [ item for sublist in value for item in sublist ] for key, value in tokenized.items() } print( f"Tokenized length: {len(flattened['input_ids'])} tokens." ) # Slice our single input_mask list into num_samples samples of seqlen length tokenized_samples = [] for i in range(0, self.num_samples * self.seqlen, self.seqlen): if i + self.seqlen >= len(flattened["input_ids"]): break sample = { "input_ids": flattened["input_ids"][i:i + self.seqlen], "attention_mask": flattened["attention_mask"][i:i + self.seqlen] } tokenized_samples.append(sample) print( f"Return {len(tokenized_samples)} samples of {self.seqlen} length. " f"Time taken: {time.time() - start_time:.2f}s." ) self.tokenized_samples = tokenized_samples return self.tokenized_samples def get_hf_dataset( self, path: str, limit: Optional[int] = None, **kwargs ) -> Dataset: """Load the Hugging Face dataset at `path`, using the provided kwargs.""" print(f"Loading HF dataset {path} with params: {kwargs}") data: Dataset = load_dataset(path=path, **kwargs) limit = limit and min(limit, len(data)) or len(data) return data.select(range(limit)) @staticmethod def list_with_nls(samples: List[str]) -> List[str]: """ Return a List[str] with each sample ending in a newline. Also filters the list by stripping, then removing any empty samples. """ return [ x.rstrip() + '\n' for x in samples if x and len(x.strip()) > 0 ] def get_samples(self) -> List[str]: """ Return a list of samples for the dataset. If the subclass implements `dataset_field`, this is used to filter the HF Dataset. Otherwise, the subclass must implement `process_samples()`, for custom filtering. Samples are returned as a List[str], each ending in a newline. """ # Load HF dataset. Subclasses provide HF dataset details in `dataset_config` if not self.data: self.data = self.get_hf_dataset(**self.dataset_config, limit=self.dataset_limit) if not self.samples: if hasattr(self, "dataset_field") and self.dataset_field: samples = self.data[self.dataset_field] else: try: samples = self.process_samples() except NotImplementedError: raise ValueError( f"No dataset field specified for class {self.__class__}, " f"and process_samples() method not defined." ) if self.randomize: import random random.seed(self.randomize_seed) random.shuffle(samples) self.samples = self.list_with_nls(samples) return self.samples def process_samples(self) -> List[str]: if not self.transform_fields or not isinstance(self.transform_fields, list): raise ValueError("transform_fields must be a List[str], defined in the subclass") if not self.transform_join or not isinstance(self.transform_join, str): raise ValueError("transform_fields must be a str defined in the subclass") def transform_sample(sample): field_values = {field: sample[field] for field in self.transform_fields} # We support both: # generic numbered fields: "{field1} {field2}" # and named fields: "{input} {output}" # Creating a combined dictionary to handle both specific field names and generic placeholders combined_dict = {**field_values, **{f'field{i+1}': field for i, field in enumerate(field_values.values())}} output = self.transform_join.format_map(combined_dict) return {"output": output} return self.data.map(transform_sample)["output"] def generate_checksum(self) -> str: # Create a sha256sum checksum of the joined samples # Can be used to confirm that code updates haven't changed the output import hashlib samples = self.get_samples() combined_samples = ''.join(samples) checksum = hashlib.sha256(combined_samples.encode()).hexdigest() return checksum @classmethod def get_dataset_url(cls) -> str: """Return the Hugging Face dataset URL for this dataset.""" if hasattr(cls, "dataset_url") and cls.dataset_url: return cls.dataset_url else: return "https://huggingface.co/datasets/{}/viewer/{}".format( cls.dataset_config["path"], cls.dataset_config.get("name", "") ) class WikitextDataset(CalibrationDataset): dataset = "wikitext" dataset_config = { "path": "wikitext", "name": "wikitext-2-raw-v1", "split": "train" } dataset_name = "Wikitext2 Full" def process_samples(self) -> List[str]: return [ "\n" if len(item) == 0 else item for item in self.data["text"] ] class C4Dataset(CalibrationDataset): dataset = "c4" dataset_field = "text" dataset_config = { "path": "allenai/c4", "data_files": { "train": "en/c4-train.00000-of-01024.json.gz" }, "split": "train" } dataset_name = "C4" class ThaiDataset(CalibrationDataset): dataset = "thai" dataset_field = "text" dataset_config = { "path": "pbwt/all-thai", "data_files": { "train": "data/train-00000-of-00047-985fbaed08d034cf.parquet" }, "split": "train" } dataset_name = "All Thai" class MovieScriptDataset(CalibrationDataset): dataset = "movie-scripts" dataset_field = "full_script" dataset_config = { "path": "jondurbin/cinematika-v0.1", "data_files": { "train": "full_script.parquet" }, "split": "train" } dataset_name = "Cinematika Full Scripts" class JapaneseEnglishDataset(CalibrationDataset): dataset = "japanese-english" dataset_config = { "path": "augmxnt/shisa-en-ja-dpo-v1", "split": "train" } dataset_name = "Shisa English Japanese DPO" randomize = True def process_samples(self) -> List[str]: def transform_samples(sample): prompt = sample["prompt"] chosen = sample["chosen"] # prompt example: "[INST] <>\nYou are a helpful, unbiased, uncensored assistant.\n<>\n\nWhat are cardigans made of? Leather or wood? [/INST]" try: part1 = prompt.split('\n<>\n\n')[1] extracted_text = part1.split(' [/INST]')[0] except Exception as e: print(f"Error extracting text from prompt '{prompt}': {e}") raise prompt = extracted_text return {"output": f"{prompt} {chosen}"} return self.data.map(transform_samples)["output"] class PortugueseDataset(CalibrationDataset): dataset = "portuguese" dataset_config = { "path": "adalbertojunior/portuguese_orca", "split": "train" } dataset_name = "Portuguese Orca" transform_fields = [ "question", "response" ] class MathsDataset(CalibrationDataset): dataset = "maths" dataset_config = { "path": "andersonbcdefg/math", "split": "train" } dataset_name = "CamelAI Math" transform_fields = [ "message_1", "message_2" ] class MedicalDataset(CalibrationDataset): dataset = "medical" dataset_config = { "path": "medalpaca/medical_meadow_wikidoc", "split": "train" } dataset_name = "Medical Medaow WikiDoc" transform_fields = [ "input", "output" ] class OpenInstructDataset(CalibrationDataset): dataset = "open-instruct" dataset_config = { "path": "VMware/open-instruct", "split": "train" } dataset_name = "VMware Open Instruct" transform_fields = [ "instruction", "response" ] class KoreanDataset(CalibrationDataset): dataset = "korean" dataset_config = { "path": "beomi/KoAlpaca-v1.1a", "split": "train" } dataset_name = "Korean Alpaca" transform_fields = [ "instruction", "output" ] class CodeDataset(CalibrationDataset): dataset = "code" dataset_field = "output" dataset_config = { "path": "nickrosh/Evol-Instruct-Code-80k-v1", "split": "train" } dataset_name = "Evol Instruct Code" class MultiLanguageDataset(CalibrationDataset): dataset = "multi-language" dataset_field = "text" dataset_config = { "path": "papluca/language-identification", "split": "train" } dataset_name = "Language Identification" class RussianDataset(CalibrationDataset): dataset = "russian" dataset_config = { "path": "Den4ikAI/russian_instructions_2", "split": "train" } dataset_name = "Russian Instructions 2" transform_fields = [ "question", "answer" ] class DutchDataset(CalibrationDataset): dataset = "dutch" dataset_config = { "path": "BramVanroy/dolly-15k-dutch", "split": "train" } dataset_name = "Dolly 15K Dutch" transform_fields = [ "instruction", "context", "response" ] transform_join = "{field1} {field2} {field3}" class VietnameseChineseDataset(CalibrationDataset): dataset = "vietnamesechinese" dataset_config = { "path": "nRuaif/Vietnamese_x_Alpaca", "split": "train" } dataset_name = "Vietnamese and Chinese" def get_dataset_url(self) -> None: return None def process_samples(self) -> List[str]: samples = self.data["output"] chinese_samples = CalibrationDataset.get_dataset("chinese").get_samples() joined_list = samples + chinese_samples import random random.shuffle(joined_list) return joined_list[:self.dataset_limit] class VietnameseDataset(CalibrationDataset): dataset = "vietnamese" dataset_field = "output" dataset_config = { "path": "nRuaif/Vietnamese_x_Alpaca", "split": "train" } dataset_name = "Alpaca Vietnamese" class ChineseDataset(CalibrationDataset): dataset = "chinese" dataset_config = { "path": "TigerResearch/tigerbot-alpaca-zh-0.5m", "split": "train" } dataset_name = "Tiger Alpaca ZH" transform_fields = [ "instruction", "input", "output" ] transform_join = "{field1} {field2} {field3}" class LatinEnglishDataset(CalibrationDataset): dataset = "latin-english" dataset_config = { "path": "grosenthal/latin_english_parallel", "split": "train" } dataset_name = "Latin English Parallel" transform_fields = [ "la", "en" ] transform_join = "{field1}\n{field2}" class PolishDataset(CalibrationDataset): dataset = "polish" dataset_field = "content" dataset_config = { "path": "WiktorS/polish-news", "split": "train" } dataset_name = "Polish News" class JapaneseDataset(CalibrationDataset): dataset = "japanese" dataset_field = "output" dataset_config = { "path": "fujiki/japanese_alpaca_data", "split": "train" } dataset_name = "Alpaca Japanese" class SpanishDataset(CalibrationDataset): dataset = "spanish" dataset_field = "output" dataset_config = { "path": "bertin-project/alpaca-spanish", "split": "train" } dataset_name = "Alpaca Spanish" class GermanDataset(CalibrationDataset): dataset = "german" dataset_config = { "path": "deepset/germanquad", "split": "train" } dataset_name = "German Quad" def process_samples(self) -> List[str]: def transform_samples(sample): split_context = sample["context"].split("===") if len(split_context) >= 3: trans_context = split_context[2] else: trans_context = sample["context"] return {"output": trans_context.strip()} return self.data.map(transform_samples)["output"] class FrenchDataset(CalibrationDataset): dataset = "french" dataset_field = "text" dataset_config = { "path": "Kant1/French_Wikipedia_articles", "data_files": { "wiki_00.txt" }, "split": "train" } dataset_name = "French Wikipedia Articles" def validate_dataset(dataset_name: str, **kwargs): for cls in CalibrationDataset.__subclasses__(): if hasattr(cls, "dataset") and cls.dataset == dataset_name: return True return False # FIXME: a temp function put in for AutoAWQ, pending full refactor where it won't be necessary def get_dataset_url(dataset_name: str): for cls in CalibrationDataset.__subclasses__(): if hasattr(cls, "dataset") and cls.dataset == dataset_name: return cls.get_dataset_url() raise ValueError(f"No dataset class found for name: {dataset_name}") def get_dataset_name(dataset_name: str): for cls in CalibrationDataset.__subclasses__(): if hasattr(cls, "dataset") and cls.dataset == dataset_name: return cls.dataset_name raise ValueError(f"No dataset class found for name: {dataset_name}") def test_datasets(datasets: Optional[List[str]] = None, checksum_only=False): import sys from transformers import AutoTokenizer try: failed = [] for cls in CalibrationDataset.__subclasses__(): if not hasattr(cls, "dataset") or not cls.dataset: failed.append(cls.__name__) if failed: print(f"The following classes have no 'dataset' attribute: {failed}") sys.exit(-1) else: print()(f"All classes have 'dataset' attribute.") print(f"Enumerating CalibrationDataset classes") classes = CalibrationDataset.__subclasses__() dataset_names = [ cls.dataset for cls in classes if cls.dataset and (not datasets or cls.dataset in datasets) ] print(f"Found {len(classes)} total dataset classes: {[c.dataset for c in classes]}") if datasets: print(f"Will test {len(dataset_names)} datasets: {dataset_names}") print(f"Starting test: loading Llama-2 tokenizer") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True) for name in dataset_names: print(f"{name} test: loading dataset.") dataset = CalibrationDataset.get_dataset(name, tokenizer=tokenizer) if not checksum_only: print(f"{name} test: running tokenize_dataset.") toks = dataset.tokenize_dataset() print(f"{name} test: getting dataset_url.") url = dataset.get_dataset_url() print(f"{name} - randomized? {dataset.randomize}") print( f"{name} - result: cls.data: length: {len(dataset.data)}, " f"first row length: {len(dataset.data[0])}, " f"first row data: '{dataset.data[0]}'." ) print( f"{name} - result: cls.samples: length: {len(dataset.samples)}, " f"first row length: {len(dataset.samples[0])}, " f"first row sample: '{dataset.samples[0]}'." ) print( f"{name} - result: tokenize_dataset result: length: {len(toks)}, " f"length first row input_ids: {len(toks[0]['input_ids'])}." ) print( f"{name} - result: dataset_url: {url}" ) checksum = dataset.generate_checksum() print( f"{name} - result: sha256 checksum: {checksum}" ) except KeyboardInterrupt: print("Test aborted") except Exception as e: print( f"Received an exception during test. Test failed. " f"Exception: {e}" ) raise if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Test calibration datasets") parser.add_argument("--datasets", "-d", "-n", nargs="*", type=str, help="Dataset(s) to check; default is all") parser.add_argument("--checksum_only", "-co", action="store_true", help="Only ouput the checksums for the datasets") args = parser.parse_args() test_datasets(args.datasets, checksum_only=args.checksum_only)