import os
import random
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import json
# JSONLファイルを読み込む
file_path = 'elyza-tasks-100-TV_0.jsonl'
data = pd.read_json(file_path, lines=True)
def set_seed(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
model_name = "hiroki-rad/google-gemma-2-2b-128-ft-3000"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
def generate_text(data):
prompt = f"""## 指示:あなたは優秀な日本人の問題解決のエキスパートです。以下のステップで質問に取り組んでください:\n\n1. 質問の種類を特定する(事実確認/推論/創造的回答/計算など)\n2. 重要な情報や制約条件を抽出する\n3. 解決に必要なステップを明確にする\n4. 回答を組み立てる
質問をよく読んで、冷静に考え、考えをステップバイステップで考えをまとめてましょう。それをもう一度じっくり考えて、思考のプロセスを整理してください。質問に対して適切な回答を簡潔に出力してください。
質問:{data.input}\n回答:"""
# 推論の実行
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
# Remove token_type_ids from the input_ids
input_ids.pop('token_type_ids', None)
outputs = model.generate(
**input_ids,
max_new_tokens=2048,
do_sample=True,
top_p=0.95,
temperature=0.9,
repetition_penalty=1.1,
)
return tokenizer.decode(outputs[0][len(input_ids['input_ids'][0]):], skip_special_tokens=True)
results = []
for d in tqdm(data.itertuples(), position=0):
results.append(generate_text(d))
jsonl_data = []
# Iterate through the data and outputs
for i in range(len(data)):
task_id = data.iloc[i]["task_id"] # Access task_id using the index
output = results[i]
# Create a dictionary for each row
jsonl_object = {
"task_id": task_id,
"output": output
}
jsonl_data.append(jsonl_object)
with open("gemma2-output.jsonl", "w", encoding="utf-8") as outfile:
for entry in jsonl_data:
# Convert task_id to a regular Python integer before dumping
entry["task_id"] = int(entry["task_id"])
json.dump(entry, outfile, ensure_ascii=False)
outfile.write('\n')
- Downloads last month
- 22
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for hiroki-rad/google-gemma-2-2b-128-ft-3000
Base model
google/gemma-2-2b