""" Test module for sharegpt integration w chatml """ import pytest from datasets import Dataset from tokenizers import AddedToken from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.sharegpt import ( SimpleShareGPTPromptTokenizingStrategy, register_chatml_template, ) from axolotl.prompters import ShareGPTPrompterV2 register_chatml_template() @pytest.fixture(name="sharegpt_dataset") def fixture_sharegpt_dataset(): return Dataset.from_list( [ { "conversations": [ { "from": "system", "value": "repeat", }, { "from": "human", "value": "hello", }, { "from": "gpt", "value": "hello", }, { "from": "human", "value": "goodbye", }, { "from": "gpt", "value": "goodbye", }, ] } ] ) @pytest.fixture(name="tokenizer") def fixture_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") tokenizer.add_special_tokens( { "eos_token": AddedToken( "<|im_end|>", rstrip=False, lstrip=False, normalized=False ) } ) tokenizer.add_tokens( [ AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), ] ) return tokenizer class TestSharegpt: """ Test class for sharegpt prompter """ def test_no_double_im_end(self, sharegpt_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation="chatml", role_key_model=None, role_key_human=None, ), tokenizer, False, # train_on_inputs 2048, # sequence_len ) dataset_wrapper = TokenizedPromptDataset( strategy, sharegpt_dataset, process_count=1 ) input_ids = dataset_wrapper[0]["input_ids"] # fmt: off assert input_ids == [ # 28705, 13, is " \n" 1, # bos 32001, 1587, 13, 25997, 32000, 28705, 13, # system 32001, 2188, 13, 21558, 32000, 28705, 13, # human 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt ] # fmt: on def test_w_train_on_input(self, sharegpt_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation="chatml", role_key_model=None, role_key_human=None, ), tokenizer, False, # train_on_inputs 2048, # sequence_len ) dataset_wrapper = TokenizedPromptDataset( strategy, sharegpt_dataset, process_count=1 ) labels = dataset_wrapper[0]["labels"] # fmt: off assert labels == [ -100, # bos -100, -100, -100, -100, -100, -100, -100, # system -100, -100, -100, -100, -100, -100, -100, # human -100, -100, 13, 21558, 32000, 28705, 13, # gpt -100, -100, -100, -100, -100, -100, -100, -100, # human -100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt ] # fmt: on def test_no_train_on_input(self, sharegpt_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation="chatml", role_key_model=None, role_key_human=None, ), tokenizer, True, # train_on_inputs 2048, # sequence_len ) dataset_wrapper = TokenizedPromptDataset( strategy, sharegpt_dataset, process_count=1 ) labels = dataset_wrapper[0]["labels"] # fmt: off assert labels == [ 1, # bos 32001, 1587, 13, 25997, 32000, 28705, 13, # system 32001, 2188, 13, 21558, 32000, 28705, 13, # human 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt ] # fmt: on