Feat(test): Add tests for alpaca chatml prompt tokenizer (#1088)
Browse files* draft for adding test for tokenizer
* clean up
* clean up
* fix pre commit
* fix pylint
* Revert "fix pylint"
This reverts commit cd2cda3cdae6f31f6d038a0673c2c7abd8e8e46a.
* add pylint exception for pytest fixture
* update comments
* Apply suggestions from code review
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
* update spelling and import promptstyle
* reaname, restrucure
* clean up
* add fmt:on
---------
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
tests/prompt_strategies/test_alpaca.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Test module for alpaca integration w chatml
|
3 |
+
"""
|
4 |
+
import pytest
|
5 |
+
from datasets import Dataset
|
6 |
+
from tokenizers import AddedToken
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
from axolotl.datasets import TokenizedPromptDataset
|
10 |
+
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
11 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
12 |
+
|
13 |
+
|
14 |
+
@pytest.fixture(name="alpaca_dataset")
|
15 |
+
def fixture_alpaca_dataset():
|
16 |
+
return Dataset.from_list(
|
17 |
+
[
|
18 |
+
{
|
19 |
+
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
20 |
+
"input": "He finnished his meal and left the resturant",
|
21 |
+
"output": "He finished his meal and left the restaurant.",
|
22 |
+
}
|
23 |
+
]
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@pytest.fixture(name="tokenizer")
|
28 |
+
def fixture_tokenizer():
|
29 |
+
# pylint: disable=all
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
31 |
+
tokenizer.add_special_tokens(
|
32 |
+
{
|
33 |
+
"eos_token": AddedToken(
|
34 |
+
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
35 |
+
)
|
36 |
+
}
|
37 |
+
)
|
38 |
+
tokenizer.add_tokens(
|
39 |
+
[
|
40 |
+
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
41 |
+
]
|
42 |
+
)
|
43 |
+
|
44 |
+
return tokenizer
|
45 |
+
|
46 |
+
|
47 |
+
class TestAlpacaChatml:
|
48 |
+
"""
|
49 |
+
Test class for alpaca prompter
|
50 |
+
"""
|
51 |
+
|
52 |
+
def test_no_double_im_end(self, alpaca_dataset, tokenizer):
|
53 |
+
strategy = AlpacaPromptTokenizingStrategy(
|
54 |
+
AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
|
55 |
+
tokenizer,
|
56 |
+
False, # train_on_inputs
|
57 |
+
2048, # sequence_len
|
58 |
+
)
|
59 |
+
|
60 |
+
dataset_wrapper = TokenizedPromptDataset(
|
61 |
+
strategy, alpaca_dataset, process_count=1
|
62 |
+
)
|
63 |
+
|
64 |
+
input_ids = dataset_wrapper[0]["input_ids"]
|
65 |
+
# fmt: off
|
66 |
+
assert input_ids == [
|
67 |
+
1, # Bos
|
68 |
+
32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction
|
69 |
+
32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input
|
70 |
+
32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output
|
71 |
+
]
|
72 |
+
# fmt: on
|
73 |
+
|
74 |
+
def test_no_train_on_input(self, alpaca_dataset, tokenizer):
|
75 |
+
strategy = AlpacaPromptTokenizingStrategy(
|
76 |
+
AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
|
77 |
+
tokenizer,
|
78 |
+
False, # train_on_inputs
|
79 |
+
2048, # sequence_len
|
80 |
+
)
|
81 |
+
|
82 |
+
dataset_wrapper = TokenizedPromptDataset(
|
83 |
+
strategy, alpaca_dataset, process_count=1
|
84 |
+
)
|
85 |
+
|
86 |
+
labels = dataset_wrapper[0]["labels"]
|
87 |
+
# fmt: off
|
88 |
+
assert labels == [
|
89 |
+
-100, # bos
|
90 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # instruction
|
91 |
+
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # input
|
92 |
+
-100, -100, -100, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # Output
|
93 |
+
]
|
94 |
+
# fmt: on
|
95 |
+
|
96 |
+
def test_w_train_on_input(self, alpaca_dataset, tokenizer):
|
97 |
+
strategy = AlpacaPromptTokenizingStrategy(
|
98 |
+
AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
|
99 |
+
tokenizer,
|
100 |
+
True, # train_on_inputs
|
101 |
+
2048, # sequence_len
|
102 |
+
)
|
103 |
+
|
104 |
+
dataset_wrapper = TokenizedPromptDataset(
|
105 |
+
strategy, alpaca_dataset, process_count=1
|
106 |
+
)
|
107 |
+
|
108 |
+
labels = dataset_wrapper[0]["labels"]
|
109 |
+
# fmt: off
|
110 |
+
assert labels == [
|
111 |
+
1, # Bos
|
112 |
+
32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction
|
113 |
+
32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input
|
114 |
+
32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output
|
115 |
+
]
|
116 |
+
# fmt: on
|