File size: 3,749 Bytes
1cd211f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c90d0c7
1cd211f
 
 
 
c90d0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
---
license: gemma
---
## Introduction
This repo contains Gemma-2-9b-Medical, a medical language model with 9 billion parameters. This model builds upon the foundation of Gemma-2-9b-base and has been tuned with diverse medical and general instructions. We also use the three strategies in the paper 'Efficient Continual Pre-training by Mitigating the Stability Gap' to mitigate the stability gap during instruction tuning, which boosts the model's medical task performance and reduces the computation consumption.

## 💻 Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
model_name = "YiDuo1999/Gemma-2-9b-medical"
device_map = 'auto'
model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True,use_cache=False,device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
def askme(question):
    sys_message = ''' 
    You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
    provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
    '''   
    # Create messages structured for the chat template
    messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]
    
    # Applying chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
    
    # Extract and return the generated text, removing the prompt
    response_text = tokenizer.batch_decode(outputs)[0].strip()
    answer = response_text.split('<|im_start|>assistant')[-1].strip()
    return answer
```
## 🏆 Evaluation
For question-answering tasks, we have

| Model                          | MMLU-Medical | PubMedQA | MedMCQA | MedQA-4-Option | Avg  |
|:-------------------------------|:-------------|:---------|:--------|:---------------|:-----|
| Mistral-7B-instruct            | 55.8         | 17.8     | 40.2    | 41.1           | 37.5 |
| Zephyr-7B-instruct-β           | 63.3         | 46.0     | 43.0    | 48.5           | 48.7 |
| PMC-Llama-7B                   | 59.7         | 59.2     | 57.6    | 49.2           | 53.6 |
| Medalpaca-13B                  | 55.2         | 50.4     | 21.2    | 20.2           | 36.7 |
| AlpaCare-13B                   | 60.2         | 53.8     | 38.5    | 30.4           | 45.7 |
| BioMedGPT-LM 7B                | 52.0         | 58.6     | 34.9    | 39.3           | 46.2 |
| Me-Llama-13B                   | -            | 70.0     | 44.9    | 42.7           | -    |
| Llama-3-8B instruct            | 82.0         | 74.6     | 57.1    | 60.3           | 68.5 |
| JSL-Med-Sft-Llama-3-8B         | 83.0         | 75.4     | 57.5    | 74.8           | 72.7 |
| GPT-3.5-turbo-1106             | 74.0         | 72.6     | 34.9    | 39.3           | 60.6 |
| GPT-4                          | 85.5         | 69.2     | 69.5    | 83.9           | 77.0 |
| Gemma-2-9b-int                 | 75.0         | 76.0     | 40.3    | 48.9           | 60.0 |
| Gemma-2-9b-Medical             | 75.0         | 76.0     | 61.3    | 59.7           | 68.0 |
| Llama-3-physician-8B instruct  | 80.0         | 76.0     | 80.2    | 60.3           | 74.1 |

## Citation
```
@inproceedings{Guo2024EfficientCP,
  title={Efficient Continual Pre-training by Mitigating the Stability Gap},
  author={Yiduo Guo and Jie Fu and Huishuai Zhang and Dongyan Zhao and Yikang Shen},
  year={2024},
  url={https://api.semanticscholar.org/CorpusID:270688100}
}
```