winglian commited on
Commit
eaaeefc
1 Parent(s): f5a828a

jupyter lab fixes (#1139) [skip ci]

Browse files

* add a basic notebook for lab users in the root

* update notebook and fix cors for jupyter

* cell is code

* fix eval batch size check

* remove intro notebook

docker/Dockerfile-cloud CHANGED
@@ -12,7 +12,7 @@ EXPOSE 22
12
 
13
  COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
14
 
15
- RUN pip install jupyterlab notebook && \
16
  jupyter lab clean
17
  RUN apt install --yes --no-install-recommends openssh-server tmux && \
18
  mkdir -p ~/.ssh && \
 
12
 
13
  COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
14
 
15
+ RUN pip install jupyterlab notebook ipywidgets && \
16
  jupyter lab clean
17
  RUN apt install --yes --no-install-recommends openssh-server tmux && \
18
  mkdir -p ~/.ssh && \
scripts/cloud-entrypoint.sh CHANGED
@@ -33,7 +33,7 @@ fi
33
 
34
  if [ "$JUPYTER_DISABLE" != "1" ]; then
35
  # Run Jupyter Lab in the background
36
- jupyter lab --allow-root --ip 0.0.0.0 &
37
  fi
38
 
39
  # Execute the passed arguments (CMD)
 
33
 
34
  if [ "$JUPYTER_DISABLE" != "1" ]; then
35
  # Run Jupyter Lab in the background
36
+ jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
37
  fi
38
 
39
  # Execute the passed arguments (CMD)
src/axolotl/cli/train.py CHANGED
@@ -3,9 +3,11 @@ CLI to run training on a model
3
  """
4
  import logging
5
  from pathlib import Path
 
6
 
7
  import fire
8
  import transformers
 
9
 
10
  from axolotl.cli import (
11
  check_accelerate_default_config,
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
24
  def do_cli(config: Path = Path("examples/"), **kwargs):
25
  # pylint: disable=duplicate-code
26
  parsed_cfg = load_cfg(config, **kwargs)
27
- print_axolotl_text_art()
28
- check_accelerate_default_config()
29
- check_user_token()
30
  parser = transformers.HfArgumentParser((TrainerCliArgs))
31
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
32
  return_remaining_strings=True
33
  )
 
 
34
 
35
- if parsed_cfg.rl:
36
- dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
 
 
 
37
  else:
38
- dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
39
- train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
 
40
 
41
 
42
  if __name__ == "__main__":
 
3
  """
4
  import logging
5
  from pathlib import Path
6
+ from typing import Tuple
7
 
8
  import fire
9
  import transformers
10
+ from transformers import PreTrainedModel, PreTrainedTokenizer
11
 
12
  from axolotl.cli import (
13
  check_accelerate_default_config,
 
26
  def do_cli(config: Path = Path("examples/"), **kwargs):
27
  # pylint: disable=duplicate-code
28
  parsed_cfg = load_cfg(config, **kwargs)
 
 
 
29
  parser = transformers.HfArgumentParser((TrainerCliArgs))
30
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
31
  return_remaining_strings=True
32
  )
33
+ return do_train(parsed_cfg, parsed_cli_args)
34
+
35
 
36
+ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
37
+ print_axolotl_text_art()
38
+ check_accelerate_default_config()
39
+ check_user_token()
40
+ if cfg.rl:
41
+ dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
42
  else:
43
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
44
+
45
+ return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
46
 
47
 
48
  if __name__ == "__main__":
src/axolotl/core/trainer_builder.py CHANGED
@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
746
  training_arguments_kwargs[
747
  "per_device_train_batch_size"
748
  ] = self.cfg.micro_batch_size
749
- training_arguments_kwargs[
750
- "per_device_eval_batch_size"
751
- ] = self.cfg.eval_batch_size
 
752
  training_arguments_kwargs[
753
  "gradient_accumulation_steps"
754
  ] = self.cfg.gradient_accumulation_steps
 
746
  training_arguments_kwargs[
747
  "per_device_train_batch_size"
748
  ] = self.cfg.micro_batch_size
749
+ if self.cfg.eval_batch_size:
750
+ training_arguments_kwargs[
751
+ "per_device_eval_batch_size"
752
+ ] = self.cfg.eval_batch_size
753
  training_arguments_kwargs[
754
  "gradient_accumulation_steps"
755
  ] = self.cfg.gradient_accumulation_steps
src/axolotl/utils/bench.py CHANGED
@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
20
  device = kwargs.get("device", args[0] if args else None)
21
 
22
  if (
23
- not torch.cuda.is_available()
 
24
  or device == "auto"
25
  or torch.device(device).type == "cpu"
26
  ):
 
20
  device = kwargs.get("device", args[0] if args else None)
21
 
22
  if (
23
+ device is None
24
+ or not torch.cuda.is_available()
25
  or device == "auto"
26
  or torch.device(device).type == "cpu"
27
  ):
src/axolotl/utils/models.py CHANGED
@@ -2,7 +2,7 @@
2
  import logging
3
  import math
4
  import os
5
- from typing import Any, Optional, Tuple, Union # noqa: F401
6
 
7
  import addict
8
  import bitsandbytes as bnb
@@ -348,7 +348,11 @@ def load_model(
348
  LOG.info("patching _expand_mask")
349
  hijack_expand_mask()
350
 
351
- model_kwargs = {}
 
 
 
 
352
 
353
  max_memory = cfg.max_memory
354
  device_map = cfg.device_map
 
2
  import logging
3
  import math
4
  import os
5
+ from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
6
 
7
  import addict
8
  import bitsandbytes as bnb
 
348
  LOG.info("patching _expand_mask")
349
  hijack_expand_mask()
350
 
351
+ model_kwargs: Dict[str, Any] = {}
352
+
353
+ if cfg.model_kwargs:
354
+ for key, val in model_kwargs.items():
355
+ model_kwargs[key] = val
356
 
357
  max_memory = cfg.max_memory
358
  device_map = cfg.device_map