File size: 6,810 Bytes
d72e6ae
892e2f9
d72e6ae
 
632f592
 
9547c62
632f592
9547c62
ab391c2
d72e6ae
b225b76
632f592
ab391c2
070377f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5720fe4
b225b76
 
 
 
 
 
 
a04f2e1
070377f
a04f2e1
 
 
070377f
 
 
 
 
 
 
 
 
 
 
411ad3b
070377f
411ad3b
d72e6ae
 
062ca1d
070377f
 
d72e6ae
070377f
 
4aafa13
93fda42
ab391c2
 
93fda42
070377f
d72e6ae
 
ba5c790
070377f
aa518eb
 
 
 
 
 
 
070377f
411ad3b
aa518eb
411ad3b
070377f
ba5c790
070377f
d72e6ae
 
 
861cd57
070377f
 
632f592
 
070377f
632f592
d72e6ae
 
 
 
 
 
 
632f592
d72e6ae
070377f
9547c62
d72e6ae
070377f
 
 
 
 
 
 
 
 
632f592
070377f
d72e6ae
070377f
 
 
 
93fda42
9547c62
 
 
 
 
 
070377f
892e2f9
070377f
 
 
 
93fda42
070377f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d72e6ae
 
f9b4329
070377f
a04f2e1
070377f
f9b4329
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
from sys import exit
import torch
import trl
from transformers import (
    AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
    PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
)
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset, Dataset
from tokenizers import ByteLevelBPETokenizer
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from itertools import islice
from typing import Optional
from logging import getLogger, StreamHandler, INFO

logger = getLogger(__name__)
logger.setLevel(INFO)
handler = StreamHandler()
logger.addHandler(handler)

class Config:
    # Model and training hyperparameters
    BATCH_SIZE = 16
    EPOCHS = 3
    LEARNING_RATE = 2e-4
    MAX_SEQ_LENGTH = 512
    VOCAB_SIZE = 32000
    FP16 = True
    WEIGHT_DECAY = 1e-3
    GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // 4

    # Dataset configurations
    INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
    INSTRUCT_DATASET = "nroggendorff/elephant"
    SHARD_SIZE = int(2e+5)

    # Output and repo settings
    OUTPUT_REPO = "nroggendorff/smallama"
    PUSH_TO_HUB = True
    INSTRUCT_FINETUNE_BOOL = False

    # Training steps and warmup
    FACTOR = 12 ** 3 // 3
    TOTAL_STEPS = (SHARD_SIZE * EPOCHS) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
    WARMUP_STEPS = int(TOTAL_STEPS * 0.1)

    # Initial state for shard offset
    INIT = 0

class Space:
    def __init__(self):
        self.api = HfApi()
        self.pause = lambda: self.api.pause_space("nroggendorff/train-llama")

space = Space()

class FineError(Exception):
    def __init__(self, message="Training completed successfully."):
        self.message = message
        super().__init__(self.message)

def load_data(dataset_name: str, split: str, shard_size: int, init_offset: int = 0) -> Dataset:
    dataset = load_dataset(dataset_name, split=split, streaming=True)
    shard_start = init_offset * shard_size
    data_list = list(islice(dataset, shard_start, shard_start + shard_size))
    return Dataset.from_dict({'text': [example.get('text', '') for example in data_list]})

def encode_decode(texts, tokenizer):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenized_texts = tokenizer(
        texts, padding="max_length", truncation=True, max_length=Config.MAX_SEQ_LENGTH, return_tensors="pt"
    ).input_ids
    return tokenizer.batch_decode(tokenized_texts) if tokenized_texts.dim() >= 1 else [tokenizer.pad_token * Config.MAX_SEQ_LENGTH]

def create_tokenizer(training_corpus):
    tokenizer = ByteLevelBPETokenizer()
    special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
    tokenizer.train_from_iterator(training_corpus, vocab_size=Config.VOCAB_SIZE, min_frequency=2, special_tokens=special_tokens)
    return PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)

def load_tokenizer(repo: str):
    return AutoTokenizer.from_pretrained(repo)

def get_training_corpus(dataset):
    for i in range(0, len(dataset['text']), 1000):
        yield dataset['text'][i : i + 1000]

def format_prompts(examples, tokenizer, is_instructional):
    texts = []
    for text in examples['text']:
        if text and len(text.strip()) > 0:
            if is_instructional:
                conversation = []
                parts = text.split('<|end|>')
                for i in range(0, len(parts) - 1, 2):
                    prompt = parts[i].replace("<|user|>", "").strip()
                    response = parts[i + 1].replace("<|bot|>", "").strip()
                    conversation.append({"role": "user", "content": prompt})
                    conversation.append({"role": "assistant", "content": response})
                coded_text = tokenizer.code(tokenizer.apply_chat_template(conversation, tokenize=False))
                texts.append(coded_text)
            else:
                texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
    if not texts:
        raise ValueError("No valid texts found in examples for formatting.")
    return {'text': tokenizer.code(texts)}

def create_model(tokenizer):
    config = LlamaConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=Config.FACTOR,
        intermediate_size=Config.FACTOR * 4,
        num_hidden_layers=12,
        num_attention_heads=12,
        max_position_embeddings=Config.MAX_SEQ_LENGTH,
        rms_norm_eps=1e-5,
        initializer_range=0.02,
        use_cache=True,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        tie_word_embeddings=False,
    )
    return LlamaForCausalLM(config)

def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
    config = SFTConfig(
        output_dir="model",
        num_train_epochs=Config.EPOCHS,
        per_device_train_batch_size=Config.BATCH_SIZE,
        learning_rate=Config.LEARNING_RATE,
        warmup_steps=Config.WARMUP_STEPS,
        weight_decay=Config.WEIGHT_DECAY,
        gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
        fp16=Config.FP16,
        save_steps=int(Config.WARMUP_STEPS * 5),
        logging_steps=int(Config.WARMUP_STEPS),
        save_total_limit=2,
        report_to="none",
    )
    dataset = dataset.map(
        lambda examples: format_prompts(examples, tokenizer, is_instructional), 
        batched=True, 
        remove_columns=dataset.column_names
    )
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        config=config,
        train_dataset=dataset
    )
    train_result = trainer.train()

    if push_to_hub:
        repo_id = Config.OUTPUT_REPO + "-it" if Config.INSTRUCT_FINETUNE_BOOL else Config.OUTPUT_REPO
        trainer.model.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
        trainer.tokenizer.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
    else:
        trainer.model.save_pretrained("model")
        trainer.tokenizer.save_pretrained("tokenizer")

def main():
    dataset = load_data(Config.INPUT_DATASET, "train", Config.SHARD_SIZE, Config.INIT)
    tokenizer = (
        load_tokenizer(Config.OUTPUT_REPO)
        if Config.INSTRUCT_FINETUNE_BOOL and Config.INIT > 0
        else create_tokenizer(get_training_corpus(dataset))
    )
    model = (
        load_model()
        if Config.INSTRUCT_FINETUNE_BOOL or Config.INIT > 0
        else create_model(tokenizer)
    )
    train_model(model, tokenizer, dataset, Config.PUSH_TO_HUB, Config.INSTRUCT_FINETUNE_BOOL)

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.error(f"{type(e).__name__}: {e}")
        space.pause()