Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,51 @@
|
|
1 |
-
---
|
2 |
-
license: gemma
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: gemma
|
3 |
+
---
|
4 |
+
## Introduction
|
5 |
+
This repo contains Gemma-2-9b-Medical, a medical language model with 9 billion parameters. This model builds upon the foundation of Gemma-2-9b-base and has been tuned with diverse medical and general instructions. We also use the three strategies in the paper 'Efficient Continual Pre-training by Mitigating the Stability Gap' to mitigate the stability gap during instruction tuning, which boosts the model's medical task performance and reduces the computation consumption.
|
6 |
+
|
7 |
+
## 💻 Usage
|
8 |
+
|
9 |
+
```python
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
11 |
+
import torch
|
12 |
+
model_name = "YiDuo1999/Gemma-2-9b-medical"
|
13 |
+
device_map = 'auto'
|
14 |
+
model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True,use_cache=False,device_map=device_map)
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
16 |
+
tokenizer.pad_token = tokenizer.eos_token
|
17 |
+
def askme(question):
|
18 |
+
sys_message = '''
|
19 |
+
You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
|
20 |
+
provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
|
21 |
+
'''
|
22 |
+
# Create messages structured for the chat template
|
23 |
+
messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]
|
24 |
+
|
25 |
+
# Applying chat template
|
26 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
27 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
28 |
+
outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
|
29 |
+
|
30 |
+
# Extract and return the generated text, removing the prompt
|
31 |
+
response_text = tokenizer.batch_decode(outputs)[0].strip()
|
32 |
+
answer = response_text.split('<|im_start|>assistant')[-1].strip()
|
33 |
+
return answer
|
34 |
+
## 🏆 Evaluation
|
35 |
+
For question-answering tasks, we have
|
36 |
+
|
37 |
+
| Model | MMLU-Medical | PubMedQA | MedMCQA | MedQA-4-Option | Avg |
|
38 |
+
|:--------------------------------|:--------------|:----------|:---------|:----------------|:------|
|
39 |
+
| Mistral-7B-instruct | 55.8 | 17.8 | 40.2 | 41.1 | 37.5 |
|
40 |
+
| Zephyr-7B-instruct-β | 63.3 | 46.0 | 43.0 | 48.5 | 48.7 |
|
41 |
+
| PMC-Llama-7B | 59.7 | 59.2 | 57.6 | 49.2 | 53.6 |
|
42 |
+
| Medalpaca-13B | 55.2 | 50.4 | 21.2 | 20.2 | 36.7 |
|
43 |
+
| AlpaCare-13B | 60.2 | 53.8 | 38.5 | 30.4 | 45.7 |
|
44 |
+
| BioMedGPT-LM 7B | 52.0 | 58.6 | 34.9 | 39.3 | 46.2 |
|
45 |
+
| Me-Llama-13B | - | 70.0 | 44.9 | 42.7 | - |
|
46 |
+
| Llama-3-8B instruct | 82.0 | 74.6 | 57.1 | 60.3 | 68.5 |
|
47 |
+
| JSL-Med-Sft-Llama-3-8B | 83.0 | 75.4 | 57.5 | 74.8 | 72.7 |
|
48 |
+
| GPT-3.5-turbo-1106 | 74.0 | 72.6 | 34.9 | 39.3 | 60.6 |
|
49 |
+
| GPT-4 | 85.5 | 69.2 | 69.5 | 83.9 | 77.0 |
|
50 |
+
| Gemma-2-9b-int | 75.0 | 76.0 | 40.3 | 48.9 | 60.q |
|
51 |
+
| Llama-3-physician-8B instruct | 80.0 | 76.0 | 80.2 | 60.3 | 74.1 |
|