|
""" |
|
Test module for raw i/o data for prompts |
|
""" |
|
import pytest |
|
from datasets import Dataset |
|
from tokenizers import AddedToken |
|
from transformers import AutoTokenizer |
|
|
|
from axolotl.datasets import TokenizedPromptDataset |
|
from axolotl.prompt_strategies.input_output import ( |
|
RawInputOutputPrompter, |
|
RawInputOutputStrategy, |
|
) |
|
|
|
|
|
@pytest.fixture(name="segments_dataset") |
|
def fixture_sharegpt_dataset(): |
|
return Dataset.from_list( |
|
[ |
|
{ |
|
"segments": [ |
|
{ |
|
"label": False, |
|
"text": "<s>hello ", |
|
}, |
|
{ |
|
"label": True, |
|
"text": "hi there.<eot>", |
|
}, |
|
{ |
|
"label": False, |
|
"text": "goodbye ", |
|
}, |
|
{ |
|
"label": True, |
|
"text": "farewell<eot>", |
|
}, |
|
] |
|
} |
|
] |
|
) |
|
|
|
|
|
@pytest.fixture(name="tokenizer") |
|
def fixture_tokenizer(): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"casperhansen/mistral-7b-instruct-v0.1-awq" |
|
) |
|
tokenizer.add_tokens( |
|
[ |
|
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False), |
|
] |
|
) |
|
|
|
return tokenizer |
|
|
|
|
|
class TestRawInputOutputPrompts: |
|
""" |
|
Test class for raw i/o prompter |
|
""" |
|
|
|
def test_segment_prompts(self, segments_dataset, tokenizer): |
|
strategy = RawInputOutputStrategy( |
|
RawInputOutputPrompter(), |
|
tokenizer, |
|
False, |
|
2048, |
|
) |
|
|
|
dataset_wrapper = TokenizedPromptDataset( |
|
strategy, segments_dataset, process_count=1 |
|
) |
|
|
|
input_ids = dataset_wrapper[0]["input_ids"] |
|
labels = dataset_wrapper[0]["labels"] |
|
|
|
assert ( |
|
tokenizer.decode(input_ids) |
|
== "<s> hello hi there.<eot> goodbye farewell<eot>" |
|
) |
|
|
|
assert input_ids == [ |
|
1, |
|
6312, |
|
28709, |
|
28705, |
|
12014, |
|
736, |
|
28723, |
|
32000, |
|
1179, |
|
17664, |
|
28705, |
|
19111, |
|
5458, |
|
32000, |
|
] |
|
|
|
|
|
|
|
assert labels == [ |
|
-100, |
|
-100, |
|
-100, |
|
-100, |
|
12014, |
|
736, |
|
28723, |
|
32000, |
|
-100, |
|
-100, |
|
-100, |
|
19111, |
|
5458, |
|
32000, |
|
] |
|
|
|
|