import math import torch import torch.nn as nn from torch.nn import functional as F class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) class RWKV_TimeMix_x051a(nn.Module): def __init__(self, config, layer_id): super().__init__() assert config.n_embd % config.n_head == 0 self.head_size = config.n_embd // config.n_head self.n_head = config.n_head with torch.no_grad(): ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0 ddd = torch.ones(1, 1, config.n_embd) for i in range(config.n_embd): ddd[0, 0, i] = i / config.n_embd self.time_maa_k = nn.Parameter( 1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_v = nn.Parameter( 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)) self.time_maa_r = nn.Parameter( 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_maa_g = nn.Parameter( 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) decay_speed = torch.ones(self.n_head) for h in range(self.n_head): decay_speed[h] = -6 + 5 * \ (h / (self.n_head - 1)) ** (0.7 + 1.3 * ratio_0_to_1) self.time_decay = nn.Parameter(decay_speed.unsqueeze(-1)) tmp = torch.zeros(self.n_head) for h in range(self.n_head): tmp[h] = ratio_0_to_1 * (1 - (h / (self.n_head - 1))) self.time_faaaa = nn.Parameter(tmp.unsqueeze(-1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.receptance = nn.Linear( config.n_embd, config.n_embd, bias=config.bias) self.key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.gate = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.output = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.ln_x = nn.GroupNorm(self.n_head, config.n_embd, eps=(1e-5)*64) self.dropout = nn.Dropout(config.dropout) def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) H, N = self.n_head, self.head_size if T % 256 == 0: Q = 256 elif T % 128 == 0: Q = 128 else: Q = T assert T % Q == 0 xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k xv = x + xx * self.time_maa_v xr = x + xx * self.time_maa_r xg = x + xx * self.time_maa_g r = self.receptance(xr).view(B, T, H, N).transpose(1, 2) # receptance k = self.key(xk).view(B, T, H, N).permute(0, 2, 3, 1) # key v = self.value(xv).view(B, T, H, N).transpose(1, 2) # value g = F.silu(self.gate(xg)) # extra gate w = torch.exp(-torch.exp(self.time_decay.float())) # time_decay u = self.time_faaaa.float() # time_first ws = w.pow(Q).view(1, H, 1, 1) ind = torch.arange( Q-1, -1, -1, device=r.device).unsqueeze(0).repeat(H, 1) w = w.repeat(1, Q).pow(ind) wk = w.view(1, H, 1, Q) wb = wk.transpose(-2, -1).flip(2) w = torch.cat([w[:, 1:], u], dim=1) w = F.pad(w, (0, Q)) w = torch.tile(w, [Q]) w = w[:, :-Q].view(-1, Q, 2*Q - 1) w = w[:, :, Q-1:].view(1, H, Q, Q) w = w.to(dtype=r.dtype) # the decay matrix wk = wk.to(dtype=r.dtype) wb = wb.to(dtype=r.dtype) ws = ws.to(dtype=r.dtype) state = torch.zeros(B, H, N, N, device=r.device, dtype=r.dtype) # state y = torch.empty(B, H, T, N, device=r.device, dtype=r.dtype) # output for i in range(T // Q): # the rwkv-x051a operator rr = r[:, :, i*Q:i*Q+Q, :] kk = k[:, :, :, i*Q:i*Q+Q] vv = v[:, :, i*Q:i*Q+Q, :] y[:, :, i*Q:i*Q+Q, :] = ((rr @ kk) * w) @ vv + (rr @ state) * wb state = ws * state + (kk * wk) @ vv y = y.transpose(1, 2).contiguous().view(B * T, C) y = self.ln_x(y).view(B, T, C) * g # output projection y = self.dropout(self.output(y)) return y class RWKV_ChannelMix_x051a(nn.Module): def __init__(self, config, layer_id): super().__init__() self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) ddd = torch.ones(1, 1, config.n_embd) for i in range(config.n_embd): ddd[0, 0, i] = i / config.n_embd self.time_maa_k = nn.Parameter( 1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter( 1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.key = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.value = nn.Linear( 3 * config.n_embd, config.n_embd, bias=config.bias) self.receptance = nn.Linear( config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k xr = x + xx * self.time_maa_r x = self.key(xk) x = torch.relu(x) ** 2 x = self.value(x) x = torch.sigmoid(self.receptance(xr)) * x x = self.dropout(x) return x class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-8): super().__init__() self.scale = dim ** -0.5 self.eps = eps def forward(self, x): norm = torch.norm(x, dim=-1, keepdim=True) * self.scale return x / (norm + self.eps) class GroupedQAttention(nn.Module): def __init__(self, dim, num_heads, groups=4): super().__init__() self.num_heads = num_heads self.groups = groups self.qkvw = nn.Linear(dim, dim * 4, bias=False) self.out = nn.Linear(dim, dim, bias=False) def forward(self, x): batch, seq_len, dim = x.shape qkvw = self.qkvw(x) # GENERATE qkvw_gropus = torch.chunk(qkvw, self.groups, dim=-1) # GENERATE q, k, v, w = [t.chunk(self.groups, dim=-1) for t in qkvw_gropus] q, k, v, w = [ torch.cat([qi, ki, vi, wi], dim=0) for qi, ki, vi, wi in zip(q, k, v, w) ] q, k, v = map( lambda t: t.view(batch * self.groups, self.num_heads, -1, dim // self.num_heads // self.groups).transpose(1, 2), [q, k, v] ) w = w.view(batch * self.groups, self.num_heads, - 1, dim // self.num_heads // self.groups) attn_output = (q @ k.transpose(-2, -1)) * \ (dim // self.num_heads // self.groups) ** -0.5 attn_output = attn_output.softmax(dim=-1) attn_output = (attn_output @ v).transpose(1, 2).reshape(batch, seq_len, dim) return self.out(attn_output * w.reshape(batch, seq_len, dim)) class SlidingWindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=False) self.proj = nn.Linear(dim, dim, bias=False) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.head_dim ** -0.5 # Pad to multiple of window size padding = (self.window_size - N % self.window_size) % self.window_size q = F.pad(q, (0, 0, 0, padding)) k = F.pad(k, (0, 0, 0, padding)) v = F.pad(v, (0, 0, 0, padding)) # Reshape to sliding windows q = q.reshape(B * self.num_heads, self.window_size, -1) k = k.reshape(B * self.num_heads, self.window_size, -1) v = v.reshape(B * self.num_heads, self.window_size, -1) attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = attn @ v attn = attn.reshape(B, self.num_heads, N + padding, self.head_dim) attn = attn[:, :, :N, :].permute(0, 2, 1, 3).reshape(B, N, C) return self.proj(attn) class TinyMoE(nn.Module): def __init__(self, dim, num_experts, num_active_experts, expert_dim, dropout=0.0, expert_capacity_scale=1.0, aux_loss_weight=0.1): super().__init__() self.dim = dim self.num_experts = num_experts self.num_active_experts = num_active_experts self.expert_dim = expert_dim self.dropout = nn.Dropout(dropout) self.gate = nn.Linear(dim, num_experts) self.expert_capacity_scale = expert_capacity_scale self.scaled_expert_dim = int(expert_dim * self.expert_capacity_scale) self.experts = nn.ModuleList( [nn.Linear(dim, self.scaled_expert_dim) for _ in range(num_active_experts)]) self.fc = nn.Linear(self.scaled_expert_dim, dim) # Auxiliary loss self.aux_loss_weight = aux_loss_weight self.expert_diversity_loss = nn.MSELoss() def forward(self, x): b, n, d = x.shape # Compute attention scores scores = self.gate(x).view(b, n, self.num_experts) scores = F.softmax(scores, dim=-1) # Apply dropout to the attention scores scores = self.dropout(scores) # Compute the weighted sum of expert outputs expert_outputs = torch.stack( [exp(x.view(b * n, d)) for exp in self.experts], dim=1) expert_outputs = expert_outputs.view( b, n, self.num_active_experts, self.scaled_expert_dim) weighted_outputs = ( expert_outputs * scores[:, :, :self.num_active_experts].unsqueeze(-1)).sum(dim=2) # Apply the final linear layer output = self.fc(weighted_outputs) # Auxiliary loss: Expert diversity # (b, num_active_experts, scaled_expert_dim) expert_activations = expert_outputs.mean(dim=1) expert_diversity_loss = self.expert_diversity_loss(expert_activations.transpose( 0, 1), torch.zeros_like(expert_activations.transpose(0, 1))) return output, expert_diversity_loss * self.aux_loss_weight def set_expert_capacity(self, expert_capacity_scale): self.expert_capacity_scale = expert_capacity_scale self.scaled_expert_dim = int( self.expert_dim * self.expert_capacity_scale) self.experts = nn.ModuleList([nn.Linear( self.dim, self.scaled_expert_dim) for _ in range(self.num_active_experts)]) self.fc = nn.Linear(self.scaled_expert_dim, self.dim) class Block(nn.Module): def __init__(self, config, layer_id): super().__init__() self.ln_1 = RMSNorm(config.n_embd) self.ln_2 = RMSNorm(config.n_embd) # stay in here because this is a core component self.tmix = RWKV_TimeMix_x051a(config, layer_id) # Add GroupedQAttention instance self.grouped_attn = GroupedQAttention(config.n_embd, config.n_head) # stay in here because this is a core component self.cmix = RWKV_ChannelMix_x051a(config, layer_id) self.sliding_attn = SlidingWindowAttention( config.n_embd, window_size=256, num_heads=config.n_head) self.moe = TinyMoE(config.dim, config.num_experts, config.num_active_experts, config.expert_dim, config.dropout, expert_capacity_scale=1.2, aux_loss_weight=0.01) def forward(self, x): x = x + self.tmix(self.ln_1(x)) x = x + self.cmix(self.ln_2(x)) x = x + self.sliding_attn(x) # Apply sliding window attention x = x + self.grouped_attn(self.tmix(x)) # Apply GroupedQAttention # x = x + self.moe(x) # Apply TinyMoE moe_output, aux_loss = self.moe(x) x = x + moe_output return x class GPT(nn.Module): def __init__(self, config): super().__init__() assert config.vocab_size is not None assert config.block_size is not None self.config = config self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.n_embd), wpe=nn.Embedding(config.block_size, config.n_embd), drop=nn.Dropout(config.dropout), h=nn.ModuleList([Block(config, i) for i in range(config.n_layer)]), ln_f=LayerNorm(config.n_embd, bias=config.bias), )) self.lm_head = nn.Linear( self.config.n_embd, self.config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight # init all weights self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): if pn.endswith('tmix.output.weight'): torch.nn.init.normal_( p, mean=0.0, std=0.02/math.sqrt(2 * self.config.n_layer)) # report number of parameters print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) def get_num_params(self, non_embedding=True): n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.transformer.wpe.weight.numel() return n_params def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) # forward the GPT model itself # token embeddings of shape (b, t, n_embd) tok_emb = self.transformer.wte(idx) # position embeddings of shape (t, n_embd) pos_emb = self.transformer.wpe(pos) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the lm_head on the very last position # note: using list [-1] to preserve the time dim logits = self.lm_head(x[:, [-1], :]) loss = None return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, top_k=None): for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size( 1) <= self.config.block_size else idx[:, -self.config.block_size:] # forward the model to get the logits for the index in the sequence logits, _ = self(idx_cond) # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1) return idx