TearGosling Nanobit commited on
Commit
f474650
1 Parent(s): 96deb6b

feat: add Metharme prompt strategy (#446)

Browse files

* Add Metharme tokenizing strategy

This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length.
I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now.

* Redo Metharme tokenizing strategy

lol

* fix: oops

* Rearrange a conditional

* chore: reformat code in accordance with linter

* chore: Make lint not freak out

* chore: fix lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

README.md CHANGED
@@ -257,6 +257,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
257
  ```json
258
  {"conversations": [{"role": "...", "value": "..."}]}
259
  ```
 
 
 
 
260
  - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
261
  ```json
262
  {"conversations": [{"role": "...", "value": "..."}]}
 
257
  ```json
258
  {"conversations": [{"role": "...", "value": "..."}]}
259
  ```
260
+ - `metharme`: instruction, adds additional eos tokens
261
+ ```json
262
+ {"prompt": "...", "generation": "..."}
263
+ ```
264
  - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
265
  ```json
266
  {"conversations": [{"role": "...", "value": "..."}]}
src/axolotl/prompt_strategies/metharme.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
2
+
3
+ import logging
4
+ from typing import Tuple
5
+
6
+ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
7
+ from axolotl.prompters import AlpacaPrompter
8
+
9
+ LOG = logging.getLogger("axolotl")
10
+
11
+ IGNORE_TOKEN_ID = -100
12
+
13
+ # pylint: disable=duplicate-code
14
+
15
+
16
+ class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
17
+ """
18
+ Tokenizing strategy for the Metharme models
19
+ """
20
+
21
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
22
+ return (prompt["prompt"], "", prompt["generation"])
23
+
24
+ def _tokenize(
25
+ self,
26
+ prompt: str,
27
+ add_eos_token: bool = True,
28
+ strip_bos_token: bool = False,
29
+ num_eos_tokens: int = 3,
30
+ ):
31
+ result = self.tokenizer(
32
+ prompt,
33
+ truncation=True,
34
+ max_length=self.sequence_len,
35
+ padding=False,
36
+ return_tensors=None,
37
+ )
38
+ if len(result["input_ids"]) == 0:
39
+ LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
40
+ # If there's already an EOS token there, subtract from the number added
41
+ if result["input_ids"][-1] == self.tokenizer.eos_token_id:
42
+ num_eos_tokens -= 1
43
+
44
+ if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
45
+ for _ in range(num_eos_tokens):
46
+ if len(result["input_ids"]) < self.sequence_len:
47
+ result["input_ids"].append(self.tokenizer.eos_token_id)
48
+ result["attention_mask"].append(1)
49
+
50
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
51
+ result["input_ids"] = result["input_ids"][1:]
52
+ result["attention_mask"] = result["attention_mask"][1:]
53
+
54
+ result["labels"] = result["input_ids"].copy()
55
+ return result
56
+
57
+
58
+ class MetharmePrompter(AlpacaPrompter):
59
+ """
60
+ Prompter for the Metharme models.
61
+ """
62
+
63
+ system_prompt = ""
64
+ system_no_input_prompt = ""
65
+ system_format = ""
66
+ turn_format = "{instruction}"
67
+ turn_no_input_format = "{instruction}"
68
+
69
+ def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
70
+ pass
71
+
72
+
73
+ def load(tokenizer, cfg):
74
+ return MetharmePromptTokenizingStrategy(
75
+ MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
76
+ )