|
--- |
|
library_name: peft |
|
language: |
|
- ja |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
# Fine-tuned OpenCALM-7B Adapters for Meeting Summarization |
|
|
|
## Description |
|
|
|
These are weights for LoRA adapters fine-tuned on the OpenCALM-7B ([Andonian et al., 2021](https://huggingface.co/cyberagent/open-calm-7b)) model for Japanese meeting summarization. |
|
|
|
## Usage |
|
|
|
### Load model and tokenizer |
|
|
|
Loading the model in the 4-bit quantized format is recommended to get reliable results since these LoRA adapters were trained by using QLoRA ([Dettmers et al., 2023](https://arxiv.org/abs/2305.14314)). |
|
|
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from peft import PeftModel |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-7b") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"cyberagent/open-calm-7b", |
|
quantization_config=bnb_config, |
|
device_map="auto" |
|
) |
|
|
|
model = PeftModel.from_pretrained(model, "haih2/open-calm-7b-summarizer-lora") |
|
``` |
|
|
|
### Generate summary |
|
|
|
In the prompt provided to the model: |
|
* The first part is the length of the summary to be generated, |
|
* and The second part is the source meeting to be summarized. |
|
|
|
```python |
|
prompt = "この段落の要約50字以内生成:次に、私立高校の生徒に対する留学支援についてでございますが、都内の私立高校は、それぞれの学校における教育方針に基づきまして、生徒の留学先として海外の学校と提携するなど、既にさまざまな独自の取り組みを進めております。\\nこうした状況等を踏まえ、私立高校を対象とした留学支援のあり方について、今後検討してまいります。\\n\n" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
tokens = model.generate( |
|
**inputs, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_k=32, |
|
top_p=0.9, |
|
repetition_penalty=1.0, |
|
no_repeat_ngram_size=0, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
output = tokenizer.decode(tokens[0], skip_special_tokens=True) |
|
print(output) |
|
``` |
|
|
|
## Prompt Format |
|
|
|
Any prompt is fine, but it is suggested to have `length` and `source` parts as follows: |
|
|
|
``` |
|
"この段落を{length}に要約しなさい:{source}\n要約:" |
|
``` |
|
|
|
or |
|
|
|
``` |
|
"この段落の要約{length}生成:{source}\n" |
|
``` |
|
|
|
## Fine-tuning Details |
|
|
|
### Dataset |
|
|
|
* [Congressional meeting's minutes](https://github.com/kmr-y/NTCIR14-QALab-PoliInfo-FormalRunDataset/tree/master) provided by QA Lab PoliInfo. |
|
|
|
### Fine-tuning procedure |
|
|
|
The OpenCALM-7B model was fine-tuned on the above dataset using the QLoRA method with prompt `この段落の要約{length}生成:{source}\n`. We outline the following hyperparameters: |
|
|
|
||| |
|
|----------------|----------------:| |
|
| **Optimizer** <br>   beta_1 <br>   beta_2 <br>   weight decay | AdamW <br> 0.9 <br> 0.999 <br> 0.01 | |
|
| **Learning rate** <br>   scheduler type | 2e-5 <br> linear | |
|
| **LoRA** <br>   target modules <br>   r <br>   alpha <br>   dropout | <br> query_key_value, dense <br> 4 <br> 64 <br> 0.05 | |
|
| **QLoRA** <br>   compute dtype <br>   storage dtype <br>   quantization strategy | <br> float16 <br> nf4 <br> double quantization | |
|
| **Sequence length** | 1536 | |
|
| **Batch size** | 4 | |
|
| **Gradient accumulation steps** | 2 | |
|
| **Epochs** | 10 | |
|
| **Warmup steps** | 200 | |
|
|
|
## Evaluation |
|
|
|
### Testing data & Metric |
|
|
|
We evaluated the model on two sets: one for *multi-topic* summarization and the other for *single-topic* summarization. ROUGE-L (F1-score-based) with the [Japanese Mecab tokenizer](https://pypi.org/project/mecab-python3/) was used as the evaluation metric. |
|
|
|
### Results |
|
|
|
| Solution/Model | ROUGE-L <br> (multi-topic) | ROUGE-L <br> (single-topic) | |
|
|----------------|:--------------------------:|:---------------------------:| |
|
|1st place solution* |34.12 |**34.44**| |
|
|2nd place solution* |32.79 |33.65 | |
|
|*OpenCALM-7B (QLoRA)*|***36.75***|*33.31* | |
|
|
|
*\* These scores are extracted from this [leaderboard](https://github.com/PoliInfo/PoliInfo.github.io/blob/master/FormalRunResult.md) for the summarization task.* |
|
|