AIR-hl commited on
Commit
f39c13e
1 Parent(s): e4b2f2a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +134 -3
README.md CHANGED
@@ -1,3 +1,134 @@
1
- ---
2
- license: llama3.2
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3.2
3
+ datasets:
4
+ - HuggingFaceH4/ultrachat_200k
5
+ base_model:
6
+ - meta-llama/Llama-3.2-1B
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - trl
10
+ - llama
11
+ - sft
12
+ - alignment
13
+ - transformers
14
+ - custome
15
+ - chat
16
+ ---
17
+ # Llama-3.2-1B-ultrachat200k
18
+
19
+
20
+ ## Model Details
21
+
22
+ - **Model type:** sft model
23
+ - **License:** llama3.2
24
+ - **Finetuned from model:** [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
25
+ - **Training data:** [HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
26
+ - **Training framework:** [trl](https://github.com/huggingface/trl)
27
+
28
+ ## Training Details
29
+
30
+ ### Training Hyperparameters
31
+ `attn_implementation`: flash_attention_2 \
32
+ `bf16`: True \
33
+ `learning_rate`: 2e-5 \
34
+ `lr_scheduler_type`: cosine \
35
+ `per_device_train_batch_size`: 2 \
36
+ `gradient_accumulation_steps`: 16 \
37
+ `torch_dtype`: bfloat16 \
38
+ `num_train_epochs`: 1 \
39
+ `max_seq_length`: 2048 \
40
+ `warmup_ratio`: 0.1
41
+
42
+ ### Results
43
+
44
+ `init_train_loss`: 1.726 \
45
+ `final_train_loss`: 1.22 \
46
+
47
+ ### Training script
48
+
49
+ ```python
50
+ import multiprocessing
51
+
52
+ from datasets import load_dataset
53
+ from tqdm.rich import tqdm
54
+ from transformers import AutoTokenizer, AutoModelForCausalLM
55
+ from trl import (
56
+ ModelConfig,
57
+ SFTTrainer,
58
+ get_peft_config,
59
+ get_quantization_config,
60
+ get_kbit_device_map,
61
+ SFTConfig,
62
+ ScriptArguments,
63
+ TrlParser
64
+ )
65
+
66
+ tqdm.pandas()
67
+
68
+ if __name__ == "__main__":
69
+ parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
70
+ args, training_args, model_config = parser.parse_args_and_config()
71
+
72
+ quantization_config = get_quantization_config(model_config)
73
+ model_kwargs = dict(
74
+ revision=model_config.model_revision,
75
+ trust_remote_code=model_config.trust_remote_code,
76
+ attn_implementation=model_config.attn_implementation,
77
+ torch_dtype=model_config.torch_dtype,
78
+ use_cache=False if training_args.gradient_checkpointing else True,
79
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
80
+ quantization_config=quantization_config,
81
+ )
82
+
83
+ model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,
84
+ **model_kwargs)
85
+ tokenizer = AutoTokenizer.from_pretrained(
86
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
87
+ )
88
+ tokenizer.pad_token = '<|end_of_text|>'
89
+
90
+ train_dataset = load_dataset(args.dataset_name,
91
+ split=args.dataset_train_split,
92
+ num_proc=multiprocessing.cpu_count())
93
+
94
+ trainer = SFTTrainer(
95
+ model=model,
96
+ args=training_args,
97
+ train_dataset=train_dataset,
98
+ processing_class=tokenizer,
99
+ peft_config=get_peft_config(model_config),
100
+ )
101
+
102
+ trainer.train()
103
+
104
+ trainer.save_model(training_args.output_dir)
105
+ ```
106
+
107
+ ### Test Script
108
+ ```python
109
+ from vllm import LLM
110
+ from datasets import load_dataset
111
+ from vllm.sampling_params import SamplingParams
112
+ from transformers import AutoTokenizer
113
+
114
+ MODEL_PATH = "autodl-tmp/saves/Llama-3.2-1B-ultrachat200k"
115
+
116
+ model = LLM(MODEL_PATH,
117
+ tensor_parallel_size=1,
118
+ dtype='bfloat16')
119
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
120
+
121
+ input = tokenizer.apply_chat_template([{"role": "user", "content": "Where is Harbin?"}],
122
+ tokenize=False,
123
+ add_generation_prompt=True)
124
+ sampling_params = SamplingParams(max_tokens=1024,
125
+ temperature=0.7,
126
+ logprobs=1,
127
+ stop_token_ids=[tokenizer.eos_token_id])
128
+
129
+ vllm_generations = model.generate(input,
130
+ sampling_params)
131
+
132
+ print(vllm_generations[0].outputs[0].text)
133
+ # print result: Harbin is located in northeastern China in the Heilongjiang province. It is the capital of Heilongjiang province in the Northeast Asia.
134
+ ```