8bit and deepspeed changes
Browse files- ds_config.json +5 -3
- src/axolotl/utils/models.py +6 -13
ds_config.json
CHANGED
@@ -20,10 +20,12 @@
|
|
20 |
}
|
21 |
},
|
22 |
"scheduler": {
|
23 |
-
"type": "
|
24 |
"params": {
|
25 |
-
"
|
26 |
-
"
|
|
|
|
|
27 |
}
|
28 |
},
|
29 |
"zero_optimization": {
|
|
|
20 |
}
|
21 |
},
|
22 |
"scheduler": {
|
23 |
+
"type": "WarmupDecayLR",
|
24 |
"params": {
|
25 |
+
"warmup_min_lr": "auto",
|
26 |
+
"warmup_max_lr": "auto",
|
27 |
+
"warmup_num_steps": "auto",
|
28 |
+
"total_num_steps": "auto"
|
29 |
}
|
30 |
},
|
31 |
"zero_optimization": {
|
src/axolotl/utils/models.py
CHANGED
@@ -101,19 +101,12 @@ def load_model(
|
|
101 |
)
|
102 |
load_in_8bit = False
|
103 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
model = LlamaForCausalLM.from_pretrained(
|
111 |
-
base_model,
|
112 |
-
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
113 |
-
torch_dtype=torch_dtype,
|
114 |
-
device_map=cfg.device_map,
|
115 |
-
)
|
116 |
-
|
117 |
elif model_type:
|
118 |
model = getattr(transformers, model_type).from_pretrained(
|
119 |
base_model,
|
|
|
101 |
)
|
102 |
load_in_8bit = False
|
103 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
104 |
+
model = LlamaForCausalLM.from_pretrained(
|
105 |
+
base_model,
|
106 |
+
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
107 |
+
torch_dtype=torch_dtype,
|
108 |
+
device_map=cfg.device_map,
|
109 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
elif model_type:
|
111 |
model = getattr(transformers, model_type).from_pretrained(
|
112 |
base_model,
|