Nanobit commited on
Commit
e9650d3
1 Parent(s): f1232b3

Fix mypy typing

Browse files
scripts/alpaca_json_to_jsonl.py CHANGED
@@ -3,7 +3,7 @@
3
  import os
4
  import sys
5
 
6
- from typing import Optional
7
  from pathlib import Path
8
 
9
  import fire
@@ -35,6 +35,7 @@ def main(
35
  """
36
 
37
  file_reader = FileReader()
 
38
  if to_stdout or output is None:
39
  writer = StdoutWriter()
40
  else:
 
3
  import os
4
  import sys
5
 
6
+ from typing import Optional, Union
7
  from pathlib import Path
8
 
9
  import fire
 
35
  """
36
 
37
  file_reader = FileReader()
38
+ writer: Union[StdoutWriter, FileWriter]
39
  if to_stdout or output is None:
40
  writer = StdoutWriter()
41
  else:
scripts/extract_lora.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import logging
2
+ # import os
3
+ # import random
4
+ # import signal
5
+ # import sys
6
+ # from pathlib import Path
7
+
8
+ # import fire
9
+ # import torch
10
+ # import yaml
11
+ # from addict import Dict
12
+
13
+ # from peft import set_peft_model_state_dict, get_peft_model_state_dict
14
+
15
+ # # add src to the pythonpath so we don't need to pip install this
16
+ # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
17
+ # src_dir = os.path.join(project_root, "src")
18
+ # sys.path.insert(0, src_dir)
19
+
20
+ # from axolotl.utils.data import load_prepare_datasets
21
+ # from axolotl.utils.models import load_model
22
+ # from axolotl.utils.trainer import setup_trainer
23
+ # from axolotl.utils.wandb import setup_wandb_env_vars
24
+
25
+ # logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
26
+
27
+
28
+ # def choose_device(cfg):
29
+ # def get_device():
30
+ # if torch.cuda.is_available():
31
+ # return "cuda"
32
+ # else:
33
+ # try:
34
+ # if torch.backends.mps.is_available():
35
+ # return "mps"
36
+ # except:
37
+ # return "cpu"
38
+
39
+ # cfg.device = get_device()
40
+ # if cfg.device == "cuda":
41
+ # cfg.device_map = {"": cfg.local_rank}
42
+ # else:
43
+ # cfg.device_map = {"": cfg.device}
44
+
45
+
46
+ # def choose_config(path: Path):
47
+ # yaml_files = [file for file in path.glob("*.yml")]
48
+
49
+ # if not yaml_files:
50
+ # raise ValueError(
51
+ # "No YAML config files found in the specified directory. Are you using a .yml extension?"
52
+ # )
53
+
54
+ # print("Choose a YAML file:")
55
+ # for idx, file in enumerate(yaml_files):
56
+ # print(f"{idx + 1}. {file}")
57
+
58
+ # chosen_file = None
59
+ # while chosen_file is None:
60
+ # try:
61
+ # choice = int(input("Enter the number of your choice: "))
62
+ # if 1 <= choice <= len(yaml_files):
63
+ # chosen_file = yaml_files[choice - 1]
64
+ # else:
65
+ # print("Invalid choice. Please choose a number from the list.")
66
+ # except ValueError:
67
+ # print("Invalid input. Please enter a number.")
68
+
69
+ # return chosen_file
70
+
71
+
72
+ # def save_latest_checkpoint_as_lora(
73
+ # config: Path = Path("configs/"),
74
+ # prepare_ds_only: bool = False,
75
+ # **kwargs,
76
+ # ):
77
+ # if Path(config).is_dir():
78
+ # config = choose_config(config)
79
+
80
+ # # load the config from the yaml file
81
+ # with open(config, "r") as f:
82
+ # cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader))
83
+ # # if there are any options passed in the cli, if it is something that seems valid from the yaml,
84
+ # # then overwrite the value
85
+ # cfg_keys = dict(cfg).keys()
86
+ # for k in kwargs:
87
+ # if k in cfg_keys:
88
+ # # handle booleans
89
+ # if isinstance(cfg[k], bool):
90
+ # cfg[k] = bool(kwargs[k])
91
+ # else:
92
+ # cfg[k] = kwargs[k]
93
+
94
+ # # setup some derived config / hyperparams
95
+ # cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
96
+ # cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
97
+ # cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
98
+ # assert cfg.local_rank == 0, "Run this with only one device!"
99
+
100
+ # choose_device(cfg)
101
+ # cfg.ddp = False
102
+
103
+ # if cfg.device == "mps":
104
+ # cfg.load_in_8bit = False
105
+ # cfg.tf32 = False
106
+ # if cfg.bf16:
107
+ # cfg.fp16 = True
108
+ # cfg.bf16 = False
109
+
110
+ # # Load the model and tokenizer
111
+ # logging.info("loading model, tokenizer, and lora_config...")
112
+ # model, tokenizer, lora_config = load_model(
113
+ # cfg.base_model,
114
+ # cfg.base_model_config,
115
+ # cfg.model_type,
116
+ # cfg.tokenizer_type,
117
+ # cfg,
118
+ # adapter=cfg.adapter,
119
+ # inference=True,
120
+ # )
121
+
122
+ # model.config.use_cache = False
123
+
124
+ # if torch.__version__ >= "2" and sys.platform != "win32":
125
+ # logging.info("Compiling torch model")
126
+ # model = torch.compile(model)
127
+
128
+ # possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
129
+ # if len(possible_checkpoints) > 0:
130
+ # sorted_paths = sorted(
131
+ # possible_checkpoints, key=lambda path: int(path.split("-")[-1])
132
+ # )
133
+ # resume_from_checkpoint = sorted_paths[-1]
134
+ # else:
135
+ # raise FileNotFoundError("Checkpoints folder not found")
136
+
137
+ # pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin")
138
+
139
+ # assert os.path.exists(pytorch_bin_path), "Bin not found"
140
+
141
+ # logging.info(f"Loading {pytorch_bin_path}")
142
+ # adapters_weights = torch.load(pytorch_bin_path, map_location="cpu")
143
+
144
+ # # d = get_peft_model_state_dict(model)
145
+ # print(model.load_state_dict(adapters_weights))
146
+ # # with open('b.log', "w") as f:
147
+ # # f.write(str(d.keys()))
148
+ # assert False
149
+
150
+ # print((adapters_weights.keys()))
151
+ # with open("a.log", "w") as f:
152
+ # f.write(str(adapters_weights.keys()))
153
+ # assert False
154
+
155
+ # logging.info("Setting peft model state dict")
156
+ # set_peft_model_state_dict(model, adapters_weights)
157
+
158
+ # logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}")
159
+ # model.save_pretrained(cfg.output_dir)
160
+
161
+
162
+ # if __name__ == "__main__":
163
+ # fire.Fire(save_latest_checkpoint_as_lora)
src/axolotl/prompt_strategies/pygmalion.py CHANGED
@@ -3,7 +3,7 @@
3
  import copy
4
  import logging
5
  from collections import defaultdict
6
- from typing import Generator
7
 
8
  from axolotl.prompt_tokenizers import (
9
  PromptTokenizingStrategy,
@@ -19,7 +19,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
19
  Tokenizing strategy for Pygmalion.
20
  """
21
 
22
- bot_prefix_token_ids = []
23
 
24
  def __init__(self, prompter, tokenizer, *args, **kwargs):
25
  super().__init__(prompter, tokenizer, *args, **kwargs)
@@ -88,7 +88,7 @@ class PygmalionPrompter:
88
 
89
  def build_prompt(
90
  self, source, *args, **kwargs # pylint: disable=unused-argument
91
- ) -> Generator[str, None, None]:
92
  for msg in source:
93
  yield msg["role"], msg["value"]
94
 
 
3
  import copy
4
  import logging
5
  from collections import defaultdict
6
+ from typing import Generator, List, Tuple
7
 
8
  from axolotl.prompt_tokenizers import (
9
  PromptTokenizingStrategy,
 
19
  Tokenizing strategy for Pygmalion.
20
  """
21
 
22
+ bot_prefix_token_ids: List[int] = []
23
 
24
  def __init__(self, prompter, tokenizer, *args, **kwargs):
25
  super().__init__(prompter, tokenizer, *args, **kwargs)
 
88
 
89
  def build_prompt(
90
  self, source, *args, **kwargs # pylint: disable=unused-argument
91
+ ) -> Generator[Tuple[str, str], None, None]:
92
  for msg in source:
93
  yield msg["role"], msg["value"]
94
 
src/axolotl/prompt_tokenizers.py CHANGED
@@ -226,20 +226,16 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
226
  Tokenizing strategy for Completion prompts.
227
  """
228
 
229
- def parse_instruction_fields(self, prompt) -> str:
230
- return prompt["text"]
231
-
232
  def tokenize_prompt(self, prompt):
233
- instruction = self.parse_instruction_fields(prompt)
234
- full_prompt = self._build_full_prompt(instruction, None, None)
235
  tokenized_full_prompt = self._tokenize(full_prompt)
236
 
237
  return tokenized_full_prompt
238
 
239
  def _build_full_prompt(
240
  self, instruction, input, response
241
- ): # pylint: disable=unused-argument, redefined-builtin
242
- return next(iter(self.prompter.build_prompt(instruction)))
243
 
244
 
245
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
@@ -419,7 +415,7 @@ def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
419
  Returns the default values for the tokenize prompt function
420
  """
421
 
422
- result = {
423
  "input_ids": [],
424
  "attention_mask": [],
425
  "labels": [],
 
226
  Tokenizing strategy for Completion prompts.
227
  """
228
 
 
 
 
229
  def tokenize_prompt(self, prompt):
230
+ full_prompt = self._build_full_prompt(prompt["text"], None, None)
 
231
  tokenized_full_prompt = self._tokenize(full_prompt)
232
 
233
  return tokenized_full_prompt
234
 
235
  def _build_full_prompt(
236
  self, instruction, input, response
237
+ ): # pylint: disable=redefined-builtin
238
+ return next(iter(self.prompter.build_prompt(instruction, input, response)))
239
 
240
 
241
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
 
415
  Returns the default values for the tokenize prompt function
416
  """
417
 
418
+ result: Dict[str, List[int]] = {
419
  "input_ids": [],
420
  "attention_mask": [],
421
  "labels": [],
src/axolotl/prompters.py CHANGED
@@ -3,7 +3,7 @@
3
  import dataclasses
4
  import logging
5
  from enum import auto, Enum
6
- from typing import List, Union, Generator
7
 
8
  IGNORE_TOKEN_ID = -100
9
 
@@ -24,7 +24,7 @@ class AlpacaPrompter:
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
27
- prompt_style = None
28
 
29
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
30
  self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
@@ -231,18 +231,18 @@ class Conversation:
231
  offset: int
232
  sep_style: SeparatorStyle = SeparatorStyle.SINGLE
233
  sep: str = "###"
234
- sep2: str = None
235
 
236
  def get_prompt(self) -> Generator[str, None, None]:
237
- seps = [self.sep, self.sep2]
238
- preamble = self.system + seps[0]
239
  yield preamble
240
  for _, (role, message) in enumerate(self.messages):
241
  if message:
242
- yield (role + ":", " " + message)
243
  else:
244
  logging.warning(f"role with empty message: {role}")
245
- yield (role + ":",)
246
 
247
  def copy(self):
248
  return Conversation(
 
3
  import dataclasses
4
  import logging
5
  from enum import auto, Enum
6
+ from typing import List, Optional, Union, Generator
7
 
8
  IGNORE_TOKEN_ID = -100
9
 
 
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
27
+ prompt_style: Optional[PromptStyle] = None
28
 
29
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
30
  self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
 
231
  offset: int
232
  sep_style: SeparatorStyle = SeparatorStyle.SINGLE
233
  sep: str = "###"
234
+ sep2: Optional[str] = None
235
 
236
  def get_prompt(self) -> Generator[str, None, None]:
237
+ # seps = [self.sep, self.sep2]
238
+ preamble = self.system + self.sep
239
  yield preamble
240
  for _, (role, message) in enumerate(self.messages):
241
  if message:
242
+ yield role + ":" + " " + message
243
  else:
244
  logging.warning(f"role with empty message: {role}")
245
+ yield role + ":"
246
 
247
  def copy(self):
248
  return Conversation(
src/axolotl/utils/data.py CHANGED
@@ -3,7 +3,7 @@
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
- from typing import Tuple, Union
7
 
8
  from datasets import (
9
  load_from_disk,
@@ -95,40 +95,36 @@ def load_tokenized_prepared_datasets(
95
 
96
  # prefer local dataset, even if hub exists
97
  if Path(d.path).exists():
98
- ds: Dataset = load_dataset(
99
  "json", data_files=d.path, streaming=False, split=None
100
  )
101
  elif ds_from_hub:
102
  if d.data_files:
103
- ds: Dataset = load_dataset(
104
  d.path,
105
  streaming=False,
106
  data_files=d.data_files,
107
  use_auth_token=use_auth_token,
108
  )
109
  else:
110
- ds: Dataset = load_dataset(
111
  d.path, streaming=False, use_auth_token=use_auth_token
112
  )
113
  else:
114
  fp = hf_hub_download(
115
  repo_id=d.path, repo_type="dataset", filename=d.data_files
116
  )
117
- ds: Dataset = load_dataset(
118
- "json", data_files=fp, streaming=False, split=None
119
- )
120
  if not ds:
121
  raise ValueError("unhandled dataset load")
122
  # support for using a subset of the data
123
  if d.shards:
124
  if "train" in ds:
125
- ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
126
  num_shards=d.shards, index=0
127
  )
128
  else:
129
- ds: Dataset = ds.shuffle(seed=42).shard(
130
- num_shards=d.shards, index=0
131
- )
132
  d_type = d.type
133
  d_type_split = d_type.split(":")
134
  d_base_type = d_type_split[0]
@@ -232,7 +228,7 @@ def load_tokenized_prepared_datasets(
232
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
233
  logging.info("tokenizing, merging, and shuffling master dataset")
234
 
235
- samples = []
236
  for d in datasets:
237
  samples = samples + list(d)
238
  dataset = Dataset.from_list(samples).shuffle(seed=42)
 
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
+ from typing import List, Tuple, Union
7
 
8
  from datasets import (
9
  load_from_disk,
 
95
 
96
  # prefer local dataset, even if hub exists
97
  if Path(d.path).exists():
98
+ ds = load_dataset(
99
  "json", data_files=d.path, streaming=False, split=None
100
  )
101
  elif ds_from_hub:
102
  if d.data_files:
103
+ ds = load_dataset(
104
  d.path,
105
  streaming=False,
106
  data_files=d.data_files,
107
  use_auth_token=use_auth_token,
108
  )
109
  else:
110
+ ds = load_dataset(
111
  d.path, streaming=False, use_auth_token=use_auth_token
112
  )
113
  else:
114
  fp = hf_hub_download(
115
  repo_id=d.path, repo_type="dataset", filename=d.data_files
116
  )
117
+ ds = load_dataset("json", data_files=fp, streaming=False, split=None)
 
 
118
  if not ds:
119
  raise ValueError("unhandled dataset load")
120
  # support for using a subset of the data
121
  if d.shards:
122
  if "train" in ds:
123
+ ds = ds.shuffle(seed=42)["train"].shard(
124
  num_shards=d.shards, index=0
125
  )
126
  else:
127
+ ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
 
 
128
  d_type = d.type
129
  d_type_split = d_type.split(":")
130
  d_base_type = d_type_split[0]
 
228
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
229
  logging.info("tokenizing, merging, and shuffling master dataset")
230
 
231
+ samples: List[int] = []
232
  for d in datasets:
233
  samples = samples + list(d)
234
  dataset = Dataset.from_list(samples).shuffle(seed=42)
src/axolotl/utils/models.py CHANGED
@@ -81,7 +81,7 @@ def load_model(
81
  adapter="lora",
82
  inference=False,
83
  ):
84
- # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
85
  """
86
  Load a model from a base model and a model type.
87
  """
 
81
  adapter="lora",
82
  inference=False,
83
  ):
84
+ # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
85
  """
86
  Load a model from a base model and a model type.
87
  """
src/axolotl/utils/trainer.py CHANGED
@@ -5,6 +5,7 @@ import math
5
  import os
6
  import sys
7
  from pathlib import Path
 
8
 
9
  import bitsandbytes as bnb
10
  import torch.cuda
@@ -28,7 +29,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
28
  self.lr_scheduler = None
29
 
30
  def create_scheduler(
31
- self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
32
  ):
33
  optimizer = self.optimizer if optimizer is None else optimizer
34
  num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
 
5
  import os
6
  import sys
7
  from pathlib import Path
8
+ from typing import Optional
9
 
10
  import bitsandbytes as bnb
11
  import torch.cuda
 
29
  self.lr_scheduler = None
30
 
31
  def create_scheduler(
32
+ self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None
33
  ):
34
  optimizer = self.optimizer if optimizer is None else optimizer
35
  num_warmup_steps = self.args.get_warmup_steps(num_training_steps)