Edit model card

Jamba-Small v2

This is a pruned version of AI21 Labs' Jamba-v0.1 model that is ~25% the size of Jamba-v0.1.

Model Details

Whereas Jamba-v0.1 contains 4 Jamba blocks, Jamba-Small contains only 1 Jamba block. Jamba-Small's Jamba blocks follow the same structure seen in Jamba-v0.1, with a 1:7 ratio of attention-to-Mamba layers and MoE applied every 2 layers.

Jamba-Small's weights are initialized from various layers in the original Jamba-v0.1 model. For v2, the layer weights are mapped as follows (left is Jamba-Small layer number, right is Jamba-v0.1 layer number):

0: 0,  # Block 0, layer 0 (mamba)
1: 1,  # Block 0, layer 1 (mamba MoE)
2: 6,  # Block 0, layer 6 (mamba)
3: 9,  # Block 1, layer 1 (mamba MoE)
4: 12, # Block 1, layer 4 (transformer)
5: 15, # Block 1, layer 7 (mamba MoE)
6: 24, # Block 3, layer 0 (mamba)
7: 31  # Block 4, layer 7 (mamba MoE)

Note that no additional fine-tuning has been performed on this model. As such, its performance is exceptionally poor. This should not be used in production without additional training.

Model Description

  • Developed by: Nathan Brown (OxxoCodes)
  • Compute provided by: Clemson Palmetto Cluster
  • Model type: Joint Attention and Mamba (Jamba)
  • Language(s) (NLP): English
  • License: Apache 2.0
  • Original model: Jamba-v0.1
  • Jamba paper: https://arxiv.org/pdf/2403.19887.pdf

How to Use

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("OxxoCodes/jamba-small-v2", torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

with torch.no_grad():
    input_ids = tokenizer("There once was a", return_tensors='pt').to(model.device)["input_ids"]
    outputs = model.generate(input_ids, max_new_tokens=216)
    print(tokenizer.batch_decode(outputs))
Downloads last month
17
Safetensors
Model size
13.3B params
Tensor type
BF16
·
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.