theobjectivedad commited on
Commit
b1f4f7a
1 Parent(s): 83237b8

Fixed pre-commit problems, fixed small bug in logging_config to handle LOG_LEVEL env var

Browse files
scripts/finetune.py CHANGED
@@ -17,6 +17,7 @@ import yaml
17
  from optimum.bettertransformer import BetterTransformer
18
  from transformers import GenerationConfig, TextStreamer
19
 
 
20
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
21
  from axolotl.utils.dict import DictDefault
22
  from axolotl.utils.models import load_model, load_tokenizer
@@ -24,7 +25,6 @@ from axolotl.utils.tokenization import check_dataset_labels
24
  from axolotl.utils.trainer import setup_trainer
25
  from axolotl.utils.validation import validate_config
26
  from axolotl.utils.wandb import setup_wandb_env_vars
27
- from axolotl.logging_config import configure_logging
28
 
29
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
30
  src_dir = os.path.join(project_root, "src")
 
17
  from optimum.bettertransformer import BetterTransformer
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
+ from axolotl.logging_config import configure_logging
21
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
22
  from axolotl.utils.dict import DictDefault
23
  from axolotl.utils.models import load_model, load_tokenizer
 
25
  from axolotl.utils.trainer import setup_trainer
26
  from axolotl.utils.validation import validate_config
27
  from axolotl.utils.wandb import setup_wandb_env_vars
 
28
 
29
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
30
  src_dir = os.path.join(project_root, "src")
src/axolotl/datasets.py CHANGED
@@ -16,6 +16,7 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
16
 
17
  LOG = logging.getLogger("axolotl")
18
 
 
19
  class TokenizedPromptDataset(IterableDataset):
20
  """
21
  Iterable dataset that returns tokenized prompts from a stream of text files.
 
16
 
17
  LOG = logging.getLogger("axolotl")
18
 
19
+
20
  class TokenizedPromptDataset(IterableDataset):
21
  """
22
  Iterable dataset that returns tokenized prompts from a stream of text files.
src/axolotl/logging_config.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import sys
2
  from logging.config import dictConfig
3
  from typing import Any, Dict
@@ -18,7 +21,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
18
  "stream": sys.stdout,
19
  },
20
  },
21
- "root": {"handlers": ["console"], "level": "INFO"},
22
  }
23
 
24
 
 
1
+ """Logging configuration settings"""
2
+
3
+ import os
4
  import sys
5
  from logging.config import dictConfig
6
  from typing import Any, Dict
 
21
  "stream": sys.stdout,
22
  },
23
  },
24
+ "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
25
  }
26
 
27
 
src/axolotl/monkeypatch/llama_landmark_attn.py CHANGED
@@ -52,6 +52,7 @@ from transformers.utils import (
52
  logging,
53
  replace_return_docstrings,
54
  )
 
55
  LOG = logging.getLogger("axolotl")
56
 
57
  _CONFIG_FOR_DOC = "LlamaConfig"
@@ -861,7 +862,7 @@ class LlamaModel(LlamaPreTrainedModel):
861
 
862
  if self.gradient_checkpointing and self.training:
863
  if use_cache:
864
- logger.warning_once(
865
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
866
  )
867
  use_cache = False
 
52
  logging,
53
  replace_return_docstrings,
54
  )
55
+
56
  LOG = logging.getLogger("axolotl")
57
 
58
  _CONFIG_FOR_DOC = "LlamaConfig"
 
862
 
863
  if self.gradient_checkpointing and self.training:
864
  if use_cache:
865
+ LOG.warning_once(
866
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
867
  )
868
  use_cache = False
src/axolotl/prompt_strategies/pygmalion.py CHANGED
@@ -11,6 +11,8 @@ from axolotl.prompt_tokenizers import (
11
  tokenize_prompt_default,
12
  )
13
 
 
 
14
  IGNORE_TOKEN_ID = -100
15
 
16
 
 
11
  tokenize_prompt_default,
12
  )
13
 
14
+ LOG = logging.getLogger("axolotl")
15
+
16
  IGNORE_TOKEN_ID = -100
17
 
18
 
src/axolotl/prompters.py CHANGED
@@ -5,6 +5,7 @@ import logging
5
  from enum import Enum, auto
6
  from typing import Generator, List, Optional, Tuple, Union
7
 
 
8
  IGNORE_TOKEN_ID = -100
9
 
10
 
 
5
  from enum import Enum, auto
6
  from typing import Generator, List, Optional, Tuple, Union
7
 
8
+ LOG = logging.getLogger("axolotl")
9
  IGNORE_TOKEN_ID = -100
10
 
11
 
src/axolotl/utils/data.py CHANGED
@@ -258,9 +258,7 @@ def load_tokenized_prepared_datasets(
258
  suffix = ""
259
  if ":load_" in d.type:
260
  suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
261
- LOG.error(
262
- f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
263
- )
264
  raise ValueError(
265
  f"unhandled prompt tokenization strategy: {d.type} {suffix}"
266
  )
@@ -271,9 +269,7 @@ def load_tokenized_prepared_datasets(
271
  samples = samples + list(d)
272
  dataset = Dataset.from_list(samples).shuffle(seed=seed)
273
  if cfg.local_rank == 0:
274
- LOG.info(
275
- f"Saving merged prepared dataset to disk... {prepared_ds_path}"
276
- )
277
  dataset.save_to_disk(prepared_ds_path)
278
  if cfg.push_dataset_to_hub:
279
  LOG.info(
@@ -366,9 +362,7 @@ def load_prepare_datasets(
366
  [dataset],
367
  seq_length=max_packed_sequence_len,
368
  )
369
- LOG.info(
370
- f"packing master dataset to len: {cfg.max_packed_sequence_len}"
371
- )
372
  dataset = Dataset.from_list(list(constant_len_dataset))
373
 
374
  # filter out bad data
 
258
  suffix = ""
259
  if ":load_" in d.type:
260
  suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
261
+ LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
 
 
262
  raise ValueError(
263
  f"unhandled prompt tokenization strategy: {d.type} {suffix}"
264
  )
 
269
  samples = samples + list(d)
270
  dataset = Dataset.from_list(samples).shuffle(seed=seed)
271
  if cfg.local_rank == 0:
272
+ LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
 
 
273
  dataset.save_to_disk(prepared_ds_path)
274
  if cfg.push_dataset_to_hub:
275
  LOG.info(
 
362
  [dataset],
363
  seq_length=max_packed_sequence_len,
364
  )
365
+ LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
 
 
366
  dataset = Dataset.from_list(list(constant_len_dataset))
367
 
368
  # filter out bad data
tests/test_prompt_tokenizers.py CHANGED
@@ -16,9 +16,6 @@ from axolotl.prompt_tokenizers import (
16
  ShareGPTPromptTokenizingStrategy,
17
  )
18
  from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
19
- from axolotl.logging_config import configure_logging
20
-
21
- configure_logging()
22
 
23
  LOG = logging.getLogger("axolotl")
24
 
 
16
  ShareGPTPromptTokenizingStrategy,
17
  )
18
  from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
 
 
 
19
 
20
  LOG = logging.getLogger("axolotl")
21