File size: 6,856 Bytes
7f1a36e
5610aae
7f1a36e
5610aae
 
 
 
7f1a36e
5610aae
7885f9e
8ff94ab
388fbd9
8ff94ab
f5137f5
5610aae
f5137f5
388fbd9
7885f9e
 
 
3d23d9e
 
 
 
 
 
7885f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5610aae
7885f9e
1ae4683
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
---
library_name: transformers
license: apache-2.0
tags:
- jamba
- mamba
- moe
---

# Please refrain from using this model yet. It's not any weight at all.

# A experts weights of [Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)

Required Weights for follow-up research.

The original model is **[AI21lab's Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)**, which requires an **>80GB VRAM**. Unfortunately, this almonst was not available via Google Colab or cloud computing services. Thus, attempts were made to perform **MoE (Mixture of Experts) splitting**, using the following resources as a basis:
- **Original Model:** [Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
- **MoE Layer Separation**: Consult [this script](https://github.com/TechxGenus/Jamba-utils/blob/main/dense_downcycling.py) written by [@TechxGenusand](https://github.com/TechxGenusand) and use [TechxGenus/Jamba-v0.1-9B](https://huggingface.co/TechxGenus/Jamba-v0.1-9B).


<br><br><br><br><br><br>


# Original Model Card from **[AI21lab's Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)**.


## Usage

The code used in **[AI21lab's Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)**.

### Presequities

To use Jamba, ensure you have `transformers` version 4.40.0 or higher installed (version 4.39.0 or higher is required):
```bash
pip install transformers>=4.40.0
```

For optimized Mamba implementations, install `mamba-ssm` and `causal-conv1d`:
```bash
pip install mamba-ssm causal-conv1d>=1.2.0
```
Ensure the model is on a CUDA device.

You can run the model without optimized Mamba kernels, but it's **not** recommended due to significantly lower latencies. To do so, specify `use_mamba_kernels=False` when loading the model.

### Run the model

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base")
tokenizer = AutoTokenizer.from_pretrained("danielpark/asp-9b-inst-base")

input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]

outputs = model.generate(input_ids, max_new_tokens=216)

print(tokenizer.batch_decode(outputs))
# ["In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
```

When using `transformers<4.40.0`, ensure `trust_remote_code=True` for running the new Jamba architecture.

<details>
<summary><strong>Loading the model in half precision</strong></summary>

  The published checkpoint is saved in BF16. To load it into RAM in BF16/FP16, specify `torch_dtype`:
  
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base",
                                             torch_dtype=torch.bfloat16)    # you can also use torch_dtype=torch.float16
```

When using half precision, enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. To use it, ensure the model is on a CUDA device. Since the model is too big to fit on a single 80GB GPU, parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base",
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2",
                                             device_map="auto")
```

</details>
<details><summary><strong>Load the model in 8-bit</strong></summary>

  **Using 8-bit precision, up to 140K sequence lengths can fit on a single 80GB GPU.** Quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). To exclude Mamba blocks from quantization to prevent model quality degradation:

```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
                                         llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2",
                                             quantization_config=quantization_config)
```
</details>

### Fine-tuning example

Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). Fine-tune it using any technique of your choice. Here's an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library:

```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("danielpark/asp-9b-inst-base")
model = AutoModelForCausalLM.from_pretrained("danielpark/asp-9b-inst-base", device_map='auto')

dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config = LoraConfig(
    r=8,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)

trainer.train()
```


## Further
Check [ai21labs/Jamba-tiny-random](https://huggingface.co/ai21labs/Jamba-tiny-random), which has 128M parameters (instead of 52B), and is initialized with random weights and did not undergo any training.