|
--- |
|
language: |
|
- en |
|
- hi |
|
license: llama2 |
|
tags: |
|
- multilingual |
|
- instruction-tuning |
|
- llama2 |
|
datasets: |
|
- ai4bharat/indic-instruct-data-v0.1 |
|
--- |
|
|
|
# Airavata |
|
|
|
This model is a 7B [OpenHathi](https://huggingface.co/sarvamai/OpenHathi-7B-Hi-v0.1-Base) model finetuned on [IndicInstruct dataset](https://huggingface.co/datasets/ai4bharat/indic-instruct-data-v0.1) |
|
which is a collection of instruction datasets (Anudesh, wikiHow, Flan v2, Dolly, Anthropic-HHH, OpenAssistant v1, and LymSys-Chat). |
|
Please check the corresponding huggingface dataset card for more details. |
|
|
|
This was trained as part of the technical report [Airavata: Introducing Hindi Instruction-tuned LLM](https://arxiv.org/abs/2401.15006). |
|
The codebase used to train and evaluate this model can be found at [https://github.com/AI4Bharat/IndicInstruct](https://github.com/AI4Bharat/IndicInstruct). |
|
|
|
## Usage |
|
|
|
Clone [https://github.com/AI4Bharat/IndicInstruct](https://github.com/AI4Bharat/IndicInstruct) and install the required dependencies. Then download or clone this model to the same machine. |
|
|
|
## Input Format |
|
|
|
The model is trained to use the chat format similar to [open-instruct code repository](https://github.com/allenai/open-instruct) (note the newlines): |
|
``` |
|
<|user|> |
|
Your message here! |
|
<|assistant|> |
|
``` |
|
|
|
For best results, format all inputs in this manner. **Make sure to include a newline after `<|assistant|>`, this can affect generation quality quite a bit.** |
|
|
|
## Hyperparameters |
|
|
|
We fine-tune OpenHathi base model on the aforementioned IndicInstruct dataset with LoRA. The hyperparameters for the LoRA fine-tuning are listed below: |
|
- LoRA Rank: 16 |
|
- LoRA alpha: 32 |
|
- LoRA Dropout: 0.05 |
|
- LoRA Target Modules: ["q_proj", "v_proj", "k_proj", "down_proj", "gate_proj", "up_proj"] |
|
- Epochs: 4 |
|
- Learning rate: 5e-4 |
|
- Batch Size: 128 |
|
- Floating Point Precision: bfloat16 |
|
|
|
We recommend the readers to check out [our official blog post](https://ai4bharat.github.io/airavata) for more details on the model training, ablations and evaluation results. |
|
|
|
## Example |
|
|
|
```python3 |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True): |
|
formatted_text = "" |
|
for message in messages: |
|
if message["role"] == "system": |
|
formatted_text += "<|system|>\n" + message["content"] + "\n" |
|
elif message["role"] == "user": |
|
formatted_text += "<|user|>\n" + message["content"] + "\n" |
|
elif message["role"] == "assistant": |
|
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n" |
|
else: |
|
raise ValueError( |
|
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format( |
|
message["role"] |
|
) |
|
) |
|
formatted_text += "<|assistant|>\n" |
|
formatted_text = bos + formatted_text if add_bos else formatted_text |
|
return formatted_text |
|
|
|
|
|
def inference(input_prompts, model, tokenizer): |
|
input_prompts = [ |
|
create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False) |
|
for input_prompt in input_prompts |
|
] |
|
|
|
encodings = tokenizer(input_prompts, padding=True, return_tensors="pt") |
|
encodings = encodings.to(device) |
|
|
|
with torch.inference_mode(): |
|
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250) |
|
|
|
output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True) |
|
|
|
input_prompts = [ |
|
tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts |
|
] |
|
output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)] |
|
return output_texts |
|
|
|
|
|
model_name = "ai4bharat/Airavata" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) |
|
|
|
input_prompts = [ |
|
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं।", |
|
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं और उनका वर्णन करें।", |
|
] |
|
outputs = inference(input_prompts, model, tokenizer) |
|
print(outputs) |
|
``` |
|
|
|
## Citation |
|
|
|
```bibtex |
|
@article{gala2024airavata, |
|
title = {Airavata: Introducing Hindi Instruction-tuned LLM}, |
|
author = {Jay Gala and Thanmay Jayakumar and Jaavid Aktar Husain and Aswanth Kumar M and Mohammed Safi Ur Rahman Khan and Diptesh Kanojia and Ratish Puduppully and Mitesh M. Khapra and Raj Dabre and Rudra Murthy and Anoop Kunchukuttan}, |
|
year = {2024}, |
|
journal = {arXiv preprint arXiv: 2401.15006} |
|
} |
|
``` |
|
|