Pratik Dwivedi
trainer commit (#1)
56d31bf
import torch
import transformers
import json
from dataclasses import dataclass
from typing import Dict, Sequence
from tqdm import tqdm
from torch.utils.data import Dataset
class ChatDataset(Dataset):
def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int):
super(ChatDataset, self).__init__()
data = []
with open(data_path, "r") as file:
for line in file:
try:
data.append(json.loads(line))
except Exception as e:
print("json processing exception", e)
continue
data_dict = preprocess(data, tokenizer, conversation_template, max_tokens)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class DataCollatorForChatDataset(object):
"""
Collate examples for supervised fine-tuning.
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
class ChatDataModule():
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int):
self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens)
self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer)
def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict:
"""
Preprocess the data by tokenizing.
"""
all_input_ids = []
all_label_ids = []
tokenizer.use_default_system_prompt = False
print("Tokenizing dataset...")
for conv in tqdm(conversations):
current_conv = conv["messages"]
tokenized_responses = []
for msg in current_conv:
if msg["role"] == "assistant":
tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False))
tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True)
all_input_ids.append(torch.LongTensor(tokenized_conv))
return dict(input_ids=all_input_ids, labels=all_input_ids)