PEFT
Safetensors
German
trl
sft
Generated from Trainer
JanPf commited on
Commit
babd921
1 Parent(s): dd106d3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -2
README.md CHANGED
@@ -31,7 +31,7 @@ torch.manual_seed(42)
31
  # script config
32
  base_model_name = "LSX-UniWue/LLaMmlein_1B"
33
  chat_adapter_name = "LSX-UniWue/LLaMmlein_1B_chat_evol_instruct"
34
- device = "mps" # or cuda
35
 
36
  # chat history
37
  messages = [
@@ -45,7 +45,6 @@ messages = [
45
  config = PeftConfig.from_pretrained(chat_adapter_name)
46
  base_model = model = AutoModelForCausalLM.from_pretrained(
47
  base_model_name,
48
- attn_implementation="flash_attention_2" if device == "cuda" else None,
49
  torch_dtype=torch.bfloat16,
50
  device_map=device,
51
  )
 
31
  # script config
32
  base_model_name = "LSX-UniWue/LLaMmlein_1B"
33
  chat_adapter_name = "LSX-UniWue/LLaMmlein_1B_chat_evol_instruct"
34
+ device = "cuda" # or mps
35
 
36
  # chat history
37
  messages = [
 
45
  config = PeftConfig.from_pretrained(chat_adapter_name)
46
  base_model = model = AutoModelForCausalLM.from_pretrained(
47
  base_model_name,
 
48
  torch_dtype=torch.bfloat16,
49
  device_map=device,
50
  )