Update README.md
Browse files
README.md
CHANGED
@@ -124,9 +124,8 @@ from transformers import (
|
|
124 |
from peft import LoraConfig, PeftModel
|
125 |
from datasets import load_dataset
|
126 |
|
127 |
-
|
128 |
BASE_MODEL = "llm-jp/llm-jp-3-13b"
|
129 |
-
PEFT_MODEL = "libkazz/llm-jp-3-13b-
|
130 |
|
131 |
tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL)
|
132 |
bnb_config = BitsAndBytesConfig(
|
@@ -148,30 +147,32 @@ model = PeftModel.from_pretrained(base_model, PEFT_MODEL)
|
|
148 |
|
149 |
# elyza-tasks-100-TV_0.jsonl データの読み込み
|
150 |
from datasets import load_dataset
|
|
|
151 |
|
152 |
-
|
153 |
|
154 |
results = []
|
155 |
-
for
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
175 |
|
176 |
# ファイルに保存する
|
177 |
import json
|
|
|
124 |
from peft import LoraConfig, PeftModel
|
125 |
from datasets import load_dataset
|
126 |
|
|
|
127 |
BASE_MODEL = "llm-jp/llm-jp-3-13b"
|
128 |
+
PEFT_MODEL = "libkazz/llm-jp-3-13b-it_lora"
|
129 |
|
130 |
tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL)
|
131 |
bnb_config = BitsAndBytesConfig(
|
|
|
147 |
|
148 |
# elyza-tasks-100-TV_0.jsonl データの読み込み
|
149 |
from datasets import load_dataset
|
150 |
+
from tqdm import tqdm
|
151 |
|
152 |
+
datasets = load_dataset("json", data_files="./elyza-tasks-100-TV_0.jsonl", split="train")
|
153 |
|
154 |
results = []
|
155 |
+
for dt in tqdm(datasets):
|
156 |
+
input = dt["input"]
|
157 |
+
prompt = f"次の指示に忠実に回答を作成しなさい。\n\n### 指示:\n{input}\n\n### 回答:\n"
|
158 |
+
|
159 |
+
model_input = tokenizer([prompt], return_tensors = "pt").to(model.device)
|
160 |
+
|
161 |
+
if "token_type_ids" in model_input:
|
162 |
+
del model_input["token_type_ids"]
|
163 |
+
|
164 |
+
outputs = model.generate(
|
165 |
+
**model_input,
|
166 |
+
max_new_tokens = 512,
|
167 |
+
pad_token_id=tokenizer.pad_token_id,
|
168 |
+
eos_token_id=tokenizer.eos_token_id,
|
169 |
+
use_cache = True,
|
170 |
+
do_sample=False,
|
171 |
+
repetition_penalty=1.2
|
172 |
+
)
|
173 |
+
|
174 |
+
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).split('\n### 回答')[-1]
|
175 |
+
results.append({"task_id": dt["task_id"], "input": input, "output": prediction})
|
176 |
|
177 |
# ファイルに保存する
|
178 |
import json
|