Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from typing import Any, Dict, List, Sequence, Tuple | |
import torch | |
from transformers import DataCollatorForSeq2Seq | |
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): | |
r""" | |
Data collator for pairwise data. | |
""" | |
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: | |
padded_labels = [] | |
for feature, (prompt_len, answer_len) in zip(batch, positions): | |
if self.tokenizer.padding_side == "left": | |
start, end = feature.size(0) - answer_len, feature.size(0) | |
else: | |
start, end = prompt_len, prompt_len + answer_len | |
padded_tensor = self.label_pad_token_id * torch.ones_like(feature) | |
padded_tensor[start:end] = feature[start:end] | |
padded_labels.append(padded_tensor) | |
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory | |
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
r""" | |
Pads batched data to the longest sequence in the batch. | |
We generate 2 * n examples where the first n examples represent chosen examples and | |
the last n examples represent rejected examples. | |
""" | |
concatenated_features = [] | |
label_positions = [] | |
for key in ("chosen_ids", "rejected_ids"): | |
for feature in features: | |
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) | |
concatenated_features.append( | |
{ | |
"input_ids": feature["prompt_ids"] + feature[key], | |
"attention_mask": [1] * (prompt_len + answer_len), | |
} | |
) | |
label_positions.append((prompt_len, answer_len)) | |
batch = self.tokenizer.pad( | |
concatenated_features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors=self.return_tensors, | |
) | |
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) | |
return batch | |