KoRWKV

RWKV-Runner์—์„œ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ๋ณ€ํ™˜ํ•œ ๋ชจ๋ธ ํŒŒ์ผ

import re

import torch

from transformers import RwkvForCausalLM

def convert_state_dict(state_dict):
    state_dict_keys = list(state_dict.keys())
    for name in state_dict_keys:
        weight = state_dict.pop(name)
        # emb -> embedding
        if name.startswith("emb."):
            name = name.replace("emb.", "embeddings.")
        # ln_0 -> pre_ln (only present at block 0)
        if name.startswith("blocks.0.ln0"):
            name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
        # att -> attention
        name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
        # ffn -> feed_forward
        name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
        # time_mix_k -> time_mix_key and reshape
        if name.endswith(".time_mix_k"):
            name = name.replace(".time_mix_k", ".time_mix_key")
        # time_mix_v -> time_mix_value and reshape
        if name.endswith(".time_mix_v"):
            name = name.replace(".time_mix_v", ".time_mix_value")
        # time_mix_r -> time_mix_key and reshape
        if name.endswith(".time_mix_r"):
            name = name.replace(".time_mix_r", ".time_mix_receptance")

        if name != "head.weight":
            name = "rwkv." + name

        state_dict[name] = weight
    return state_dict


def revert_state_dict(state_dict):
    state_dict_keys = list(state_dict.keys())
    for name in state_dict_keys:
        weight = state_dict.pop(name)
        name = name.removeprefix("rwkv.")

        # emb -> embedding
        if name.startswith("embeddings."):
            name = name.replace("embeddings.", "emb.")
        # ln_0 -> pre_ln (only present at block 0)
        if name.startswith("blocks.0.pre_ln"):
            name = name.replace("blocks.0.pre_ln", "blocks.0.ln0")
        # att -> attention
        name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name)
        # ffn -> feed_forward
        name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name)
        # time_mix_k -> time_mix_key and reshape
        if name.endswith(".time_mix_key"):
            name = name.replace(".time_mix_key", ".time_mix_k")
        # time_mix_v -> time_mix_value and reshape
        if name.endswith(".time_mix_value"):
            name = name.replace(".time_mix_value", ".time_mix_v")
        # time_mix_r -> time_mix_key and reshape
        if name.endswith(".time_mix_receptance"):
            name = name.replace(".time_mix_receptance", ".time_mix_r")

        state_dict[name] = weight
    return state_dict


if __name__ == "__main__":
    # repo = "beomi/KoRWKV-6B"
    repo = "beomi/KoAlpaca-KoRWKV-6B"
    model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)

    state_dict = model.state_dict()
    converted = revert_state_dict(state_dict)
    name = repo.split("/")[-1] + ".bf16.pth"

    torch.save(converted, name)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.