|
## zR-Llama-1B-chatglm2-6b-tokenizer |
|
|
|
本模型是基于 [build_MiniLLM_from_scratch 开源框架](https://github.com/Tongjilibo/build_MiniLLM_from_scratch) 自行训练的一个1B模型。 |
|
|
|
## 模型参数 |
|
+ 1B 参数量 |
|
+ 训练语料670亿。 |
|
+ 模型支持token长度 896 |
|
|
|
|
|
## 预训练模型 |
|
|
|
+ 使用 [build_MiniLLM_from_scratch 开源框架](https://github.com/Tongjilibo/build_MiniLLM_from_scratch) 的预训练数据集,自己完成 Tokenize 过程。 |
|
+ 使用 8 x 80GB A800 GPU 训练。 |
|
+ 训练 1 Epoch,bs=32 (每张卡) , lr=1.5e-4。 |
|
+ 共耗时 1 天。 |
|
|
|
## SFT模型 |
|
+ 使用 [build_MiniLLM_from_scratch 开源框架](https://github.com/Tongjilibo/build_MiniLLM_from_scratch) 提供的全部数据集 |
|
+ 使用 单卡A800 微调。 |
|
+ 微调 5 Epoch, bs=8, lr=2e-5。 |
|
+ 共耗时 3 天 12 小时。 |
|
|
|
## 使用模型 |
|
|
|
```python |
|
import os |
|
import torch |
|
from transformers import AutoTokenizer, LlamaForCausalLM |
|
|
|
max_length = 896 |
|
HUMAN = '<human>' |
|
ROBOT = '<robot>' |
|
def build_prompt(query, history) -> str: |
|
texts = '' |
|
for user_input, response in history: |
|
texts += f'{HUMAN}{user_input}{ROBOT}{response}' |
|
|
|
texts += f'{HUMAN}{query}{ROBOT}' |
|
return texts |
|
|
|
def build_cli_history(history): |
|
prompt = '' |
|
for query, response in history: |
|
prompt += f"\n\nUser:{query.strip()}" |
|
prompt += f"\n\nRobot:{response.strip()}" |
|
return prompt |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
tokenizer = AutoTokenizer.from_pretrained("zRzRzRzRzRzRzR/zR-Llama-1b-ChatGLM2-6b-tokenizer", trust_remote_code=True) |
|
model = LlamaForCausalLM.from_pretrained("zRzRzRzRzRzRzR/zR-Llama-1b-ChatGLM2-6b-tokenizer").to(device) |
|
|
|
history = [] |
|
clear_command = 'cls' if os.name == 'nt' else 'clear' |
|
while True: |
|
query = input('\n输入:') |
|
if query.strip() == "stop": |
|
break |
|
if query.strip() == "clear": |
|
history = [] |
|
os.system(clear_command) |
|
continue |
|
|
|
inputs = tokenizer.encode(build_prompt(query, history), return_tensors='pt', add_special_tokens=False).to(device) |
|
response = model.generate(inputs) |
|
response = tokenizer.decode(response[0].cpu(), skip_special_tokens=True) |
|
|
|
os.system(clear_command) |
|
print(build_cli_history(history + [(query, response)]), flush=True) |
|
``` |
|
|
|
|
|
|