File size: 1,809 Bytes
6bba74a
 
 
 
 
 
 
 
3cc44b6
 
 
a24a287
 
6bba74a
9752e31
3cc44b6
 
6bba74a
a24a287
3cc44b6
a24a287
3cc44b6
 
 
a24a287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa2e8ac
a24a287
 
 
 
 
 
 
6bba74a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
---
license: apache-2.0
datasets:
- harishnair04/mtsamples
language:
- en
base_model:
- google/gemma-2-2b
tags:
- trl
- sft
- quantization
- 4bit
- LoRA
library_name: transformers
---


# Model Card for Medical Transcription Model (Gemma-MedTr)

This model is a fine-tuned variant of `Gemma-2-2b`, optimized for medical transcription tasks with efficient 4-bit quantization and Low-Rank Adaptation (LoRA). It handles transcription processing, keyword extraction, and medical specialty classification.

## Model Details

- **Developed by:** Harish Nair
- **Organization:** University of Ottawa
- **License:** Apache 2.0
- **Fine-tuned from:** [Gemma-2-2b](https://huggingface.co/google/gemma-2-2b)
- **Model type:** Transformer-based language model for medical transcription processing
- **Language(s):** English

### Training Details

- **Training Loss:** Final training loss at step 10: 1.4791
- **Training Configuration:** 
  - LoRA with `r=8`, targeting specific transformer modules for adaptation.
  - 4-bit quantization using `nf4` quantization type and `bfloat16` compute precision.
- **Training Runtime:** 20.85 seconds, with approximately 1.92 samples processed per second.

## How to Use

To load and use this model, initialize it with the following configuration:
```python
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel

model_id = "harishnair04/Gemma-medtr-2b-sft"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token_read)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map='auto', token=access_token_read)