disable eval using multipack for now (#437)
Browse files- src/axolotl/utils/trainer.py +22 -22
src/axolotl/utils/trainer.py
CHANGED
@@ -14,7 +14,7 @@ import bitsandbytes as bnb
|
|
14 |
import numpy as np
|
15 |
import torch.cuda
|
16 |
import transformers
|
17 |
-
from datasets import
|
18 |
from torch import nn
|
19 |
from torch.optim.lr_scheduler import OneCycleLR
|
20 |
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
@@ -188,27 +188,27 @@ class AxolotlTrainer(Trainer):
|
|
188 |
)
|
189 |
return super().get_train_dataloader()
|
190 |
|
191 |
-
def get_eval_dataloader(
|
192 |
-
|
193 |
-
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
|
213 |
def compute_loss(self, model, inputs, return_outputs=False):
|
214 |
# use one's weighted cross entropy loss calc
|
|
|
14 |
import numpy as np
|
15 |
import torch.cuda
|
16 |
import transformers
|
17 |
+
from datasets import set_caching_enabled
|
18 |
from torch import nn
|
19 |
from torch.optim.lr_scheduler import OneCycleLR
|
20 |
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
|
|
188 |
)
|
189 |
return super().get_train_dataloader()
|
190 |
|
191 |
+
# def get_eval_dataloader(
|
192 |
+
# self, eval_dataset: Optional[Dataset] = None
|
193 |
+
# ) -> Union[DataLoader, MultipackDistributedDataloader]:
|
194 |
+
# if self.args.sample_packing:
|
195 |
+
# eval_dataset = (
|
196 |
+
# eval_dataset if eval_dataset is not None else self.eval_dataset
|
197 |
+
# )
|
198 |
+
# eval_sampler = self._get_eval_sampler(eval_dataset)
|
199 |
+
# return self.accelerator.prepare(
|
200 |
+
# MultipackDistributedDataloader(
|
201 |
+
# eval_dataset,
|
202 |
+
# batch_size=self.args.eval_batch_size,
|
203 |
+
# seq_max_length=self.args.max_seq_length,
|
204 |
+
# collate_fn=self.data_collator,
|
205 |
+
# sampler=eval_sampler,
|
206 |
+
# packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
207 |
+
# sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
208 |
+
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
209 |
+
# )
|
210 |
+
# )
|
211 |
+
# return super().get_eval_dataloader(eval_dataset)
|
212 |
|
213 |
def compute_loss(self, model, inputs, return_outputs=False):
|
214 |
# use one's weighted cross entropy loss calc
|