Athene-RM-8B / README.md
peter-jin-nexusflow's picture
Update README.md
cdf428f verified
---
license: other
language:
- en
library_name: transformers
tags:
- RLHF
- Nexusflow
- Athene
- Reward Model
base_model:
- meta-llama/Meta-Llama-3-8B-Instruct
---
# Llama3-Athene-RM-8B
We introduce Llama3-Athene-RM-8B, an open-weights reward model based off Llama-3-8B-Instruct.
- **Developed by:** The Nexusflow Team (Evan Frick\*, Peter Jin\*, Tianle Li\*, Karthik Ganesan, Jian Zhang, Jiantao Jiao and Banghua Zhu).
- **Model type:** Reward Model
- **Finetuned from model:** [Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct).
- **License**: [Nexusflow Research License](https://huggingface.co/Nexusflow/Athene-RM-8B/blob/main/Nexusflow_Research_License.pdf)
- **Blog**: https://nexusflow.ai/blogs/athene
### Usage
```python
from transformers import LlamaModel, LlamaPreTrainedModel, TextClassificationPipeline
from torch import nn
import torch
from typing import Dict
class AtheneForSequenceClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
self.CLS_ID = 128003
# Initialize weights and apply final processing
self.post_init()
def get_device(self):
return self.model.device
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
):
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=True,
)
hidden_states = transformer_outputs.hidden_states[-1]
scores = []
rewards = self.v_head(hidden_states).squeeze(-1)
bs = int(input_ids.shape[0])
for i in range(bs):
c_inds = (input_ids[i] == self.CLS_ID).nonzero()
c_ind = c_inds[-1].item()
scores.append(rewards[i, c_ind])
scores = torch.stack(scores)
return {"scores": scores}
# Make a pipeline to handle pre and post-processing
class AtheneRewardPipeline(TextClassificationPipeline):
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Tensor]:
return_tensors = self.framework
formatted = self.tokenizer.apply_chat_template(inputs, tokenize=False)
formatted = formatted + self.tokenizer.cls_token
return self.tokenizer(
formatted,
return_tensors=return_tensors,
max_length=4096,
padding="longest",
truncation=True,
)
def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
return model_outputs["scores"].cpu().float().item()
# Initialize the model
model = AtheneForSequenceClassification.from_pretrained("Nexusflow/Athene-RM-8B", torch_dtype=bfloat16)
tokenizer = AutoTokenizer.from_pretrained("Nexusflow/Athene-RM-8B")
# Initialize the pipeline
pipe = pipeline(
task="text-classification",
model=self.model,
tokenizer=self.tokenizer,
pipeline_class=AtheneRewardPipeline,
device_map="auto",
)
messages = [
{
"role": 'user',
"content": "What is an Athene Noctura? Explain one sentence."
},
{
"role": "assistant",
"content": "The Athene noctua, also known as the little owl, is a small, nocturnal owl species native to Europe, Asia, and North Africa, characterized by its distinctive facial disk and piercing yellow eyes."
}
]
print(pipe([messages])) # Print the reward!
```
### Citation
```
@misc{Athene2024,
title = {Athene-70B: Redefining the Boundaries of Post-Training for Open Models},
url = {https://nexusflow.ai/blogs/athene},
author = {Frick, Evan and Jin, Peter and Li, Tianle and Ganesan, Karthik and Zhang, Jian and Jiao, Jiantao and Zhu, Banghua},
month = {July},
year = {2024}
}
```