File size: 3,061 Bytes
56d31bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import transformers
import json

from dataclasses import dataclass
from typing import Dict, Sequence
from tqdm import tqdm
from torch.utils.data import Dataset


class ChatDataset(Dataset):
    def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int):
        super(ChatDataset, self).__init__()
        data = []
        with open(data_path, "r") as file:
            for line in file:  
                try:
                    data.append(json.loads(line))
                except Exception as e:
                    print("json processing exception", e)
                    continue


        data_dict = preprocess(data, tokenizer, conversation_template, max_tokens)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


@dataclass
class DataCollatorForChatDataset(object):
    """
    Collate examples for supervised fine-tuning.
    """

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )
    

class ChatDataModule():
    def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int):

        self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens)
        self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer)
        

def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict:
    """
    Preprocess the data by tokenizing.
    """
    all_input_ids = []
    all_label_ids = []
    tokenizer.use_default_system_prompt = False

    print("Tokenizing dataset...")
    for conv in tqdm(conversations):
        current_conv = conv["messages"]
        tokenized_responses = []
        for msg in current_conv:
            if msg["role"] == "assistant":
                tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False))

        tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True)
        all_input_ids.append(torch.LongTensor(tokenized_conv))


    return dict(input_ids=all_input_ids, labels=all_input_ids)