prepare does all this already for qlora?
Browse files- src/axolotl/utils/models.py +12 -12
src/axolotl/utils/models.py
CHANGED
@@ -204,17 +204,17 @@ def load_model(
|
|
204 |
"""### Post-processing on the model
|
205 |
Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
|
206 |
"""
|
207 |
-
if cfg.adapter == "qlora":
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
|
219 |
if not tokenizer:
|
220 |
try:
|
@@ -255,7 +255,7 @@ def load_model(
|
|
255 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
256 |
model.resize_token_embeddings(embeddings_len)
|
257 |
|
258 |
-
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
|
259 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
260 |
model = prepare_model_for_int8_training(model)
|
261 |
|
|
|
204 |
"""### Post-processing on the model
|
205 |
Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
|
206 |
"""
|
207 |
+
# if cfg.adapter == "qlora":
|
208 |
+
# for param in model.parameters():
|
209 |
+
# param.requires_grad = False # freeze the model - train adapters later
|
210 |
+
# if param.ndim == 1:
|
211 |
+
# # cast the small parameters (e.g. layernorm) to fp32 for stability
|
212 |
+
# param.data = param.data.to(torch.float32)
|
213 |
+
# class CastOutputToFloat(nn.Linear):
|
214 |
+
# def forward(self, x):
|
215 |
+
# return super().forward(x).to(torch.float32)
|
216 |
+
#
|
217 |
+
# model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias)
|
218 |
|
219 |
if not tokenizer:
|
220 |
try:
|
|
|
255 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
256 |
model.resize_token_embeddings(embeddings_len)
|
257 |
|
258 |
+
if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit:
|
259 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
260 |
model = prepare_model_for_int8_training(model)
|
261 |
|