|
--- |
|
base_model: |
|
- oxygen65/llm-jp-3-13b-finetune-2 |
|
- llm-jp/llm-jp-3-13b |
|
tags: |
|
- text-generation-inference |
|
- transformers |
|
- unsloth |
|
- llama |
|
- trl |
|
license: cc-by-nc-sa-4.0 |
|
language: |
|
- ja |
|
datasets: |
|
- elyza/ELYZA-tasks-100 |
|
--- |
|
# How to Use |
|
|
|
## 1. load this model and tokenizer |
|
```python |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
) |
|
import torch |
|
from tqdm import tqdm |
|
import json |
|
|
|
model_name = "oxygen65/llm-jp-3-13b-finetune-3" |
|
|
|
# QLoRA config |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_use_double_quant=False, |
|
) |
|
|
|
# Load model |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
) |
|
|
|
# Load tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
``` |
|
|
|
## 2. load Eval Datasets |
|
```python |
|
tasks = [] |
|
with open("./elyza-tasks-100-TV_0.jsonl", "r") as f: |
|
item = "" |
|
for line in f: |
|
line = line.strip() |
|
item += line |
|
if item.endswith("}"): |
|
tasks.append(json.loads(item)) |
|
item = "" |
|
|
|
from datasets import load_dataset |
|
sample_task_ds = load_dataset("elyza/ELYZA-tasks-100") |
|
sample_tasks = sample_task_ds['test'] |
|
sample_tasks['input'][0] |
|
``` |
|
|
|
## 3. set up retrievers |
|
if you can't find "rank_bm25" python package in your environment |
|
|
|
```bash |
|
!pip install rank_bm25 |
|
``` |
|
|
|
```python |
|
from rank_bm25 import BM25Okapi |
|
from nltk.tokenize import word_tokenize |
|
import nltk |
|
import numpy as np |
|
|
|
|
|
# 必要なデータをダウンロード(初回のみ) |
|
nltk.download('punkt') |
|
nltk.download('punkt_tab') |
|
|
|
def search_similar_documents_bm25(query, sample_tasks): |
|
# トークン化(BM25はトークン化されたデータを要求します) |
|
tokenized_documents = [word_tokenize(doc) for doc in sample_tasks['input']] |
|
|
|
# BM25オブジェクトの作成 |
|
bm25 = BM25Okapi(tokenized_documents) |
|
|
|
tokenized_query = word_tokenize(query) |
|
# 類似度の計算 |
|
doc_scores = bm25.get_scores(tokenized_query) |
|
# 類似度が高い順にソート |
|
sorted_indexes = np.argsort(doc_scores)[::-1] |
|
|
|
indexes = [] |
|
for i in range(len(doc_scores)): |
|
if doc_scores[sorted_indexes[i]] < 20.0: |
|
break |
|
else: |
|
indexes.append(sorted_indexes[i]) |
|
|
|
return indexes |
|
|
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import numpy as np |
|
SentTF = SentenceTransformer('all-MiniLM-L6-v2') |
|
def seearch_similar_documents_neuralRetriver(query, sample_tasks): |
|
global SentTF |
|
emb1 = SentTF.encode([query]) |
|
emb2 = SentTF.encode(sample_tasks['input']) |
|
# 全ての組み合わせで類似度を計算 |
|
similarity_matrix = cosine_similarity(emb1, emb2) #時間かかるので先に計算しておくべき |
|
# 類似度が高い順にソート |
|
sorted_indexes = np.argsort(similarity_matrix[0])[::-1] |
|
#print(sorted_indexes) |
|
|
|
indexes = [] |
|
for i in range(len(sample_tasks['input'])): |
|
if similarity_matrix[0][sorted_indexes[i]] < 0.75: |
|
break |
|
else: |
|
indexes.append(sorted_indexes[i]) |
|
|
|
return indexes |
|
|
|
def create_icl_prompt(input, sample_tasks, task_id): |
|
indexes_bm25 = search_similar_documents_bm25(input, sample_tasks) |
|
indexes_neu = seearch_similar_documents_neuralRetriver(input, sample_tasks) |
|
indexes = list(set(indexes_bm25 + indexes_neu)) |
|
icl_prompt = "" |
|
if indexes == []: |
|
return "" |
|
|
|
icl_prompt = f"""## 例題\n""" |
|
for i in range(len(indexes)): |
|
icl_prompt += f"""### 指示 |
|
{sample_tasks["input"][indexes[i]]} |
|
### 回答 |
|
{sample_tasks["output"][indexes[i]]} |
|
""" |
|
icl_prompt += f""" |
|
## 本題: 以下の指示に従って回答してください。step by stepで回答してください。 |
|
""" |
|
return icl_prompt |
|
|
|
create_icl_prompt(tasks[2]["input"], sample_tasks, 0) |
|
``` |
|
|
|
### 4. Inference |
|
```python |
|
# llmjp |
|
import re |
|
pattern = r"^以下.*$" |
|
|
|
# プロンプトの作成 |
|
sys_prompt = "" |
|
icl_prompt = "" |
|
results = [] |
|
loop = 0 |
|
for data in tqdm(tasks): |
|
task_id = data["task_id"] |
|
input = data["input"] |
|
# in context learning用のプロンプト |
|
icl_prompt = create_icl_prompt(input, sample_tasks, task_id) |
|
|
|
prompt = f"""{sys_prompt}{icl_prompt}### 指示 |
|
{input} |
|
### 回答 |
|
""" |
|
tokenized_input = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
tokenized_input, |
|
max_new_tokens=512, |
|
do_sample=False, |
|
repetition_penalty=1.2, |
|
eos_token_id=tokenizer.eos_token_id, |
|
)[0] |
|
output = tokenizer.decode(outputs[tokenized_input.size(1):], skip_special_tokens=True) |
|
|
|
while (True): #とりあえず出力。 |
|
line = output.splitlines() |
|
if re.match(pattern, line[0]) and len(line) == 1: |
|
print(f"#========================= Unexpected answer =========================#\n {line}") |
|
outputs = model.generate( |
|
tokenized_input, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.4, |
|
repetition_penalty=1.2 |
|
)[0] |
|
output = tokenizer.decode(outputs[tokenized_input.size(1):], skip_special_tokens=True) |
|
else: break |
|
|
|
|
|
results.append({"task_id": data["task_id"], "input": input, "output": output}) |
|
|
|
print(f"task_id: {data['task_id']}, prompt: {prompt}, output: {output}") |
|
|
|
``` |
|
|
|
### 5. Dump results |
|
```python |
|
import re |
|
model_name = re.sub(".*/", "", model_name) |
|
with open(f"./{model_name}-outputs.jsonl", 'w', encoding='utf-8') as f: |
|
for result in results: |
|
json.dump(result, f, ensure_ascii=False) # ensure_ascii=False for handling non-ASCII characters |
|
f.write('\n') |
|
``` |
|
|
|
# Uploaded model |
|
|
|
- **Developed by:** oxygen65 |
|
|
|
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. |
|
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth) |