import copy |
import inspect |
import math |
import numpy as np |
import os |
from dataclasses import dataclass, field |
from typing import List, Optional |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from elm.utils import * |
from elm.positional_embeddings import * |
def get_elm_model_map(model_name): |
"""Map the model type to corresponding class.""" |
elm_model_map = { |
"rambutan": RambutanSlice, |
} |
return elm_model_map.get(model_name, RambutanSlice) |
@dataclass |
class ModelArgs: |
"""ELM Model Args""" |
model_name_or_path: str = "ELM" |
compile_model: bool = False |
elm_model_class: Optional[str] = "rambutan" |
hidden_size: Optional[int] = 2048 |
max_inp_len: Optional[int] = 2048 |
attn_window_size: Optional[int] = max_inp_len |
num_attention_heads: Optional[int] = 32 |
layernorm_eps: float = 1e-5 |
attention_dropout: float = 0.1 |
hidden_dropout: float = 0.1 |
num_layers: Optional[int] = 16 |
bits: Optional[int] = 256 |
vocab_size: Optional[int] = 50304 |
dropout: Optional[int] = 0.1 |
use_rotary_embeddings: Optional[bool] = True |
tokenizer: Optional[str] = None |
class ELM(torch.nn.Module): |
"""ELM (SliceX GPT) model.""" |
def __init__(self, |
model_args: ModelArgs): |
"""Initialize an ELM model instance.""" |
super().__init__() |
self.model_args = model_args |
elm_model_class = model_args.elm_model_class |
hidden_size = model_args.hidden_size |
max_inp_len = model_args.max_inp_len |
num_attention_heads = model_args.num_attention_heads |
layernorm_eps = model_args.layernorm_eps |
attention_dropout = model_args.attention_dropout |
hidden_dropout = model_args.hidden_dropout |
num_layers = model_args.num_layers |
bits = model_args.bits |
vocab_size = model_args.vocab_size |
use_rotary_embeddings = model_args.use_rotary_embeddings |
layer_class = get_elm_model_map(elm_model_class) |
self.slice_transformer = torch.nn.ModuleDict(dict( |
temb = torch.nn.Embedding(vocab_size, hidden_size), |
pemb = torch.nn.Embedding(max_inp_len, hidden_size) if not use_rotary_embeddings else None, |
drop = torch.nn.Dropout(hidden_dropout), |
h = torch.nn.ModuleList([ layer_class(model_args=model_args) for _ in range(num_layers) ]), |
ln_f = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps), |
)) |
self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False) |
print("Number of model parameters: %.2fM" % (self.get_num_params(False)/1e6,)) |
def forward(self, |
x: torch.Tensor, |
attention_mask: Optional[torch.Tensor] = None, |
targets: Optional[torch.Tensor] = None): |
device = x.device |
batch, seqlen = x.size() |
tok_emb = self.slice_transformer.temb(x) |
if not self.model_args.use_rotary_embeddings: |
pos = torch.arange(0, seqlen, dtype=torch.long, device=device) |
pos_emb = self.slice_transformer.pemb(pos) |
x = self.slice_transformer.drop(tok_emb + pos_emb) |
else: |
x = self.slice_transformer.drop(tok_emb) |
tlayer_id = 0 |
ignore_index_id = -100 |
loss = torch.zeros(1).to(device) |
loss_denom = 0 |
for tlayer in self.slice_transformer.h: |
x = tlayer(x, attention_mask=attention_mask) |
tlayer_id += 1 |
x = self.slice_transformer.ln_f(x) |
if targets is not None: |
logits = self.lm_head(x) |
shift_logits = logits[..., :-1, :].contiguous() |
shift_targets = targets[..., 1:].contiguous() |
curr_loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), |
shift_targets.view(-1), |
ignore_index=ignore_index_id) |
loss += curr_loss.float() |
loss_denom += 1 |
else: |
logits = self.lm_head(x[:, [-1], :]) |
loss = loss / loss_denom |
return logits, loss |
def get_num_params(self, non_embedding=True): |
""" |
Return the number of parameters in the model. |
For non-embedding count (default), the position embeddings get subtracted. |
This assumes parameter tying between input and final layer embeddings. Oherwise |
If there is no parameter sharing , set the flag to False to include parameters for both layers. |
""" |
n_params = sum(p.numel() for p in self.parameters()) |
if non_embedding and not self.model_args.use_rotary_embeddings: |
n_params -= self.slice_transformer.pemb.weight.numel() |
return n_params |
@torch.no_grad() |
def generate(self, x, max_new_tokens, temperature=0.8, top_k=200, top_p=0.9, |
return_gen_only=False): |
max_inp_len = self.model_args.max_inp_len |
for _ in range(max_new_tokens): |
x_ctxt = x if x.size(1) <= max_inp_len else x[:, -max_inp_len:] |
logits, _ = self(x_ctxt) |
next_id = None |
if temperature <= 0: |
next_id = torch.argmax(logits, dim=-1) |
else: |
logits = logits[:, -1, :] / temperature |
if top_k is not None: |
v, k = torch.topk(logits, min(top_k, logits.size(-1))) |
logits[logits < v[:, [-1]]] = -float('Inf') |
probs = F.softmax(logits, dim=-1) |
if top_p is None: |
next_id = torch.multinomial(probs, num_samples=1) |
else: |
next_id = sample_top_p(probs, top_p) |
x = torch.cat((x, next_id), dim=1) |
if return_gen_only: |
return x[:,-max_new_tokens:] |
return x |
class RambutanMLP(torch.nn.Module): |
"""RambutanMLP version of MLP module used in the ELM (SliceX GPT) Transformer block.""" |
def __init__(self, dim=768, bits=32, dropout = 0.0): |
super(RambutanMLP, self).__init__() |
self.dim = dim |
self.bits = bits |
self.dropout = torch.nn.Dropout(dropout) |
self.A1_c_w = torch.nn.Linear(self.dim, self.bits, bias=True) |
self.Hexperts = 4 |
self.Hexpertemb = torch.nn.Embedding(self.bits, self.dim) |
self.expert_aggr = torch.nn.Linear(self.Hexperts, 1) |
def forward(self, x): |
h_c = torch.nn.functional.softmax(self.A1_c_w(x), dim=-1) |
v, i = torch.topk(h_c, self.Hexperts) |
if len(x.size()) < 3: |
p = v.unsqueeze(-1).expand(-1,-1,self.dim) |
else: |
p = v.unsqueeze(-1).expand(-1,-1,-1,self.dim) |
h_emb = p * self.Hexpertemb(i) |
if len(x.size()) < 3: |
out = self.expert_aggr(h_emb.transpose(1,2)).reshape(h_emb.size(0), -1) |
else: |
out = self.expert_aggr(h_emb.transpose(-2,-1)).reshape(x.size()) |
out = x * out |
out = self.dropout(out) |
return out |
class RambutanSlice(torch.nn.Module): |
"""Rambutan version of ELM (SliceX GPT) Transformer block.""" |
def __init__(self, |
model_args: ModelArgs): |
super().__init__() |
self.model_args = model_args |
self.num_attention_heads = model_args.num_attention_heads |
self.kv_channels = model_args.hidden_size // model_args.num_attention_heads |
self.ln1 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps) |
self.ln2 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps) |
self.mlp = RambutanMLP(dim=model_args.hidden_size, bits=model_args.bits) |
self.cattn = RambutanCausalSelfAttention(model_args=model_args) |
def forward(self, |
x: torch.Tensor, |
attention_mask: torch.Tensor = None): |
res = x |
x = self.ln1(x) |
x = self.cattn(x, attention_mask=attention_mask) |
x = res + x |
res = x |
x = self.ln2(x) |
x = self.mlp(x) |
return x + res |
class RambutanCausalSelfAttention(torch.nn.Module): |
"""Rambutan version of self-attention module used in the ELM (SliceX GPT) transformer block.""" |
def __init__(self, |
model_args: ModelArgs): |
super().__init__() |
self.model_args = model_args |
n_embd = model_args.hidden_size |
n_head = model_args.num_attention_heads |
bias = False |
dropout = model_args.attention_dropout |
assert n_embd % n_head == 0 |
self.c_attn = torch.nn.Linear(n_embd, 3 * n_embd, bias=bias) |
self.c_proj = torch.nn.Linear(n_embd, n_embd, bias=bias) |
self.attn_dropout = torch.nn.Dropout(dropout) |
self.resid_dropout = torch.nn.Dropout(dropout) |
self.n_head = n_head |
self.n_embd = n_embd |
self.dropout = dropout |
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
if not self.flash: |
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") |
self.rotary_embeddings = ( |
RotaryEmbedding(n_embd // n_head) if model_args.use_rotary_embeddings else None |
) |
def forward(self, x, attention_mask: torch.Tensor = None): |
B, T, C = x.size() |
device = x.device |
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
if self.rotary_embeddings: |
q, k = self.rotary_embeddings(q=q, k=k) |
is_causal = True |
attn_mask = None |
if attention_mask is not None: |
att_mask_input = attention_mask |
att_mask_input = att_mask_input.unsqueeze(-1).expand(B, T, T) |
if is_causal: |
att_mask_causal = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device) |
attn_mask = (att_mask_causal * att_mask_input) |
else: |
attn_mask = att_mask_input |
attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T) |
attn_mask.float().to(device) |
if self.flash: |
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=True) |
else: |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
if is_causal and attn_mask is None: |
attn_mask = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device) |
attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T) |
if attn_mask is not None: |
att = att.masked_fill(attn_mask == 0, torch.finfo(att.dtype).min) |
att = F.softmax(att, dim=-1) |
att = self.attn_dropout(att) |
y = att @ v |
y = y.transpose(1, 2).contiguous().view(B, T, C) |
y = self.resid_dropout(self.c_proj(y)) |
return y |
def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None): |
"""Initialize ELM model using default or model_config parameters.""" |
if model_config_dict: |
model_args = ModelArgs(**model_config_dict) |
dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 |
if not torch.cuda.is_available(): |
dtype = torch.bfloat16 |
model = ELM(model_args=model_args).to(dtype=dtype) |
return model |
def get_h_layers_in_ckpt(ckpt_state_dict, |
layer_name_template = 'slice_transformer.h.{layer_num}.'): |
num_layers_in_ckpt = 0 |
from collections import defaultdict |
layer_wise_dict = defaultdict(lambda: defaultdict(list)) |
layer_num_found = True |
while layer_num_found: |
layer_num_found = False |
for layer_name in ckpt_state_dict.keys(): |
if layer_name_template.format(layer_num=num_layers_in_ckpt) in layer_name: |
layer_wise_dict[num_layers_in_ckpt][layer_name] = ckpt_state_dict[layer_name] |
layer_num_found = True |
num_layers_in_ckpt += 1 |
return layer_wise_dict |
def load_elm_model_from_ckpt(ckpt_dir, device='cuda', load_partial=False, model_args=ModelArgs(), get_num_layers_from_ckpt=True): |
"""Load ELM model from local checkpoint.""" |
print(f"Loading ELM checkpoint from {ckpt_dir}") |
ckpt_path = os.path.join(ckpt_dir, 'ckpt.pt') |
checkpoint = torch.load(ckpt_path, map_location=device) |
if get_num_layers_from_ckpt: |
layer_name_template = 'slice_transformer.h.{layer_num}.' |
ckpt_layer_wise_dict = get_h_layers_in_ckpt(checkpoint['model'], |
layer_name_template = layer_name_template) |
model_args.num_layers = len(ckpt_layer_wise_dict) |
model = init_elm_model(model_args=model_args, device=device) |
ckpt_state_dict = checkpoint['model'] |
unwanted_prefix = '_orig_mod.' |
for k,v in list(ckpt_state_dict.items()): |
if k.startswith(unwanted_prefix): |
ckpt_state_dict[k[len(unwanted_prefix):]] = ckpt_state_dict.pop(k) |
if load_partial: |
mod_state_dict = model.state_dict() |
for k,v in list(ckpt_state_dict.items()): |
if k in mod_state_dict: |
v_size = v.size() |
mod_size = mod_state_dict[k].size() |
if v_size == mod_size: |
mod_state_dict[k] = v |
else: |
if len(v_size) == 1: |
mod_state_dict[k][:v_size[-1]] = v |
elif len(v_size) == 2: |
mod_state_dict[k][:v_size[-2], :v_size[-1]] = v |
ckpt_state_dict = mod_state_dict |
load_status = model.load_state_dict(ckpt_state_dict) |
print(load_status) |
model.to(device) |
return model |
def sample_top_p(probs, threshold): |
"""Perform top-p sampling on probability distribution using a threshold.""" |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
probs_sum = torch.cumsum(probs_sort, dim=-1) |
mask = probs_sum - probs_sort > threshold |
probs_sort[mask] = 0.0 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
next_token = torch.multinomial(probs_sort, num_samples=1) |
next_token = torch.gather(probs_idx, -1, next_token) |
return next_token |