testing mpt triton
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -8,7 +8,7 @@ import transformers
|
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
-
PreTrainedModel,
|
12 |
)
|
13 |
try:
|
14 |
from transformers import (
|
@@ -116,8 +116,14 @@ def load_model(
|
|
116 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
117 |
)
|
118 |
else:
|
|
|
|
|
|
|
|
|
|
|
119 |
model = AutoModelForCausalLM.from_pretrained(
|
120 |
base_model,
|
|
|
121 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
122 |
torch_dtype=torch_dtype,
|
123 |
device_map=cfg.device_map,
|
|
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
+
PreTrainedModel, AutoConfig,
|
12 |
)
|
13 |
try:
|
14 |
from transformers import (
|
|
|
116 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
117 |
)
|
118 |
else:
|
119 |
+
config = AutoConfig.from_pretrained(
|
120 |
+
base_model,
|
121 |
+
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
122 |
+
)
|
123 |
+
config.attn_config['attn_impl'] = 'triton'
|
124 |
model = AutoModelForCausalLM.from_pretrained(
|
125 |
base_model,
|
126 |
+
config=config,
|
127 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
128 |
torch_dtype=torch_dtype,
|
129 |
device_map=cfg.device_map,
|
src/axolotl/utils/wandb.py
CHANGED
@@ -2,7 +2,9 @@ import os
|
|
2 |
|
3 |
|
4 |
def setup_wandb_env_vars(cfg):
|
5 |
-
if cfg.
|
|
|
|
|
6 |
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
7 |
cfg.use_wandb = True
|
8 |
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
|
|
2 |
|
3 |
|
4 |
def setup_wandb_env_vars(cfg):
|
5 |
+
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
6 |
+
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
7 |
+
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
8 |
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
9 |
cfg.use_wandb = True
|
10 |
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|