PRM_DPO_GEMMA_ZD_8_18_1
This model is a fine-tuned version of google/gemma-2-2b-it on the prm_dpo dataset.
Citation
@article{zhang2024llama,
title={LLaMA-Berry: Pairwise Optimization for O1-like Olympiad-Level Mathematical Reasoning},
author={Zhang, Di and Wu, Jianbo and Lei, Jingdi and Che, Tong and Li, Jiatong and Xie, Tong and Huang, Xiaoshui and Zhang, Shufei and Pavone, Marco and Li, Yuqiang and others},
journal={arXiv preprint arXiv:2410.02884},
year={2024}
}
@article{zhang2024accessing,
title={Accessing GPT-4 level Mathematical Olympiad Solutions via Monte Carlo Tree Self-refine with LLaMa-3 8B},
author={Zhang, Di and Li, Jiatong and Huang, Xiaoshui and Zhou, Dongzhan and Li, Yuqiang and Ouyang, Wanli},
journal={arXiv preprint arXiv:2406.07394},
year={2024}
}
Model usage
server.py
import json
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# Initialize FastAPI
app = FastAPI()
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model and tokenizer loading (as you provided)
model_name = "google/gemma-2-2b-it"
lora_checkpoint_path = "qq8933/PPRM-gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map='cuda')
model = PeftModel.from_pretrained(base_model, lora_checkpoint_path, device_map='cuda')
yes_token_id = tokenizer.convert_tokens_to_ids("yes")
no_token_id = tokenizer.convert_tokens_to_ids("no")
# Request model
class InputRequest(BaseModel):
text: str
# Predict function
def predict(qeustion,answer_1,answer_2):
prompt_template = """Problem:\n\n{}\n\nFirst Answer:\n\n{}\n\nSecond Answer:\n\n{}\n\nIs First Answer better than Second Answer?\n\n"""
input_text = prompt_template.format(qeustion,answer_1,answer_2)
input_text = tokenizer.apply_chat_template(
[{'role': 'user', 'content': input_text}], tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(input_text, return_tensors="pt").to(device)
with torch.no_grad():
generated_outputs = model.generate(
**inputs, max_new_tokens=2, output_scores=True, return_dict_in_generate=True
)
scores = generated_outputs.scores
first_token_logits = scores[0]
yes_logit = first_token_logits[0, yes_token_id].item()
no_logit = first_token_logits[0, no_token_id].item()
return {
"yes_logit": yes_logit,
"no_logit": no_logit,
"logit_difference": yes_logit - no_logit
}
# Define API endpoint
@app.post("/predict")
async def get_prediction(input_request: InputRequest):
payload = json.loads(input_request.text)
qeustion,answer_1,answer_2 = payload['qeustion'],payload['answer_1'],payload['answer_2']
try:
result = predict(qeustion,answer_1,answer_2)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
run pprm_server
uvicorn server:app --host 0.0.0.0 --port $MASTER_PORT --workers 1
request pprm server
# qeustion,answer_1,answer_2 = 'What is the capital of France?', 'Berlin', 'Paris'
# {'yes_logit': -24.26136016845703, 'no_logit': 19.517587661743164, 'logit_difference': -43.778947830200195}
# Is answer_1 better than answer_2? yes or no
# 奖励模型的入口
def request_prediction(
qeustion, answer_1, answer_2, url="http://10.140.24.56:10085/predict"
):
"""
Sends a POST request to the FastAPI server to get a prediction.
Args:
- text (str): The input text for the prediction.
- url (str): The API endpoint URL. Defaults to 'http://localhost:8000/predict'.
Returns:
- dict: The response from the API containing prediction results.
"""
headers = {"Content-Type": "application/json"}
payload = {
"text": json.dumps(
{"qeustion": qeustion, "answer_1": answer_1, "answer_2": answer_2}
)
}
response = requests.post(url, json=payload, headers=headers, timeout=TIMEOUT_PRM)
response.raise_for_status() # Raises an HTTPError if the response code was unsuccessful
return response.json() # Return the JSON response as a dictionary
def cal_reward(question, ans, ans2="I don't know"):
if ans2 in DUMMY_ANSWERS:#I don't know
return 1
if ans in DUMMY_ANSWERS:
return 0
urls = copy.deepcopy(prm_servers)
random.shuffle(urls)
for url in urls:
try:
response = request_prediction(question, ans, ans2, url)
return math.exp(response["yes_logit"]) / (
math.exp(response["yes_logit"]) + math.exp(response["no_logit"])
)
except Exception as e:
# print(e)
continue
print(Exception("All prm servers are down"))
# get_clients()
return cal_reward(question, ans, ans2)
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 4
- eval_batch_size: 8
- seed: 42
- distributed_type: multi-GPU
- num_devices: 16
- gradient_accumulation_steps: 2
- total_train_batch_size: 128
- total_eval_batch_size: 128
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 1.0
Framework versions
- PEFT 0.11.1
- Transformers 4.44.0
- Pytorch 2.3.1
- Datasets 2.20.0
- Tokenizers 0.19.1
- Downloads last month
- 35