sedrickkeh
commited on
Commit
•
443ad2e
1
Parent(s):
efb6425
Update README.md
Browse files
README.md
CHANGED
@@ -94,7 +94,7 @@ We follow their training recipe and release our version of Mamba-7B.
|
|
94 |
|
95 |
## Training Details
|
96 |
- Mamba-7B was trained using AWS SageMaker on 128 H100 80GB GPUs.
|
97 |
-
- Training began in March
|
98 |
| **Hyperparameter** | **Value** |
|
99 |
|--------------------|------------|
|
100 |
| Precision | `bfloat16` |
|
@@ -108,18 +108,9 @@ We follow their training recipe and release our version of Mamba-7B.
|
|
108 |
|
109 |
|
110 |
## Usage
|
111 |
-
|
112 |
-
This model was trained using [OpenLM](https://github.com/mlfoundations/open_lm/).
|
113 |
-
|
114 |
-
To use HuggingFace models trained with OpenLM, first install the OpenLM package
|
115 |
-
```bash
|
116 |
-
pip install openlm
|
117 |
-
```
|
118 |
-
|
119 |
-
Importing from `openlm_hf` will automatically import the necessary classes.
|
120 |
|
121 |
```python
|
122 |
-
from openlm_hf import * # registers the Auto* classes
|
123 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
124 |
tokenizer = AutoTokenizer.from_pretrained("tri-ml/mamba-7b-rw")
|
125 |
model = AutoModelForCausalLM.from_pretrained("tri-ml/mamba-7b-rw").cuda()
|
|
|
94 |
|
95 |
## Training Details
|
96 |
- Mamba-7B was trained using AWS SageMaker on 128 H100 80GB GPUs.
|
97 |
+
- Training began in March 2024 and lasted around 3 weeks (some down time due to crashes and loss spikes)
|
98 |
| **Hyperparameter** | **Value** |
|
99 |
|--------------------|------------|
|
100 |
| Precision | `bfloat16` |
|
|
|
108 |
|
109 |
|
110 |
## Usage
|
111 |
+
This model was trained using [OpenLM](https://github.com/mlfoundations/open_lm/). The weights have been converted to be compatible with HuggingFace.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
```python
|
|
|
114 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
115 |
tokenizer = AutoTokenizer.from_pretrained("tri-ml/mamba-7b-rw")
|
116 |
model = AutoModelForCausalLM.from_pretrained("tri-ml/mamba-7b-rw").cuda()
|