|
""" |
|
test module for the axolotl.utis.data module |
|
""" |
|
import unittest |
|
|
|
from transformers import LlamaTokenizer |
|
|
|
from axolotl.utils.data import encode_pretraining, md5 |
|
|
|
|
|
class TestEncodePretraining(unittest.TestCase): |
|
""" |
|
test class for encode pretraining and md5 helper |
|
""" |
|
|
|
def setUp(self): |
|
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") |
|
self.tokenizer.add_special_tokens( |
|
{ |
|
"eos_token": "</s>", |
|
"bos_token": "<s>", |
|
"unk_token": "<unk>", |
|
"pad_token": "<pad>", |
|
} |
|
) |
|
self.max_tokens = 15 |
|
|
|
def test_encode_pretraining(self): |
|
examples = { |
|
"text": [ |
|
"Hello, world!", |
|
"Nice to meet you.", |
|
"lorem ipsum dolor sit amet.", |
|
"Nice to meet you again!.", |
|
"hello, hello", |
|
] |
|
} |
|
result = encode_pretraining(self.tokenizer, self.max_tokens, examples) |
|
|
|
self.assertEqual(len(result["input_ids"]), 3) |
|
|
|
|
|
self.assertEqual(len(result["input_ids"][0]), self.max_tokens) |
|
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens) |
|
|
|
|
|
|
|
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id) |
|
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id) |
|
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id) |
|
|
|
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id) |
|
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id) |
|
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id) |
|
|
|
def test_md5(self): |
|
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3") |
|
self.assertEqual( |
|
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|