|
--- |
|
library_name: transformers |
|
tags: |
|
- medical |
|
- biology |
|
- chemistry |
|
license: gemma |
|
datasets: |
|
- sid6i7/patient-doctor |
|
language: |
|
- en |
|
pipeline_tag: text-generation |
|
base_model: |
|
- google/gemma-2-2b-it |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
|
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
## Uses |
|
|
|
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. --> |
|
Below we share some code snippets on how to get quickly started with running the model. First, install the Transformers library with: |
|
|
|
```python |
|
pip install -U transformers |
|
pip install -U torch |
|
``` |
|
Then, copy the snippet from the section that is relevant for your usecase. |
|
|
|
### Direct Use |
|
|
|
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. --> |
|
Running the model on Single/Multi GPU |
|
|
|
```python |
|
# pip install accelerate |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Ellio98/doctor-gemma-2-2b-it") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"Ellio98/doctor-gemma-2-2b-it", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
query = "I have a high fever and coughing since morning. What should I do ?" |
|
|
|
prompt = [{"role": "user", "content": query}] |
|
|
|
model_input = tokenizer(tokenizer.apply_chat_template(prompt, tokenize=False), return_tensors="pt") |
|
|
|
outputs = model.generate( |
|
input_ids=model_input["input_ids"].to("cuda"), |
|
attention_mask=model_input["attention_mask"].to("cuda"), |
|
max_new_tokens=32 |
|
) |
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
# user |
|
# I have a high fever and coughing since morning. What should I do ? |
|
# model |
|
# Hi,From history it seems that you might be having viral infection giving this problem of fever with cough. |
|
# Take paracetamol or ibuprofen for pain due to fever. Take plenty of water. |
|
# If require take one antispasmodic medicine like Meftal spas as needed. |
|
# Ok and take care. |
|
``` |
|
|
|
## Bias, Risks, and Limitations |
|
|
|
<!-- This section is meant to convey both technical and sociotechnical limitations. --> |
|
The model might not be good at generating the desired output and might suggest some weird medical practices that does not make sense. |
|
It is highly recommended to use this model just for research purpose and not as a commercial use. |
|
Also, as no data anonymization was performed, the model will be generating random names, addresses or terms that are not intended. |
|
|
|
## Future Works |
|
|
|
Need to perform data anonymization to refine the model output to be generic. |
|
|
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. --> |
|
For fine-tuning the model, 5000 samples were indiced from the `sid6i7/patient-doctor` dataset, out of which 4500 were used for training, 250 were used for validation and testing. |
|
|
|
#### Training Hyperparameters |
|
|
|
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision --> |
|
- precision: bfloat16 |
|
- alpha: 32 |
|
- rank: 16 |
|
- lora dropout: 0.05 |
|
- target modules: 'o_proj', 'v_proj', 'up_proj', 'q_proj', 'down_proj', 'gate_proj', and 'k_proj' |
|
- optimizer: paged_adamw_32bit |
|
- learning rate: 2e-4 |
|
|
|
|
|
### Results |
|
|
|
The loss progression during fine-tuning of gemma-2-2b on the dataset is as follows: |
|
|
|
![image/png](https://cdn-uploads.huggingface.co/production/uploads/635e4f9d19b4345264daecac/_lCch0VfuEVpGzjn1ilzS.png) |
|
|
|
#### Hardware |
|
|
|
The model was trained on 2 x T4 GPU. The total training time taken was 6 hours. |