Spaces:
Paused
Paused
nroggendorff
commited on
Commit
•
9547c62
1
Parent(s):
089175b
swap for sftc
Browse fileswdym trainingargs is deprecated :(
train.py
CHANGED
@@ -4,8 +4,9 @@ import torch
|
|
4 |
import trl
|
5 |
from transformers import (
|
6 |
AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
|
7 |
-
|
8 |
)
|
|
|
9 |
from datasets import load_dataset, Dataset
|
10 |
from tokenizers import ByteLevelBPETokenizer
|
11 |
from huggingface_hub import HfApi
|
@@ -126,7 +127,7 @@ def create_model(tokenizer):
|
|
126 |
return LlamaForCausalLM(config)
|
127 |
|
128 |
def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
|
129 |
-
|
130 |
output_dir="model",
|
131 |
num_train_epochs=Config.EPOCHS,
|
132 |
per_device_train_batch_size=Config.BATCH_SIZE,
|
@@ -145,7 +146,12 @@ def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
|
|
145 |
batched=True,
|
146 |
remove_columns=dataset.column_names
|
147 |
)
|
148 |
-
trainer =
|
|
|
|
|
|
|
|
|
|
|
149 |
train_result = trainer.train()
|
150 |
|
151 |
if push_to_hub:
|
|
|
4 |
import trl
|
5 |
from transformers import (
|
6 |
AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
|
7 |
+
PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
|
8 |
)
|
9 |
+
from trl import SFTConfig, SFTTrainer
|
10 |
from datasets import load_dataset, Dataset
|
11 |
from tokenizers import ByteLevelBPETokenizer
|
12 |
from huggingface_hub import HfApi
|
|
|
127 |
return LlamaForCausalLM(config)
|
128 |
|
129 |
def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
|
130 |
+
config = SFTConfig(
|
131 |
output_dir="model",
|
132 |
num_train_epochs=Config.EPOCHS,
|
133 |
per_device_train_batch_size=Config.BATCH_SIZE,
|
|
|
146 |
batched=True,
|
147 |
remove_columns=dataset.column_names
|
148 |
)
|
149 |
+
trainer = SFTTrainer(
|
150 |
+
model=model,
|
151 |
+
tokenizer=tokenizer,
|
152 |
+
config=config,
|
153 |
+
train_dataset=dataset
|
154 |
+
)
|
155 |
train_result = trainer.train()
|
156 |
|
157 |
if push_to_hub:
|