Gemma Fine-tuning won't work with any other method except SFT
I have a library named EasyDeL and I have re-implemented Gemma
for that with some other options like flash-attention
, ring-attention
, blockwise_ffn
, and ...
but there's a problem the training won't do anything the loss will start from 8 and won't go any lower than 4.23 no matter which model you try I have already tried all of the Gemma
models, here's the example of training and fine-tuning model with EasyDeL (this code Is not for DPOTrainer but the same will happens for DPOTrainer
loss average of ~50 and model don't learn any)
Installation dependencies
You Need EasyDeL from head
pip install git+https://github.com/erfanzar/EasyDeL.git -q -U
pip install jax[tpu]==0.4.22 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q
Fine-tuning Code
from EasyDel import (
AutoEasyDelModelForCausalLM,
TrainArguments,
CausalLanguageModelTrainer,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers,
EasyDelState,
EasyDeLXRapTureConfig,
get_modules_by_type,
easystate_to_huggingface_model
)
from datasets import load_dataset
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp
import jax
from transformers import GemmaForCausalLM as ModuleTorch
def main(use_lora=False):
pretrained_model_name_or_path = "google/gemma-2b-it"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
input_shape=(1, 1),
device_map="auto",
sharding_axis_dims=(1, 1, 1, -1)
)
config = model.config
model_parameters = FrozenDict({"params": params})
dtype = jnp.bfloat16
config.add_basic_configurations(
attn_mechanism="normal",
block_b=1,
block_q=128,
block_k=128,
block_k_major=128,
)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True
)
max_length = 4096
configs_to_initialize_model_class = {
'config': config,
'dtype': dtype,
'param_dtype': dtype,
'input_shape': (1, max_length)
}
if tokenizer.pad_token == None:
tokenizer.pad_token = tokenizer.eos_token
rapture_config = EasyDeLXRapTureConfig(
model_parameters,
lora_dim=64,
fully_fine_tune_parameters=["embed_tokens"],
lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"],
verbose=True
) if use_lora else None
dataset = load_dataset(
"erfanzar/Zeus-v0.1-Llama",
split="train",
)
def gemma_prompt(x):
return x.replace(
"[/INST]", "<end_of_turn>\n<start_of_turn>model\n").replace(
"</s><s>[INST]", "<end_of_turn>\n").replace(
"<s>[INST] <<SYS>>\n", "<start_of_turn>system\n").replace(
"<s>[INST]", "<start_of_turn>user\n").replace(
"<</SYS>>\n", "<end_of_turn>\n").replace(
"<end_of_turn>\n\n", "<end_of_turn>\n"
)
def tokenization_process(data_chunk) -> dict:
return tokenizer(
gemma_prompt(data_chunk["prompt"]),
add_special_tokens=False,
max_length=max_length,
padding="max_length"
)
dataset = dataset.map(
tokenization_process,
num_proc=18,
remove_columns=dataset.column_names
)
train_args = TrainArguments(
model_class=get_modules_by_type(config.model_type)[1],
configs_to_initialize_model_class=configs_to_initialize_model_class,
custom_rule=config.get_partition_rules(True),
model_name="Jupyter",
num_train_epochs=2,
learning_rate=5e-5,
learning_rate_end=7e-6,
warmup_steps=200,
optimizer=EasyDelOptimizers.ADAMW,
scheduler=EasyDelSchedulers.LINEAR,
weight_decay=0.02,
total_batch_size=64,
max_sequence_length=max_length,
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, 1, 1, -1),
use_pjit_attention_force=False,
gradient_accumulation_steps=1,
init_input_shape=(1, max_length),
dtype=dtype,
param_dtype=dtype,
step_start_point=0,
training_time="7H",
rapture_config=rapture_config,
wandb_entity=None
)
trainer = CausalLanguageModelTrainer(
train_args,
dataset.shuffle().shuffle().shuffle(),
checkpoint_path=None
)
model_parameters = model_parameters if not use_lora else None
output = trainer.train(
model_parameters=model_parameters,
state=None
)
with jax.default_device(jax.devices("cpu")[0]):
model = easystate_to_huggingface_model(
state=EasyDelState.load_state(
output.checkpoint_path
),
base_huggingface_module=ModuleTorch,
config=config
)
model = model.half()
model.push_to_hub("Gemma-2B-Fine-tuned")
tokenizer.push_to_hub("Gemma-2B-Fine-tuned")
if __name__ == "__main__":
main()
I haven't seen this code before; do you track any other metrics at the start of training that might indicate what's wrong?
Just to check, are you finetuning the pretrained checkpoints?
yes I'm fine-tuning the pre-trained model and this is my library EasyDeL and I track some metrics like TPU/GPU/CPU usage mean (loss/accuracy) loss, and accuracy, perplexity, trained_tokens, learning rate
these are gemma-7b-it charts which I have tried with higher lr but in lower learning rates exactly the same would happen
Aditional information
- Model Generate Text fine
- Used Code
from EasyDel import JAXServer, JAXServerConfig, EasyServe
from fjformer import get_dtype
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, Qwen2Prompter
from EasyDel.serve.prompters.base_prompter import BasePrompter
from jax import numpy as jnp, lax
import jax
from typing import List, Union, Optional
max_sequence_length = 8192
max_compile_tokens = 256
max_new_tokens_ratio = 25
dtype = "bf16"
prompter_type = "gemma"
sharding_axis_dims = (1, 1, 1, -1)
pretrained_model_name_or_path = "google/gemma-7b-it"
attn_mechanism = "normal"
scan_mlp_chunk_size = max_compile_tokens
use_scan_mlp = True
scan_ring_attention = True
block_k = 128
block_q = 128
use_sharded_kv_caching = False
server_config = JAXServerConfig(
max_sequence_length=max_sequence_length,
max_compile_tokens=max_compile_tokens,
max_new_tokens=max_compile_tokens * max_new_tokens_ratio,
dtype=dtype,
pre_compile=False
)
prompters = {
"gemma": GemmaPrompter(),
"llama": Llama2Prompter(),
"openchat": OpenChatPrompter(),
"qwen2": Qwen2Prompter()
}
prompter: BasePrompter = prompters[prompter_type]
class JAXServerC(JAXServer):
@staticmethod
def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
return prompter.format_message(
history=history,
prompt=prompt,
system_message=system,
prefix=None
)
@staticmethod
def format_instruct(system: str, instruction: str) -> str:
return prompter.format_message(
prefix=None,
system_message=system,
prompt=instruction,
history=[]
)
server = JAXServerC.from_torch_pretrained(
server_config=server_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
dtype=get_dtype(dtype=dtype),
param_dtype=get_dtype(dtype=dtype),
precision=jax.lax.Precision("fastest"),
sharding_axis_dims=sharding_axis_dims,
sharding_axis_names=("dp", "fsdp", "tp", "sp"),
input_shape=(1, server_config.max_sequence_length),
model_config_kwargs=dict(
fully_sharded_data_parallel=True,
attn_mechanism=attn_mechanism,
scan_mlp_chunk_size=max_compile_tokens,
use_scan_mlp=use_scan_mlp,
scan_ring_attention=scan_ring_attention,
block_k=block_k,
block_q=block_q,
use_sharded_kv_caching=use_sharded_kv_caching
)
)
history = []
while True:
user_prompt = input("> ")
model_prompt = server.format_chat(history, user_prompt, None)
past_response_length = 0
for response, used_tokens in server.sample(
model_prompt,
greedy=False
):
print(response[past_response_length:], end="")
past_response_length = len(response)
history.append([user_prompt, response])
- Trainer Loops (DPO, CLM) are both working fine.