File size: 14,740 Bytes
793da2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 |
# Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
import copy
import inspect
import math
import numpy as np
import os
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from elm.utils import *
from elm.positional_embeddings import *
def get_elm_model_map(model_name):
"""Map the model type to corresponding class."""
elm_model_map = {
"rambutan": RambutanSlice,
}
return elm_model_map.get(model_name, RambutanSlice)
@dataclass
class ModelArgs:
"""ELM Model Args"""
model_name_or_path: str = "ELM"
compile_model: bool = False
elm_model_class: Optional[str] = "rambutan"
hidden_size: Optional[int] = 2048
max_inp_len: Optional[int] = 2048
attn_window_size: Optional[int] = max_inp_len
num_attention_heads: Optional[int] = 32
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
hidden_dropout: float = 0.1
num_layers: Optional[int] = 16
bits: Optional[int] = 256
vocab_size: Optional[int] = 50304
dropout: Optional[int] = 0.1
use_rotary_embeddings: Optional[bool] = True
tokenizer: Optional[str] = None
class ELM(torch.nn.Module):
"""ELM (SliceX GPT) model."""
def __init__(self,
model_args: ModelArgs):
"""Initialize an ELM model instance."""
super().__init__()
self.model_args = model_args
elm_model_class = model_args.elm_model_class
hidden_size = model_args.hidden_size
max_inp_len = model_args.max_inp_len
num_attention_heads = model_args.num_attention_heads
layernorm_eps = model_args.layernorm_eps
attention_dropout = model_args.attention_dropout
hidden_dropout = model_args.hidden_dropout
num_layers = model_args.num_layers
bits = model_args.bits
vocab_size = model_args.vocab_size
use_rotary_embeddings = model_args.use_rotary_embeddings
layer_class = get_elm_model_map(elm_model_class)
self.slice_transformer = torch.nn.ModuleDict(dict(
temb = torch.nn.Embedding(vocab_size, hidden_size),
pemb = torch.nn.Embedding(max_inp_len, hidden_size) if not use_rotary_embeddings else None,
drop = torch.nn.Dropout(hidden_dropout),
h = torch.nn.ModuleList([ layer_class(model_args=model_args) for _ in range(num_layers) ]),
ln_f = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps),
))
self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
print("Number of model parameters: %.2fM" % (self.get_num_params(False)/1e6,))
def forward(self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
targets: Optional[torch.Tensor] = None):
device = x.device
batch, seqlen = x.size()
tok_emb = self.slice_transformer.temb(x)
if not self.model_args.use_rotary_embeddings:
pos = torch.arange(0, seqlen, dtype=torch.long, device=device)
pos_emb = self.slice_transformer.pemb(pos)
x = self.slice_transformer.drop(tok_emb + pos_emb)
else:
x = self.slice_transformer.drop(tok_emb)
tlayer_id = 0
ignore_index_id = -100
loss = torch.zeros(1).to(device)
loss_denom = 0
for tlayer in self.slice_transformer.h:
x = tlayer(x, attention_mask=attention_mask)
tlayer_id += 1
x = self.slice_transformer.ln_f(x)
if targets is not None:
logits = self.lm_head(x)
shift_logits = logits[..., :-1, :].contiguous()
shift_targets = targets[..., 1:].contiguous()
curr_loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)),
shift_targets.view(-1),
ignore_index=ignore_index_id)
loss += curr_loss.float()
loss_denom += 1
else:
logits = self.lm_head(x[:, [-1], :])
loss = loss / loss_denom
return logits, loss
def get_num_params(self, non_embedding=True):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
This assumes parameter tying between input and final layer embeddings. Oherwise
If there is no parameter sharing , set the flag to False to include parameters for both layers.
"""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding and not self.model_args.use_rotary_embeddings:
n_params -= self.slice_transformer.pemb.weight.numel()
return n_params
@torch.no_grad()
def generate(self, x, max_new_tokens, temperature=0.8, top_k=200, top_p=0.9,
return_gen_only=False):
max_inp_len = self.model_args.max_inp_len
for _ in range(max_new_tokens):
x_ctxt = x if x.size(1) <= max_inp_len else x[:, -max_inp_len:]
logits, _ = self(x_ctxt)
next_id = None
if temperature <= 0:
next_id = torch.argmax(logits, dim=-1)
else:
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, k = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
if top_p is None:
next_id = torch.multinomial(probs, num_samples=1)
else:
next_id = sample_top_p(probs, top_p)
x = torch.cat((x, next_id), dim=1)
if return_gen_only:
return x[:,-max_new_tokens:]
return x
class RambutanMLP(torch.nn.Module):
"""RambutanMLP version of MLP module used in the ELM (SliceX GPT) Transformer block."""
def __init__(self, dim=768, bits=32, dropout = 0.0):
super(RambutanMLP, self).__init__()
self.dim = dim
self.bits = bits
self.dropout = torch.nn.Dropout(dropout)
self.A1_c_w = torch.nn.Linear(self.dim, self.bits, bias=True)
self.Hexperts = 4
self.Hexpertemb = torch.nn.Embedding(self.bits, self.dim)
self.expert_aggr = torch.nn.Linear(self.Hexperts, 1)
def forward(self, x):
h_c = torch.nn.functional.softmax(self.A1_c_w(x), dim=-1)
v, i = torch.topk(h_c, self.Hexperts)
if len(x.size()) < 3:
p = v.unsqueeze(-1).expand(-1,-1,self.dim)
else:
p = v.unsqueeze(-1).expand(-1,-1,-1,self.dim)
h_emb = p * self.Hexpertemb(i)
if len(x.size()) < 3:
out = self.expert_aggr(h_emb.transpose(1,2)).reshape(h_emb.size(0), -1)
else:
out = self.expert_aggr(h_emb.transpose(-2,-1)).reshape(x.size())
out = x * out
out = self.dropout(out)
return out
class RambutanSlice(torch.nn.Module):
"""Rambutan version of ELM (SliceX GPT) Transformer block."""
def __init__(self,
model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.num_attention_heads = model_args.num_attention_heads
self.kv_channels = model_args.hidden_size // model_args.num_attention_heads
self.ln1 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps)
self.ln2 = torch.nn.LayerNorm(model_args.hidden_size, eps=model_args.layernorm_eps)
self.mlp = RambutanMLP(dim=model_args.hidden_size, bits=model_args.bits)
self.cattn = RambutanCausalSelfAttention(model_args=model_args)
def forward(self,
x: torch.Tensor,
attention_mask: torch.Tensor = None):
res = x
x = self.ln1(x)
x = self.cattn(x, attention_mask=attention_mask)
x = res + x
res = x
x = self.ln2(x)
x = self.mlp(x)
return x + res
class RambutanCausalSelfAttention(torch.nn.Module):
"""Rambutan version of self-attention module used in the ELM (SliceX GPT) transformer block."""
def __init__(self,
model_args: ModelArgs):
super().__init__()
self.model_args = model_args
n_embd = model_args.hidden_size
n_head = model_args.num_attention_heads
bias = False
dropout = model_args.attention_dropout
assert n_embd % n_head == 0
self.c_attn = torch.nn.Linear(n_embd, 3 * n_embd, bias=bias)
self.c_proj = torch.nn.Linear(n_embd, n_embd, bias=bias)
self.attn_dropout = torch.nn.Dropout(dropout)
self.resid_dropout = torch.nn.Dropout(dropout)
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
self.rotary_embeddings = (
RotaryEmbedding(n_embd // n_head) if model_args.use_rotary_embeddings else None
)
def forward(self, x, attention_mask: torch.Tensor = None):
B, T, C = x.size()
device = x.device
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
if self.rotary_embeddings:
q, k = self.rotary_embeddings(q=q, k=k)
is_causal = True
attn_mask = None
if attention_mask is not None:
att_mask_input = attention_mask
att_mask_input = att_mask_input.unsqueeze(-1).expand(B, T, T)
if is_causal:
att_mask_causal = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device)
attn_mask = (att_mask_causal * att_mask_input)
else:
attn_mask = att_mask_input
attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T)
attn_mask.float().to(device)
if self.flash:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=True)
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
if is_causal and attn_mask is None:
attn_mask = torch.tril(torch.ones(T, T)).view(1,T,T).expand(B,T,T).to(device)
attn_mask = attn_mask.unsqueeze(1).expand(B, self.n_head, T, T)
if attn_mask is not None:
att = att.masked_fill(attn_mask == 0, torch.finfo(att.dtype).min)
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None):
"""Initialize ELM model using default or model_config parameters."""
if model_config_dict:
model_args = ModelArgs(**model_config_dict)
dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
model = ELM(model_args=model_args).to(dtype=dtype)
return model
def get_h_layers_in_ckpt(ckpt_state_dict,
layer_name_template = 'slice_transformer.h.{layer_num}.'):
num_layers_in_ckpt = 0
from collections import defaultdict
layer_wise_dict = defaultdict(lambda: defaultdict(list))
layer_num_found = True
while layer_num_found:
layer_num_found = False
for layer_name in ckpt_state_dict.keys():
if layer_name_template.format(layer_num=num_layers_in_ckpt) in layer_name:
layer_wise_dict[num_layers_in_ckpt][layer_name] = ckpt_state_dict[layer_name]
layer_num_found = True
num_layers_in_ckpt += 1
return layer_wise_dict
def load_elm_model_from_ckpt(ckpt_dir, device='cuda', load_partial=False, model_args=ModelArgs(), get_num_layers_from_ckpt=True):
"""Load ELM model from local checkpoint."""
print(f"Loading ELM checkpoint from {ckpt_dir}")
ckpt_path = os.path.join(ckpt_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
if get_num_layers_from_ckpt:
layer_name_template = 'slice_transformer.h.{layer_num}.'
ckpt_layer_wise_dict = get_h_layers_in_ckpt(checkpoint['model'],
layer_name_template = layer_name_template)
model_args.num_layers = len(ckpt_layer_wise_dict)
model = init_elm_model(model_args=model_args, device=device)
ckpt_state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(ckpt_state_dict.items()):
if k.startswith(unwanted_prefix):
ckpt_state_dict[k[len(unwanted_prefix):]] = ckpt_state_dict.pop(k)
if load_partial:
mod_state_dict = model.state_dict()
for k,v in list(ckpt_state_dict.items()):
if k in mod_state_dict:
v_size = v.size()
mod_size = mod_state_dict[k].size()
if v_size == mod_size:
mod_state_dict[k] = v
else:
if len(v_size) == 1:
mod_state_dict[k][:v_size[-1]] = v
elif len(v_size) == 2:
mod_state_dict[k][:v_size[-2], :v_size[-1]] = v
ckpt_state_dict = mod_state_dict
load_status = model.load_state_dict(ckpt_state_dict)
print(load_status)
model.to(device)
return model
def sample_top_p(probs, threshold):
"""Perform top-p sampling on probability distribution using a threshold."""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > threshold
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token |