llm-jp-3-13b-it / README.md
togepi55's picture
Update README.md
1bea88c verified
|
raw
history blame
4.63 kB
metadata
base_model: llm-jp/llm-jp-3-13b
library_name: peft
tags:
  - text-generation-inference
  - llama
  - trl
license: cc-by-sa-4.0

Model Card for Model ID

  • Developed by: togepi55
  • Funded by : llm-jp/llm-jp-3-13b
  • Language(s) (NLP): English, Japanese
  • License: cc-by-sa-4.0

注意

プロンプトは形式でのみ学習しています。

"""
<s>以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい

### 指示:
{instruction}

### 応答:
"""

サンプルコード

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from transformers import TextStreamer


BASE_MODEL = "togepi55/llm-jp-3-13b-it"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            device_map="auto",
            quantization_config=bnb_config,
            torch_dtype="auto",
            trust_remote_code=True,
        )

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)


instruction = "東京は日本の"


prompt = f"<s>以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい\n\n### 指示:\n{instruction}\n\n### 応答:\n"
print(prompt)
model_input = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = model_input["input_ids"]

model.eval()
with torch.no_grad():
    result = model.generate(
              input_ids,
              max_new_tokens=300,
              attention_mask = model_input.attention_mask,
              pad_token_id=tokenizer.pad_token_id,
              eos_token_id=tokenizer.eos_token_id,
              do_sample=False,
              streamer=streamer,
              repetition_penalty=1.02,
          )
    print("----"*20)
    del input_ids

Bias, Risks, and Limitations

RLHF,DPOを実施していないため不適切な表現が出力される可能性があります。

Training Details

指示チューニングデータとして下記のものを利用しました。

  • ichikara-instruction-003-001-1.json
  • ichikara-instruction-003-002-1.json
  • elyza/ELYZA-tasks-100

SFTの概要

  • 4bit量子化
  • LoRAによるSFT
  • learning_rate = 2e-4
  • optim="adamw_torch_fused"
  • lr_scheduler_type="cosine"
  • weight_decay=0.01

tasks-100-tv.jsonlでの出力方法

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from peft import LoraConfig, PeftModel
from datasets import load_dataset


BASE_MODEL = "llm-jp/llm-jp-3-13b"
PEFT_MODEL = "togepi55/llm-jp-3-13b-it"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
)

base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            device_map="auto",
            quantization_config=bnb_config,
            torch_dtype="auto",
            trust_remote_code=True,
        )

model = PeftModel.from_pretrained(base_model, PEFT_MODEL)

# elyza-tasks-100-TV_0.jsonl データの読み込み
from datasets import load_dataset

dataset = load_dataset("json", data_files="./elyza-tasks-100-TV_0.jsonl", split="train")


results = []

for num in tqdm(range(100)):
    instruction = dataset["input"][num]

    prompt = f"<s>以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい\n\n### 指示:\n{instruction}\n\n### 応答:\n"

    model_input = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids = model_input["input_ids"]

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=300,
            attention_mask = model_input.attention_mask,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=False,
            repetition_penalty=1.02,
        )[0]
    output = tokenizer.decode(outputs[input_ids.size(1):], skip_special_tokens=True)
    results.append({"task_id": num, "input": instruction, "output": output})



# 保存する場合
import json
with open("output.jsonl", "wt", encoding='utf-8') as f:
    for result in results:
        json.dump(result, f, ensure_ascii=False)
        f.write('\n')