Spaces:
Build error
Build error
import torch | |
import transformers | |
import json | |
from dataclasses import dataclass | |
from typing import Dict, Sequence | |
from tqdm import tqdm | |
from torch.utils.data import Dataset | |
class ChatDataset(Dataset): | |
def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int): | |
super(ChatDataset, self).__init__() | |
data = [] | |
with open(data_path, "r") as file: | |
for line in file: | |
try: | |
data.append(json.loads(line)) | |
except Exception as e: | |
print("json processing exception", e) | |
continue | |
data_dict = preprocess(data, tokenizer, conversation_template, max_tokens) | |
self.input_ids = data_dict["input_ids"] | |
self.labels = data_dict["labels"] | |
def __len__(self): | |
return len(self.input_ids) | |
def __getitem__(self, i) -> Dict[str, torch.Tensor]: | |
return dict(input_ids=self.input_ids[i], labels=self.labels[i]) | |
class DataCollatorForChatDataset(object): | |
""" | |
Collate examples for supervised fine-tuning. | |
""" | |
tokenizer: transformers.PreTrainedTokenizer | |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: | |
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids")) | |
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) | |
return dict( | |
input_ids=input_ids, | |
labels=labels, | |
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), | |
) | |
class ChatDataModule(): | |
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int): | |
self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens) | |
self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer) | |
def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict: | |
""" | |
Preprocess the data by tokenizing. | |
""" | |
all_input_ids = [] | |
all_label_ids = [] | |
tokenizer.use_default_system_prompt = False | |
print("Tokenizing dataset...") | |
for conv in tqdm(conversations): | |
current_conv = conv["messages"] | |
tokenized_responses = [] | |
for msg in current_conv: | |
if msg["role"] == "assistant": | |
tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False)) | |
tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True) | |
all_input_ids.append(torch.LongTensor(tokenized_conv)) | |
return dict(input_ids=all_input_ids, labels=all_input_ids) |