update table for rwkv4 support, fix process count for dataset (#822)
Browse files- README.md +1 -0
- src/axolotl/datasets.py +8 -2
- src/axolotl/utils/data.py +30 -10
README.md
CHANGED
@@ -74,6 +74,7 @@ Features:
|
|
74 |
| gpt-j | β
| β
| β
| β | β | β | β |
|
75 |
| XGen | β
| β | β
| β | β | β | β
|
|
76 |
| phi | β
| β
| β
| β | β | β | β |
|
|
|
77 |
|
78 |
|
79 |
## Quickstart β‘
|
|
|
74 |
| gpt-j | β
| β
| β
| β | β | β | β |
|
75 |
| XGen | β
| β | β
| β | β | β | β
|
|
76 |
| phi | β
| β
| β
| β | β | β | β |
|
77 |
+
| RWKV | β
| β | β | β | β | β | β |
|
78 |
|
79 |
|
80 |
## Quickstart β‘
|
src/axolotl/datasets.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
|
3 |
import logging
|
4 |
import os
|
5 |
-
from typing import List
|
6 |
|
7 |
import torch
|
8 |
from datasets import Dataset, IterableDataset
|
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
|
30 |
self,
|
31 |
prompt_tokenizer: PromptTokenizingStrategy,
|
32 |
dataset: IterableDataset,
|
|
|
33 |
**kwargs,
|
34 |
):
|
35 |
self.prompt_tokenizer = prompt_tokenizer
|
|
|
36 |
super().__init__(self.process(dataset).data, **kwargs)
|
37 |
|
38 |
def process(self, dataset):
|
39 |
features = dataset.features.keys()
|
40 |
-
num_proc =
|
|
|
|
|
|
|
|
|
41 |
map_kwargs = {}
|
42 |
if self.prompt_tokenizer.supports_batched:
|
43 |
map_kwargs["batched"] = True
|
|
|
2 |
|
3 |
import logging
|
4 |
import os
|
5 |
+
from typing import List, Optional
|
6 |
|
7 |
import torch
|
8 |
from datasets import Dataset, IterableDataset
|
|
|
30 |
self,
|
31 |
prompt_tokenizer: PromptTokenizingStrategy,
|
32 |
dataset: IterableDataset,
|
33 |
+
process_count: Optional[int] = None,
|
34 |
**kwargs,
|
35 |
):
|
36 |
self.prompt_tokenizer = prompt_tokenizer
|
37 |
+
self.process_count = process_count
|
38 |
super().__init__(self.process(dataset).data, **kwargs)
|
39 |
|
40 |
def process(self, dataset):
|
41 |
features = dataset.features.keys()
|
42 |
+
num_proc = (
|
43 |
+
min(64, self.process_count)
|
44 |
+
if self.process_count
|
45 |
+
else min(64, os.cpu_count())
|
46 |
+
)
|
47 |
map_kwargs = {}
|
48 |
if self.prompt_tokenizer.supports_batched:
|
49 |
map_kwargs["batched"] = True
|
src/axolotl/utils/data.py
CHANGED
@@ -482,10 +482,14 @@ def get_dataset_wrapper(
|
|
482 |
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
483 |
)
|
484 |
dataset_prompter = UnsupportedPrompter()
|
485 |
-
dataset_wrapper = TokenizedPromptDataset(
|
|
|
|
|
486 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
487 |
dataset_prompter = UnsupportedPrompter()
|
488 |
-
dataset_wrapper = TokenizedPromptDataset(
|
|
|
|
|
489 |
elif d_base_type == "alpaca":
|
490 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
491 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
@@ -494,7 +498,9 @@ def get_dataset_wrapper(
|
|
494 |
cfg.train_on_inputs,
|
495 |
cfg.sequence_len,
|
496 |
)
|
497 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
498 |
dataset_wrapper = ds_wrapper
|
499 |
elif d_base_type == "explainchoice":
|
500 |
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
@@ -504,7 +510,9 @@ def get_dataset_wrapper(
|
|
504 |
cfg.train_on_inputs,
|
505 |
cfg.sequence_len,
|
506 |
)
|
507 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
508 |
dataset_wrapper = ds_wrapper
|
509 |
elif d_base_type == "concisechoice":
|
510 |
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
@@ -514,7 +522,9 @@ def get_dataset_wrapper(
|
|
514 |
cfg.train_on_inputs,
|
515 |
cfg.sequence_len,
|
516 |
)
|
517 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
518 |
dataset_wrapper = ds_wrapper
|
519 |
elif d_base_type == "summarizetldr":
|
520 |
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
@@ -524,7 +534,9 @@ def get_dataset_wrapper(
|
|
524 |
cfg.train_on_inputs,
|
525 |
cfg.sequence_len,
|
526 |
)
|
527 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
528 |
dataset_wrapper = ds_wrapper
|
529 |
elif d_base_type == "jeopardy":
|
530 |
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
@@ -534,7 +546,9 @@ def get_dataset_wrapper(
|
|
534 |
cfg.train_on_inputs,
|
535 |
cfg.sequence_len,
|
536 |
)
|
537 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
538 |
dataset_wrapper = ds_wrapper
|
539 |
elif d_base_type == "oasst":
|
540 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
@@ -544,7 +558,9 @@ def get_dataset_wrapper(
|
|
544 |
cfg.train_on_inputs,
|
545 |
cfg.sequence_len,
|
546 |
)
|
547 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
548 |
dataset_wrapper = ds_wrapper
|
549 |
elif d_base_type == "gpteacher":
|
550 |
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
@@ -554,7 +570,9 @@ def get_dataset_wrapper(
|
|
554 |
cfg.train_on_inputs,
|
555 |
cfg.sequence_len,
|
556 |
)
|
557 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
558 |
dataset_wrapper = ds_wrapper
|
559 |
elif d_base_type == "reflection":
|
560 |
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
@@ -564,7 +582,9 @@ def get_dataset_wrapper(
|
|
564 |
cfg.train_on_inputs,
|
565 |
cfg.sequence_len,
|
566 |
)
|
567 |
-
ds_wrapper = TokenizedPromptDataset(
|
|
|
|
|
568 |
dataset_wrapper = ds_wrapper
|
569 |
else:
|
570 |
suffix = ""
|
|
|
482 |
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
483 |
)
|
484 |
dataset_prompter = UnsupportedPrompter()
|
485 |
+
dataset_wrapper = TokenizedPromptDataset(
|
486 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
487 |
+
)
|
488 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
489 |
dataset_prompter = UnsupportedPrompter()
|
490 |
+
dataset_wrapper = TokenizedPromptDataset(
|
491 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
492 |
+
)
|
493 |
elif d_base_type == "alpaca":
|
494 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
495 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
|
498 |
cfg.train_on_inputs,
|
499 |
cfg.sequence_len,
|
500 |
)
|
501 |
+
ds_wrapper = TokenizedPromptDataset(
|
502 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
503 |
+
)
|
504 |
dataset_wrapper = ds_wrapper
|
505 |
elif d_base_type == "explainchoice":
|
506 |
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
|
|
510 |
cfg.train_on_inputs,
|
511 |
cfg.sequence_len,
|
512 |
)
|
513 |
+
ds_wrapper = TokenizedPromptDataset(
|
514 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
515 |
+
)
|
516 |
dataset_wrapper = ds_wrapper
|
517 |
elif d_base_type == "concisechoice":
|
518 |
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
|
|
522 |
cfg.train_on_inputs,
|
523 |
cfg.sequence_len,
|
524 |
)
|
525 |
+
ds_wrapper = TokenizedPromptDataset(
|
526 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
527 |
+
)
|
528 |
dataset_wrapper = ds_wrapper
|
529 |
elif d_base_type == "summarizetldr":
|
530 |
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
|
|
534 |
cfg.train_on_inputs,
|
535 |
cfg.sequence_len,
|
536 |
)
|
537 |
+
ds_wrapper = TokenizedPromptDataset(
|
538 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
539 |
+
)
|
540 |
dataset_wrapper = ds_wrapper
|
541 |
elif d_base_type == "jeopardy":
|
542 |
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
|
|
546 |
cfg.train_on_inputs,
|
547 |
cfg.sequence_len,
|
548 |
)
|
549 |
+
ds_wrapper = TokenizedPromptDataset(
|
550 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
551 |
+
)
|
552 |
dataset_wrapper = ds_wrapper
|
553 |
elif d_base_type == "oasst":
|
554 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
|
558 |
cfg.train_on_inputs,
|
559 |
cfg.sequence_len,
|
560 |
)
|
561 |
+
ds_wrapper = TokenizedPromptDataset(
|
562 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
563 |
+
)
|
564 |
dataset_wrapper = ds_wrapper
|
565 |
elif d_base_type == "gpteacher":
|
566 |
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
|
|
570 |
cfg.train_on_inputs,
|
571 |
cfg.sequence_len,
|
572 |
)
|
573 |
+
ds_wrapper = TokenizedPromptDataset(
|
574 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
575 |
+
)
|
576 |
dataset_wrapper = ds_wrapper
|
577 |
elif d_base_type == "reflection":
|
578 |
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
|
|
582 |
cfg.train_on_inputs,
|
583 |
cfg.sequence_len,
|
584 |
)
|
585 |
+
ds_wrapper = TokenizedPromptDataset(
|
586 |
+
ds_strategy, dataset, process_count=cfg.dataset_processes
|
587 |
+
)
|
588 |
dataset_wrapper = ds_wrapper
|
589 |
else:
|
590 |
suffix = ""
|