""" Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng: https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py 2) huggingface/transformers PyTorch implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py """ import math,time import os import inspect from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F PREV_X_TIME = 0 NUM_STATE = 1 DEN_STATE = 2 MAX_STATE = 3 PREV_X_CHANNEL = 4 # copied from nanoGPT class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) # learn from GPT-4 from unittest.mock import patch class CudaNotAvailable: def __enter__(self): self.patcher = patch("torch.cuda.is_available", return_value=False) self.patcher.start() def __exit__(self, exc_type, exc_value, traceback): self.patcher.stop() # https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21 class L2Wrap(torch.autograd.Function): @staticmethod def forward(ctx, loss, y): ctx.save_for_backward(y) return loss @staticmethod def backward(ctx, grad_output): y = ctx.saved_tensors[0] # to encourage the logits to be close to 0 factor = 1e-4 / (y.shape[0] * y.shape[1]) maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) return (grad_output, gy) class ChannelMixing(nn.Module): def __init__(self,config,layer_id): super().__init__() self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.layer_id = layer_id n_embd = config.n_embd intermediate_size = ( config.intermediate_size if config.intermediate_size is not None else 4 * n_embd ) ## Learnable Matrix self.key_proj = nn.Linear(n_embd,intermediate_size,bias=False) self.value_proj = nn.Linear(intermediate_size,n_embd,bias=False) self.receptance_proj = nn.Linear(n_embd,n_embd,bias=False) ## Learnable Vector self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd)) def forward(self,x,state=None): # x = (Batch,Time,Channel) if state is not None: prev_x = state[self.layer_id,:,[PREV_X_CHANNEL],:] state[self.layer_id,:,[PREV_X_CHANNEL],:] = x else: prev_x = self.time_shift(x) ## R receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance) receptance = self.receptance_proj(receptance) receptance = F.sigmoid(receptance) # K key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key) key = self.key_proj(key) # V value = self.value_proj(torch.square(torch.relu(key))) ## output out = receptance * value return out, state class TimeMixing(nn.Module): def __init__(self,config,layer_id): super().__init__() self.config = config self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.layer_id = layer_id n_embd = config.n_embd attn_sz = n_embd ## learnable matrix self.key_proj = nn.Linear(n_embd, attn_sz, bias=False) self.value_proj = nn.Linear(n_embd, attn_sz, bias=False) self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False) self.output_proj = nn.Linear(attn_sz, n_embd, bias=False) ## learnable vector self.time_decay = nn.Parameter(torch.empty(attn_sz)) self.time_first = nn.Parameter(torch.empty(attn_sz)) self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd)) def forward(self,x,state=None): # x = (Batch,Time,Channel) if state is not None: prev_x = state[self.layer_id,:,[PREV_X_TIME],:] state[self.layer_id,:,[PREV_X_TIME],:] = x else: prev_x = self.time_shift(x) # K key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key) key = self.key_proj(key) # V value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value) value = self.value_proj(value) # R receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance) receptance = self.receptance_proj(receptance) receptance = F.sigmoid(receptance) # WKV wkv, state = self.wkv_function(key,value,use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,state=state) # RWKV rwkv = receptance * wkv rwkv = self.output_proj(rwkv) return rwkv, state def wkv_function(self,key,value,use_customized_cuda_kernel,state=None): ## essentially, this customized cuda kernel delivers a faster for loop across time steps ## only for training and evaluating loss and ppl if state is None and use_customized_cuda_kernel: B, T, C = key.size() return WKVKernel.apply(B, T, C, self.time_decay, self.time_first, key, value), None ## raw wkv function (from Huggingface Implementation) ## only for generation (because using raw pytorch for loop to train the model would be super super slow) else: _, seq_length, _ = key.size() output = torch.zeros_like(key) debug_mode = False if state is None: ## only for debug purpose when use_customized_cuda_kernel=False and state is None debug_mode = True num_state = torch.zeros_like(key[:, 0], dtype=torch.float32) den_state = torch.zeros_like(key[:, 0], dtype=torch.float32) max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38 else: num_state = state[self.layer_id,:,NUM_STATE,:] den_state = state[self.layer_id,:,DEN_STATE,:] max_state = state[self.layer_id,:,MAX_STATE,:] time_decay = -torch.exp(self.time_decay) for current_index in range(seq_length): current_key = key[:, current_index].float() current_value = value[:, current_index] # wkv computation at time t max_for_output = torch.maximum(max_state, current_key + self.time_first) e1 = torch.exp(max_state - max_for_output) e2 = torch.exp(current_key + self.time_first - max_for_output) numerator = e1 * num_state + e2 * current_value denominator = e1 * den_state + e2 output[:, current_index] = (numerator / denominator).to(output.dtype) # Update state for next iteration max_for_state = torch.maximum(max_state + time_decay, current_key) e1 = torch.exp(max_state + time_decay - max_for_state) e2 = torch.exp(current_key - max_for_state) num_state = e1 * num_state + e2 * current_value den_state = e1 * den_state + e2 max_state = max_for_state if debug_mode: return output, None else: state[self.layer_id,:,NUM_STATE,:] = num_state state[self.layer_id,:,DEN_STATE,:] = den_state state[self.layer_id,:,MAX_STATE,:] = max_state return output, state class Block(nn.Module): def __init__(self, config,layer_id): super().__init__() self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) self.attn = TimeMixing(config,layer_id) self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.ffn = ChannelMixing(config,layer_id) def forward(self, x, state = None): # state: [batch_size, 5 , n_embd] # time mixing residual = x x,state = self.attn(self.ln_1(x),state=state) x = x + residual # channel mixing residual = x x, state = self.ffn(self.ln_2(x),state=state) x = x + residual return x, state @dataclass class RWKVConfig: block_size: int = 1024 # same as nanoGPT vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: int = 12 n_embd: int = 768 bias: bool = True # bias in LayerNorms, in RWKV, all bias in Linear is False intermediate_size: int = None # intermediate_size in channel-mixing use_customized_cuda_kernel: bool = True dtype: str = "float16" ## bfloat16 is not supported in V100 rescale_every: int = 6 ## mysterious trick, only applies when inference class RWKV(nn.Module): def __init__(self, config,lr_init=0.0008): super().__init__() assert config.vocab_size is not None assert config.block_size is not None self.config = config self.lr_init = lr_init ## used to initialize embedding parameters self.rwkv = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), ln_p = LayerNorm(config.n_embd, bias=config.bias), h = nn.ModuleList([Block(config,layer_id) for layer_id in range(config.n_layer)]), ln_f = LayerNorm(config.n_embd, bias=config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.apply(self._init_weights) print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) if self.config.use_customized_cuda_kernel: ## load customized cuda kernel self.load_cuda_kernel(config.dtype) def get_num_params(self, non_embedding=True): """ Return the number of parameters in the model. For non-embedding count (default), the token embeddings get subtracted. """ n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.rwkv.wte.weight.numel() return n_params def _init_weights(self, module): ## initialize Vector Parameters in TimeMixing if isinstance(module,TimeMixing): layer_id = module.layer_id n_layer = self.config.n_layer n_embd = self.config.n_embd attn_sz = n_embd with torch.no_grad(): ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0 ddd = torch.ones(1, 1, n_embd) for i in range(n_embd): ddd[0, 0, i] = i / n_embd decay_speed = torch.ones(attn_sz) for h in range(attn_sz): decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1) module.time_decay = nn.Parameter(decay_speed) zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5 module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) ## initialize Vector Parameters in ChannelMixing elif isinstance(module,ChannelMixing): layer_id = module.layer_id n_layer = self.config.n_layer n_embd = self.config.n_embd with torch.no_grad(): # fancy init of time_mix ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0 ddd = torch.ones(1, 1, n_embd) for i in range(n_embd): ddd[0, 0, i] = i / n_embd module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) module.time_mix_receptance = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) ## initialize Linear Layer and Embedding Layer elif isinstance(module,(nn.Embedding,nn.Linear)): weight = module.weight shape = weight.shape gain = 1.0 scale = 1.0 ## get the current name of the parameters for _name,_parameters in self.named_parameters(): if id(_parameters) == id(weight): current_module_name = _name # print(current_module_name) ## Embedding if isinstance(module, nn.Embedding): gain = math.sqrt(max(shape[0], shape[1])) scale = -1 * self.lr_init ## Linear elif isinstance(module,nn.Linear): if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) ## initialize some matrix to be all ZEROS for name in [".attn.key_proj.", ".attn.receptance_proj.", ".attn.output_proj.", ".ffn.value_proj.", ".ffn.receptance_proj."]: if name in current_module_name: scale = 0 if current_module_name == 'lm_head.weight': scale = 0.5 if scale == 0: nn.init.zeros_(weight) elif scale < 0: nn.init.uniform_(weight, a=scale, b=-scale) else: nn.init.orthogonal_(weight, gain=gain * scale) def forward(self, idx, targets=None, state=None, return_state=False): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" x = self.rwkv.wte(idx) x = self.rwkv.ln_p(x) # x = self.rwkv.drop(x) for block_idx,block in enumerate(self.rwkv.h): x, state = block(x,state) if state is not None: ## in generation mode if ( self.config.rescale_every > 0 and (block_idx + 1) % self.config.rescale_every == 0 ): x = x/2 x = self.rwkv.ln_f(x) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) if self.training: loss = L2Wrap.apply(loss,logits) # from RWKV-LM else: # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim loss = None if return_state: return logits, loss, state else: return logits, loss def crop_block_size(self, block_size): assert block_size <= self.config.block_size self.config.block_size = block_size @classmethod def from_pretrained(cls, model_type,use_customized_cuda_kernel=True,dtype="float16"): assert model_type in { 'RWKV/rwkv-4-169m-pile', "RWKV/rwkv-4-430m-pile", "RWKV/rwkv-4-1b5-pile", "RWKV/rwkv-4-3b-pile", "RWKV/rwkv-4-7b-pile", "RWKV/rwkv-raven-7b", "RWKV/rwkv-raven-1b5", "RWKV/rwkv-raven-3b", "RWKV/rwkv-4-14b-pile", } print("loading weights from pretrained RWKV: %s" % model_type) # init a huggingface/transformers model from transformers import RwkvForCausalLM,RwkvConfig hf_config = RwkvConfig.from_pretrained(model_type) with CudaNotAvailable(): ## avoid HF load kernel hf_model = RwkvForCausalLM.from_pretrained(model_type) # create a from-scratch initialized RWKV model config = { "vocab_size":50277, "n_layer":hf_config.num_hidden_layers, "n_embd":hf_config.hidden_size, "intermediate_size":hf_config.intermediate_size, "use_customized_cuda_kernel":use_customized_cuda_kernel, "dtype": dtype, } config = RWKVConfig(**config) model = RWKV(config) num_layers = config.n_layer ## create mapping from the parameter name in RWKV to that of HF-RWKV mapping = { "rwkv.wte.weight":"rwkv.embeddings.weight", "rwkv.ln_p.weight":"rwkv.blocks.0.pre_ln.weight", "rwkv.ln_p.bias":"rwkv.blocks.0.pre_ln.bias", "rwkv.ln_f.weight":"rwkv.ln_out.weight", "rwkv.ln_f.bias":"rwkv.ln_out.bias", "lm_head.weight":"head.weight", **{f"rwkv.h.{layer_id}.ln_{norm_id}.weight":f"rwkv.blocks.{layer_id}.ln{norm_id}.weight" for layer_id in range(num_layers) for norm_id in [1,2]}, **{f"rwkv.h.{layer_id}.ln_{norm_id}.bias":f"rwkv.blocks.{layer_id}.ln{norm_id}.bias" for layer_id in range(num_layers) for norm_id in [1,2]}, **{f"rwkv.h.{layer_id}.attn.{_type}":f"rwkv.blocks.{layer_id}.attention.{_type}" for layer_id in range(num_layers) for _type in ["time_decay","time_first",'time_mix_key','time_mix_value',"time_mix_receptance"]}, **{f"rwkv.h.{layer_id}.attn.{_type}_proj.weight":f"rwkv.blocks.{layer_id}.attention.{_type}.weight" for layer_id in range(num_layers) for _type in ["key","value",'receptance',"output"]}, **{f"rwkv.h.{layer_id}.ffn.{_type}":f"rwkv.blocks.{layer_id}.feed_forward.{_type}" for layer_id in range(num_layers) for _type in ['time_mix_key',"time_mix_receptance"]}, **{f"rwkv.h.{layer_id}.ffn.{_type}_proj.weight":f"rwkv.blocks.{layer_id}.feed_forward.{_type}.weight" for layer_id in range(num_layers) for _type in ["key","value",'receptance']}, } mapped_set = [mapping[x] for x in model.state_dict().keys()] assert set(mapped_set) == set(hf_model.state_dict().keys()) sd = model.state_dict() hf_sd = hf_model.state_dict() for k1,k2 in mapping.items(): assert sd[k1].shape == hf_sd[k2].shape,(k1,k2) sd[k1].copy_(hf_sd[k2]) return model # def configure_optimizers(self,weight_decay,learning_rate,betas,device_type): # # lr_1x = set() # # lr_2x = set() # # lr_3x = set() # # for n, p in self.named_parameters(): # # if "time_mix" in n:lr_1x.add(n) # # elif "time_decay" in n:lr_2x.add(n) # # elif "time_first" in n:lr_3x.add(n) # # else:lr_1x.add(n) # # lr_1x = sorted(list(lr_1x)) # # lr_2x = sorted(list(lr_2x)) # # lr_3x = sorted(list(lr_3x)) # # param_dict = {n: p for n, p in self.named_parameters()} # # optim_groups = [ # # {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, # # {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, # # {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, # # ] # optim_groups = [{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},] # fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters # use_fused = fused_available and device_type == 'cuda' # extra_args = dict(fused=True) if use_fused else dict() # optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas, eps=1e-8, weight_decay=weight_decay,amsgrad=False,**extra_args) # return optimizer def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") # Create AdamW optimizer and use the fused version if it is available fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and device_type == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) print(f"using fused AdamW: {use_fused}") return optimizer def estimate_mfu(self, fwdbwd_per_iter, dt): """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ # first estimate the number of flops we do per iteration. # see RWKV paper Appendix C as ref: https://arxiv.org/abs/2305.13048 cfg = self.config L, V, D = cfg.n_layer, cfg.vocab_size, cfg.n_embd # Note there is a typo in the RWKV paper. Forward pass is 2*fn, forward # and backward is 6*fn. flops_per_token = 2*(V*D + 13*(V**2)*L) flops_per_fwdbwd = 3*flops_per_token flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter # express our flops throughput as ratio of A100 bfloat16 peak flops flops_achieved = flops_per_iter * (1.0/dt) # per second # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet.pdf if cfg.dtype == 'bfloat16': flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS elif cfg.dtype == 'float16': flops_promised = 312e12 # A100 GPU float16 peak flops is 312 TFLOPS else: #dtype == float32 flops_promised = 19.5e12 # A100 GPU float32 peak flops is 19.5 TFLOPS mfu = flops_achieved / flops_promised return mfu def init_state(self,batch_size,device): n_state = len([PREV_X_TIME,NUM_STATE,DEN_STATE,MAX_STATE,PREV_X_CHANNEL]) state = torch.zeros( (self.config.n_layer,batch_size,n_state,self.config.n_embd), dtype=torch.float32, device=device, ) state[:,:,MAX_STATE,:] -= 1e30 return state def scale_parameters(self): if self.config.rescale_every > 0: with torch.no_grad(): for block_id,block in enumerate(self.rwkv.h): block.attn.output_proj.weight.div_(2 ** int(block_id // self.config.rescale_every)) block.ffn.value_proj.weight.div_(2 ** int(block_id // self.config.rescale_every)) self.scaled = True def unscale_parameters(self): if self.config.rescale_every > 0 and self.scaled: with torch.no_grad(): for block_id,block in enumerate(self.rwkv.h): block.attn.output_proj.weight.mul_(2 ** int(block_id // self.config.rescale_every)) block.ffn.value_proj.weight.mul_(2 ** int(block_id // self.config.rescale_every)) @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ idx: (batch_size,seq_len) """ batch_size,seq_len = idx.shape state = self.init_state(batch_size,idx.device) for seq_id in range(seq_len): logits, _, state = self(idx[:,[seq_id]], state = state, return_state=True) for _ in range(max_new_tokens): # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1) logits, _, state = self(idx_next, state=state, return_state=True) return idx def load_cuda_kernel(self,dtype): from torch.utils.cpp_extension import load T_MAX = self.config.block_size RWKV_FLOAT_MODE = dtype if RWKV_FLOAT_MODE == "bfloat16": wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): ctx.B = B ctx.T = T ctx.C = C assert T <= T_MAX assert B * C % min(C, 32) == 0 w = -torch.exp(w.float().contiguous()) u = u.contiguous().bfloat16() k = k.contiguous() v = v.contiguous() y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) return y @staticmethod def backward(ctx, gy): B = ctx.B T = ctx.T C = ctx.C assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) return (None, None, None, gw, gu, gk, gv) else: wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): ctx.B = B ctx.T = T ctx.C = C assert T <= T_MAX assert B * C % min(C, 32) == 0 if "32" in RWKV_FLOAT_MODE: w = -torch.exp(w.contiguous()) u = u.contiguous() k = k.contiguous() v = v.contiguous() else: w = -torch.exp(w.float().contiguous()) u = u.float().contiguous() k = k.float().contiguous() v = v.float().contiguous() y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) if "32" in RWKV_FLOAT_MODE: return y elif RWKV_FLOAT_MODE == "float16": return y.half() @staticmethod def backward(ctx, gy): B = ctx.B T = ctx.T C = ctx.C assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) if "32" in RWKV_FLOAT_MODE: wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) else: wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) if "32" in RWKV_FLOAT_MODE: return (None, None, None, gw, gu, gk, gv) elif RWKV_FLOAT_MODE == "float16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) global WKVKernel WKVKernel = WKV