jupyter lab fixes (#1139) [skip ci]
Browse files* add a basic notebook for lab users in the root
* update notebook and fix cors for jupyter
* cell is code
* fix eval batch size check
* remove intro notebook
- docker/Dockerfile-cloud +1 -1
- scripts/cloud-entrypoint.sh +1 -1
- src/axolotl/cli/train.py +13 -7
- src/axolotl/core/trainer_builder.py +4 -3
- src/axolotl/utils/bench.py +2 -1
- src/axolotl/utils/models.py +6 -2
docker/Dockerfile-cloud
CHANGED
@@ -12,7 +12,7 @@ EXPOSE 22
|
|
12 |
|
13 |
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
14 |
|
15 |
-
RUN pip install jupyterlab notebook && \
|
16 |
jupyter lab clean
|
17 |
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
18 |
mkdir -p ~/.ssh && \
|
|
|
12 |
|
13 |
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
14 |
|
15 |
+
RUN pip install jupyterlab notebook ipywidgets && \
|
16 |
jupyter lab clean
|
17 |
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
18 |
mkdir -p ~/.ssh && \
|
scripts/cloud-entrypoint.sh
CHANGED
@@ -33,7 +33,7 @@ fi
|
|
33 |
|
34 |
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
35 |
# Run Jupyter Lab in the background
|
36 |
-
jupyter lab --allow-root --
|
37 |
fi
|
38 |
|
39 |
# Execute the passed arguments (CMD)
|
|
|
33 |
|
34 |
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
35 |
# Run Jupyter Lab in the background
|
36 |
+
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
37 |
fi
|
38 |
|
39 |
# Execute the passed arguments (CMD)
|
src/axolotl/cli/train.py
CHANGED
@@ -3,9 +3,11 @@ CLI to run training on a model
|
|
3 |
"""
|
4 |
import logging
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import fire
|
8 |
import transformers
|
|
|
9 |
|
10 |
from axolotl.cli import (
|
11 |
check_accelerate_default_config,
|
@@ -24,19 +26,23 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|
24 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
25 |
# pylint: disable=duplicate-code
|
26 |
parsed_cfg = load_cfg(config, **kwargs)
|
27 |
-
print_axolotl_text_art()
|
28 |
-
check_accelerate_default_config()
|
29 |
-
check_user_token()
|
30 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
31 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
32 |
return_remaining_strings=True
|
33 |
)
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
else:
|
38 |
-
dataset_meta = load_datasets(cfg=
|
39 |
-
|
|
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|
|
|
3 |
"""
|
4 |
import logging
|
5 |
from pathlib import Path
|
6 |
+
from typing import Tuple
|
7 |
|
8 |
import fire
|
9 |
import transformers
|
10 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
11 |
|
12 |
from axolotl.cli import (
|
13 |
check_accelerate_default_config,
|
|
|
26 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
27 |
# pylint: disable=duplicate-code
|
28 |
parsed_cfg = load_cfg(config, **kwargs)
|
|
|
|
|
|
|
29 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
30 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
31 |
return_remaining_strings=True
|
32 |
)
|
33 |
+
return do_train(parsed_cfg, parsed_cli_args)
|
34 |
+
|
35 |
|
36 |
+
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
37 |
+
print_axolotl_text_art()
|
38 |
+
check_accelerate_default_config()
|
39 |
+
check_user_token()
|
40 |
+
if cfg.rl:
|
41 |
+
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
42 |
else:
|
43 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
44 |
+
|
45 |
+
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
46 |
|
47 |
|
48 |
if __name__ == "__main__":
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -746,9 +746,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
746 |
training_arguments_kwargs[
|
747 |
"per_device_train_batch_size"
|
748 |
] = self.cfg.micro_batch_size
|
749 |
-
|
750 |
-
|
751 |
-
|
|
|
752 |
training_arguments_kwargs[
|
753 |
"gradient_accumulation_steps"
|
754 |
] = self.cfg.gradient_accumulation_steps
|
|
|
746 |
training_arguments_kwargs[
|
747 |
"per_device_train_batch_size"
|
748 |
] = self.cfg.micro_batch_size
|
749 |
+
if self.cfg.eval_batch_size:
|
750 |
+
training_arguments_kwargs[
|
751 |
+
"per_device_eval_batch_size"
|
752 |
+
] = self.cfg.eval_batch_size
|
753 |
training_arguments_kwargs[
|
754 |
"gradient_accumulation_steps"
|
755 |
] = self.cfg.gradient_accumulation_steps
|
src/axolotl/utils/bench.py
CHANGED
@@ -20,7 +20,8 @@ def check_cuda_device(default_value):
|
|
20 |
device = kwargs.get("device", args[0] if args else None)
|
21 |
|
22 |
if (
|
23 |
-
|
|
|
24 |
or device == "auto"
|
25 |
or torch.device(device).type == "cpu"
|
26 |
):
|
|
|
20 |
device = kwargs.get("device", args[0] if args else None)
|
21 |
|
22 |
if (
|
23 |
+
device is None
|
24 |
+
or not torch.cuda.is_available()
|
25 |
or device == "auto"
|
26 |
or torch.device(device).type == "cpu"
|
27 |
):
|
src/axolotl/utils/models.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
import logging
|
3 |
import math
|
4 |
import os
|
5 |
-
from typing import Any, Optional, Tuple, Union # noqa: F401
|
6 |
|
7 |
import addict
|
8 |
import bitsandbytes as bnb
|
@@ -348,7 +348,11 @@ def load_model(
|
|
348 |
LOG.info("patching _expand_mask")
|
349 |
hijack_expand_mask()
|
350 |
|
351 |
-
model_kwargs = {}
|
|
|
|
|
|
|
|
|
352 |
|
353 |
max_memory = cfg.max_memory
|
354 |
device_map = cfg.device_map
|
|
|
2 |
import logging
|
3 |
import math
|
4 |
import os
|
5 |
+
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
6 |
|
7 |
import addict
|
8 |
import bitsandbytes as bnb
|
|
|
348 |
LOG.info("patching _expand_mask")
|
349 |
hijack_expand_mask()
|
350 |
|
351 |
+
model_kwargs: Dict[str, Any] = {}
|
352 |
+
|
353 |
+
if cfg.model_kwargs:
|
354 |
+
for key, val in model_kwargs.items():
|
355 |
+
model_kwargs[key] = val
|
356 |
|
357 |
max_memory = cfg.max_memory
|
358 |
device_map = cfg.device_map
|