winglian commited on
Commit
59a31fe
·
unverified ·
1 Parent(s): 814aee6

DPO fixes v2 (#1174)

Browse files

* check for length before trying to remove it

* add validation for sample packing with RLHF

src/axolotl/core/trainer_builder.py CHANGED
@@ -227,7 +227,8 @@ class AxolotlTrainer(Trainer):
227
  def get_train_dataloader(self) -> DataLoader:
228
  if self.args.sample_packing and not self.args.pretraining:
229
  train_dataset = self.train_dataset
230
- train_dataset = train_dataset.remove_columns(["length"])
 
231
  data_collator = self.data_collator
232
  dataloader_params = {
233
  "batch_size": self._train_batch_size,
 
227
  def get_train_dataloader(self) -> DataLoader:
228
  if self.args.sample_packing and not self.args.pretraining:
229
  train_dataset = self.train_dataset
230
+ if "length" in train_dataset.features.keys():
231
+ train_dataset = train_dataset.remove_columns(["length"])
232
  data_collator = self.data_collator
233
  dataloader_params = {
234
  "batch_size": self._train_batch_size,
src/axolotl/utils/config.py CHANGED
@@ -204,6 +204,9 @@ def validate_config(cfg):
204
  if cfg.max_packed_sequence_len:
205
  raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
206
 
 
 
 
207
  if cfg.sample_packing and not cfg.pad_to_sequence_len:
208
  LOG.warning(
209
  "`pad_to_sequence_len: true` is recommended when using sample_packing"
 
204
  if cfg.max_packed_sequence_len:
205
  raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
206
 
207
+ if cfg.sample_packing and cfg.rl:
208
+ raise ValueError("`sample_packing: true` does not work with RLHF training")
209
+
210
  if cfg.sample_packing and not cfg.pad_to_sequence_len:
211
  LOG.warning(
212
  "`pad_to_sequence_len: true` is recommended when using sample_packing"