import os
from sys import exit
import torch
import trl
from transformers import (
AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
)
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset, Dataset
from tokenizers import ByteLevelBPETokenizer
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from itertools import islice
from typing import Optional
from logging import getLogger, StreamHandler, INFO
logger = getLogger(__name__)
logger.setLevel(INFO)
handler = StreamHandler()
logger.addHandler(handler)
class Config:
# Model and training hyperparameters
BATCH_SIZE = 16
EPOCHS = 3
LEARNING_RATE = 2e-4
MAX_SEQ_LENGTH = 512
VOCAB_SIZE = 32000
FP16 = True
WEIGHT_DECAY = 1e-3
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // 4
# Dataset configurations
INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
INSTRUCT_DATASET = "nroggendorff/elephant"
SHARD_SIZE = int(2e+5)
# Output and repo settings
OUTPUT_REPO = "nroggendorff/smallama"
PUSH_TO_HUB = True
INSTRUCT_FINETUNE_BOOL = False
# Training steps and warmup
FACTOR = 12 ** 3 // 3
TOTAL_STEPS = (SHARD_SIZE * EPOCHS) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
WARMUP_STEPS = int(TOTAL_STEPS * 0.1)
# Initial state for shard offset
INIT = 0
class Space:
def __init__(self):
self.api = HfApi()
self.pause = lambda: self.api.pause_space("nroggendorff/train-llama")
space = Space()
class FineError(Exception):
def __init__(self, message="Training completed successfully."):
self.message = message
super().__init__(self.message)
def load_data(dataset_name: str, split: str, shard_size: int, init_offset: int = 0) -> Dataset:
dataset = load_dataset(dataset_name, split=split, streaming=True)
shard_start = init_offset * shard_size
data_list = list(islice(dataset, shard_start, shard_start + shard_size))
return Dataset.from_dict({'text': [example.get('text', '') for example in data_list]})
def encode_decode(texts, tokenizer):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenized_texts = tokenizer(
texts, padding="max_length", truncation=True, max_length=Config.MAX_SEQ_LENGTH, return_tensors="pt"
).input_ids
return tokenizer.batch_decode(tokenized_texts) if tokenized_texts.dim() >= 1 else [tokenizer.pad_token * Config.MAX_SEQ_LENGTH]
def create_tokenizer(training_corpus):
tokenizer = ByteLevelBPETokenizer()
special_tokens = ["", "", "", "", ""]
tokenizer.train_from_iterator(training_corpus, vocab_size=Config.VOCAB_SIZE, min_frequency=2, special_tokens=special_tokens)
return PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
def load_tokenizer(repo: str):
return AutoTokenizer.from_pretrained(repo)
def get_training_corpus(dataset):
for i in range(0, len(dataset['text']), 1000):
yield dataset['text'][i : i + 1000]
def format_prompts(examples, tokenizer, is_instructional):
texts = []
for text in examples['text']:
if text and len(text.strip()) > 0:
if is_instructional:
conversation = []
parts = text.split('<|end|>')
for i in range(0, len(parts) - 1, 2):
prompt = parts[i].replace("<|user|>", "").strip()
response = parts[i + 1].replace("<|bot|>", "").strip()
conversation.append({"role": "user", "content": prompt})
conversation.append({"role": "assistant", "content": response})
coded_text = tokenizer.code(tokenizer.apply_chat_template(conversation, tokenize=False))
texts.append(coded_text)
else:
texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
if not texts:
raise ValueError("No valid texts found in examples for formatting.")
return {'text': tokenizer.code(texts)}
def create_model(tokenizer):
config = LlamaConfig(
vocab_size=tokenizer.vocab_size,
hidden_size=Config.FACTOR,
intermediate_size=Config.FACTOR * 4,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=Config.MAX_SEQ_LENGTH,
rms_norm_eps=1e-5,
initializer_range=0.02,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
tie_word_embeddings=False,
)
return LlamaForCausalLM(config)
def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
config = SFTConfig(
output_dir="model",
num_train_epochs=Config.EPOCHS,
per_device_train_batch_size=Config.BATCH_SIZE,
learning_rate=Config.LEARNING_RATE,
warmup_steps=Config.WARMUP_STEPS,
weight_decay=Config.WEIGHT_DECAY,
gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
fp16=Config.FP16,
save_steps=int(Config.WARMUP_STEPS * 5),
logging_steps=int(Config.WARMUP_STEPS),
save_total_limit=2,
report_to="none",
)
dataset = dataset.map(
lambda examples: format_prompts(examples, tokenizer, is_instructional),
batched=True,
remove_columns=dataset.column_names
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
config=config,
train_dataset=dataset
)
train_result = trainer.train()
if push_to_hub:
repo_id = Config.OUTPUT_REPO + "-it" if Config.INSTRUCT_FINETUNE_BOOL else Config.OUTPUT_REPO
trainer.model.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
trainer.tokenizer.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
else:
trainer.model.save_pretrained("model")
trainer.tokenizer.save_pretrained("tokenizer")
def main():
dataset = load_data(Config.INPUT_DATASET, "train", Config.SHARD_SIZE, Config.INIT)
tokenizer = (
load_tokenizer(Config.OUTPUT_REPO)
if Config.INSTRUCT_FINETUNE_BOOL and Config.INIT > 0
else create_tokenizer(get_training_corpus(dataset))
)
model = (
load_model()
if Config.INSTRUCT_FINETUNE_BOOL or Config.INIT > 0
else create_model(tokenizer)
)
train_model(model, tokenizer, dataset, Config.PUSH_TO_HUB, Config.INSTRUCT_FINETUNE_BOOL)
if __name__ == "__main__":
try:
main()
except Exception as e:
logger.error(f"{type(e).__name__}: {e}")
space.pause()