curiousily's picture
Update README.md
f6b9780 verified
|
raw
history blame
2.33 kB
---
library_name: transformers
tags:
- finance
license: llama3
base_model: meta-llama/Meta-Llama-3-8B-Instruct
datasets:
- virattt/financial-qa-10K
language:
- en
pipeline_tag: text-generation
---
# Llama 3 8B Instruct (Financial RAG)
This model is a fine-tuned version of the original [Llama 3 8B Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model
on 4000 examples from the [virattt/financial-qa-10K](https://huggingface.co/datasets/virattt/financial-qa-10K) dataset.
The model is fine-tuned using a LoRA adapter for RAG use cases. It is optimized to answer a question based on a context:
```txt
Answer the question:
{question}
Using the information:
{context}
```
## Usage
Load the model:
```py
MODEL_NAME = "curiousily/Llama-3-8B-Instruct-Finance-RAG"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto"
)
pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=128,
return_full_text=False,
)
```
Format the prompt (uses the original Instruct prompt format):
````py
prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Use only the information to answer the question<|eot_id|><|start_header_id|>user<|end_header_id|>
How much did the company's net earnings amount to in fiscal 2022?
Information:
```
Net earnings were $17.1 billion in fiscal 2022.
```<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
````
And make a prediction:
```py
print(outputs[0]["generated_text"])
```
```
$17.1 billion
```
Here's a helper function to build your prompts:
```py
def create_test_prompt(data_row):
prompt = dedent(f"""
{data_row["question"]}
Information:
```
{data_row["context"]}
```
""")
messages = [
{"role": "system", "content": "Use only the information to answer the question"},
{"role": "user", "content": prompt},
]
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
```
Where `data_row` must be a dict:
```
data_row = {
"question": "...",
"context": "..."
}
```
## License
Uses the original Llama 3 License.
A custom commercial license is available at: https://llama.meta.com/llama3/license