|
--- |
|
library_name: transformers |
|
tags: |
|
- finance |
|
license: llama3 |
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct |
|
datasets: |
|
- virattt/financial-qa-10K |
|
language: |
|
- en |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
# Llama 3 8B Instruct (Financial RAG) |
|
|
|
This model is a fine-tuned version of the original [Llama 3 8B Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model |
|
on 4000 examples from the [virattt/financial-qa-10K](https://huggingface.co/datasets/virattt/financial-qa-10K) dataset. |
|
|
|
The model is fine-tuned using a LoRA adapter for RAG use cases. It is optimized to answer a question based on a context: |
|
|
|
```txt |
|
Answer the question: |
|
{question} |
|
|
|
Using the information: |
|
{context} |
|
``` |
|
|
|
## Usage |
|
|
|
Load the model: |
|
|
|
```py |
|
MODEL_NAME = "curiousily/Llama-3-8B-Instruct-Finance-RAG" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
device_map="auto" |
|
) |
|
|
|
pipe = pipeline( |
|
task="text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=128, |
|
return_full_text=False, |
|
) |
|
``` |
|
|
|
Format the prompt (uses the original Instruct prompt format): |
|
|
|
````py |
|
prompt = """ |
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
|
|
Use only the information to answer the question<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
|
|
How much did the company's net earnings amount to in fiscal 2022? |
|
|
|
Information: |
|
|
|
``` |
|
Net earnings were $17.1 billion in fiscal 2022. |
|
```<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
""" |
|
```` |
|
|
|
And make a prediction: |
|
|
|
```py |
|
print(outputs[0]["generated_text"]) |
|
``` |
|
|
|
``` |
|
$17.1 billion |
|
``` |
|
|
|
Here's a helper function to build your prompts: |
|
|
|
```py |
|
def create_test_prompt(data_row): |
|
prompt = dedent(f""" |
|
{data_row["question"]} |
|
|
|
Information: |
|
|
|
``` |
|
{data_row["context"]} |
|
``` |
|
""") |
|
messages = [ |
|
{"role": "system", "content": "Use only the information to answer the question"}, |
|
{"role": "user", "content": prompt}, |
|
] |
|
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
``` |
|
|
|
Where `data_row` must be a dict: |
|
|
|
``` |
|
data_row = { |
|
"question": "...", |
|
"context": "..." |
|
} |
|
``` |
|
|
|
## License |
|
|
|
Uses the original Llama 3 License. |
|
A custom commercial license is available at: https://llama.meta.com/llama3/license |