Upstream model config
{
"_name_or_path": "output/hermes-llama2-4k/checkpoint-2259",
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pad_token_id": 0,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.32.0.dev0",
"use_cache": false,
"vocab_size": 32000
}
Dataset
DATASET = "abideen/Cosmopedia-100k-pretrain" # @param
from datasets import load_dataset
# converted to BitLinear
class BitLinear(nn.Linear):
def forward(self, x):
w = self.weight # a weight tensor with shape [d, k]
x = x.to(w.device)
RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
x_norm = RMSNorm(x)
# A trick for implementing Straight−Through−Estimator (STE) using detach()
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)
return y
### Create the llama model with our custom config. Convert it to bitnet.
model = LlamaForCausalLM(config)
convert_to_bitnet(model, copy_weights=False)
Training
args = TrainingArguments(
output_dir=output_path,
per_device_train_batch_size=BATCH_SIZE,
logging_steps=100,
gradient_accumulation_steps=2,
num_train_epochs=EPOCHS,
weight_decay=0.01,
warmup_steps=0.1,
lr_scheduler_type="cosine",
learning_rate=LEARNING_RATE,
# max_steps=5000,
save_steps=0.25,
fp16=True,
report_to="wandb"
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=args,
data_collator=data_collator,
train_dataset=tokenized_data["train"],
)
trainer.train()
Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import *
# Load a pretrained BitNet model
model = "saadnaeem/Llama2-70M-Cosmopedia-100k-Pretrain"
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForCausalLM.from_pretrained(model)
def activation_quant(x):
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127)
y = y / scale
return y
def weight_quant(w):
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
u = (w * scale).round().clamp_(-1, 1)
u = u / scale
return u
class BitLinear(nn.Linear):
def forward(self, x):
w = self.weight # a weight tensor with shape [d, k]
x = x.to(w.device)
RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
x_norm = RMSNorm(x)
# A trick for implementing Straight−Through−Estimator (STE) using detach()
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)
return y
def convert_to_bitnet(model, copy_weights):
for name, module in model.named_modules():
# Replace linear layers with BitNet
if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):
for child_name, child_module in module.named_children():
if isinstance(child_module, nn.Linear):
bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0")
if copy_weights:
bitlinear.weight = child_module.weight
if child_module.bias is not None:
bitlinear.bias = child_module.bias
setattr(module, child_name, bitlinear)
# Remove redundant input_layernorms
elif isinstance(module, LlamaDecoderLayer):
for child_name, child_module in module.named_children():
if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm":
setattr(module, child_name, nn.Identity().to(device="cuda:0"))
convert_to_bitnet(model, copy_weights=True)
model.to(device="cuda:0")
prompt = "What is Machine Learning?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generate_ids = model.generate(inputs.input_ids, max_length=50)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- Downloads last month
- 6
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.