|
--- |
|
base_model: tomo1222/gemma-2-27b-bf16-4bit |
|
tags: |
|
- text-generation-inference |
|
- transformers |
|
- unsloth |
|
- gemma2 |
|
- trl |
|
license: gemma |
|
language: |
|
- jp |
|
datasets: |
|
- llm-jp/magpie-sft-v1.0 |
|
- tomo1222/Japanese-QA111dataset |
|
--- |
|
|
|
# Uploaded model |
|
|
|
- **Developed by:** tomo1222 |
|
- **License:** Gemma |
|
- **Finetuned from model :** tomo1222/gemma-2-27b-bf16-4bit |
|
|
|
[tomo1222/gemma-2-27b-bf16-4bit](https://huggingface.co/tomo1222/gemma-2-27b-bf16-4bit) : [google/gemma-2-27b](https://huggingface.co/google/gemma-2-27b)を[Unsloth](https://github.com/unslothai/unsloth)で直接用いるために、BitsAndBytesを用いて4bit量子化し、そのまま保存したもの。 |
|
|
|
This gemma2 model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. |
|
|
|
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth) |
|
|
|
# output code |
|
|
|
## library |
|
```bash |
|
pip install unsloth |
|
pip install --no-deps --upgrade "flash-attn>=2.6.3" |
|
pip install -U ragatouille |
|
pip install fugashi unidic-lite |
|
``` |
|
|
|
### inference code using Google Colaboratory(L4) |
|
```python |
|
from datasets import concatenate_datasets, load_dataset |
|
from unsloth import FastLanguageModel |
|
import random |
|
import json |
|
|
|
from huggingface_hub import login |
|
from google.colab import userdata |
|
login(userdata.get('HFtoken')) |
|
|
|
|
|
with open("elyza-tasks-100-TV_0.jsonl","r",encoding='utf-8') as f: |
|
tasks = [json.loads(l) for l in f.readlines()] |
|
|
|
model_name = "tomo1222/Gemma2-27b-ft-jp-r64_alpha64" |
|
|
|
|
|
max_seq_length = 4096 |
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name = model_name, |
|
max_seq_length = max_seq_length, |
|
dtype = None, |
|
load_in_4bit = True, |
|
) |
|
|
|
# google/gemma-2-9bのテンプレート |
|
tokenizer.chat_template = """ |
|
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %} |
|
""" |
|
FastLanguageModel.for_inference(model) # Enable native 2x faster inference |
|
|
|
dataset = load_dataset("tomo1222/Japanese-QA111dataset") |
|
ref_tasks = list(dataset["train"]) |
|
ref_tasks_input = [task["input"] for task in ref_tasks] |
|
|
|
dic = {} |
|
dic_input = {} |
|
for i, task in enumerate(ref_tasks): |
|
dic[ref_tasks_input[i]] = task["output"] |
|
dic_input[ref_tasks_input[i]] = task["input"] |
|
|
|
"""# 2. RAGのロード""" |
|
|
|
from ragatouille import RAGPretrainedModel |
|
RAG = RAGPretrainedModel.from_pretrained("bclavie/JaColBERTv2") |
|
RAG.encode(ref_tasks_input) |
|
|
|
def search_ref_input(input, k=10): |
|
retreived=RAG.search_encoded_docs(query=input,k=k) |
|
print(retreived) |
|
text ="質問・文章をよく読んで、正確で親切な回答を書きなさい。\n" |
|
for data in retreived[::-1]: # inverse order |
|
key = data["content"] |
|
output = dic[key] |
|
input = dic_input[key] |
|
text+="### 質問:\n"+input+"\n\n### 回答:\n"+output+"\n\n\n" |
|
return text |
|
|
|
"""# Prompt""" |
|
output_data=[] |
|
|
|
for i, task in enumerate(tasks): |
|
text = ( |
|
search_ref_input(task["input"], 20) |
|
+ "あなたは日本語が堪能な優秀な人間です。\n" |
|
+ "**文脈**を踏まえて、改行と箇条書きを駆使して、日本語で**詳細に**書きなさい。\n" |
|
+ "優秀な人間になりきって、推測をいれずに根拠をもってわかりやすく答えてください。" |
|
+ f"### 質問:\n{task['input']}\n\n### 回答:\n" |
|
) |
|
print(task["input"]) |
|
inputs = tokenizer(text, return_tensors="pt").to("cuda") |
|
print(len(inputs['input_ids'][0])) |
|
output = model.generate(**inputs, max_new_tokens=1024,repetition_penalty=1.1,use_cache=True, |
|
bad_words_ids = [tokenizer.encode("質問", add_special_tokens=False), |
|
tokenizer.encode("###", add_special_tokens=False), |
|
tokenizer.encode("#", add_special_tokens=False), |
|
tokenizer.encode("##", add_special_tokens=False), |
|
tokenizer.encode("---", add_special_tokens=False), |
|
tokenizer.encode("<h3>", add_special_tokens=False), |
|
tokenizer.encode("filepath", add_special_tokens=False), |
|
tokenizer.encode("> ", add_special_tokens=False), |
|
] |
|
) |
|
|
|
output_text = tokenizer.decode(output[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip() |
|
print(i,output_text) |
|
print("---") |
|
output_data.append({"task_id":i,"output":output_text}) |
|
|
|
with open("output.jsonl","w",encoding="utf-8") as f: |
|
for result in output_data: |
|
json.dump(result, f, ensure_ascii=False) |
|
f.write('\n') |
|
``` |
|
|