|
import unittest |
|
from Andromeda.dataset_builder import DatasetBuilder |
|
|
|
class TestDatasetBuilder(unittest.TestCase): |
|
def setUp(self): |
|
self.builder = DatasetBuilder(dataset_name="tiiuae/falcon-refinedweb") |
|
|
|
def test_initialization(self): |
|
self.assertEqual(self.builder.dataset_name, "tiiuae/falcon-refinedweb", "Dataset name is not correctly set.") |
|
self.assertEqual(self.builder.seq_len, 8192, "Sequence length is not correctly set.") |
|
self.assertEqual(self.builder.tokenizer, "EleutherAI/gpt-neox-20b", "Tokenizer is not correctly set.") |
|
|
|
def test_build_dataset(self): |
|
dataset = self.builder.build_dataset() |
|
self.assertIsNotNone(dataset, "Dataset is not built.") |
|
self.assertTrue(hasattr(dataset, "map"), "Dataset does not have a map method.") |
|
|
|
def test_tokenize_function(self): |
|
example = {"text": ["Hello, world!", "Andromeda is great."]} |
|
tokenized_example = self.builder.tokenize_function(example) |
|
self.assertIsInstance(tokenized_example, dict, "Tokenized example is not a dictionary.") |
|
self.assertTrue(all(isinstance(t, list) for t in tokenized_example.values()), "Tokenized example values are not lists.") |
|
|
|
def test_group_texts(self): |
|
examples = {"input_ids": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] * 10} |
|
grouped_examples = self.builder.group_texts(examples) |
|
self.assertIsInstance(grouped_examples, dict, "Grouped examples is not a dictionary.") |
|
self.assertTrue(all(isinstance(t, list) for t in grouped_examples.values()), "Grouped example values are not lists.") |
|
self.assertTrue(all(len(t) == self.builder.seq_len for t in grouped_examples["input_ids"]), "Grouped example sequences are not the correct length.") |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |