{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n",
"#### SD2.x系付属の.yamlも併せて変換します\n",
"#### オプションでfp16として保存できます"
],
"metadata": {
"id": "fAIY_GORNEYa"
}
},
{
"cell_type": "markdown",
"source": [
"最初に以下のコードを実行"
],
"metadata": {
"id": "OnuCk_wNLM_D"
}
},
{
"cell_type": "code",
"source": [
"!pip install torch safetensors\n",
"!pip install pytorch-lightning\n",
"!pip install wget"
],
"metadata": {
"id": "pXr7oNJzwwgU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Google Drive上のファイルを読み書きしたい場合は、以下のコードを実行"
],
"metadata": {
"id": "NsncqZOha2e0"
}
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount(\"/content/drive\")"
],
"metadata": {
"id": "liEiK8Iioscq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 変換したモデルをHugging Faceに投稿したい場合は、以下のコードを実行\n",
"#@markdown 1. [このページ](https://huggingface.co/settings/tokens)にアクセスしてNew tokenからName=適当, Role=writeでAccess Tokenを取得\n",
"#@markdown 2. 取得したTokenをコピー & 以下の欄に貼り付け & 実行\n",
"!pip install huggingface_hub\n",
"from huggingface_hub import login\n",
"token = \"\" #@param {type:\"string\"}\n",
"login(token=token)"
],
"metadata": {
"cellView": "form",
"id": "mJO8RdvIINA-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行"
],
"metadata": {
"id": "7Ils-K70k15Y"
}
},
{
"cell_type": "code",
"source": [
"#@title モデルをダウンロード\n",
"#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n",
"#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n",
"#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n",
"import shutil\n",
"import urllib.parse\n",
"import urllib.request\n",
"import wget\n",
"import os\n",
"\n",
"models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
"models = [m.strip() for m in models.split(\",\")]\n",
"for model in models:\n",
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
" wget.download(model)\n",
" elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\", \".pt\")):\n",
" shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + os.path.basename(model)) # get the model from mydrive\n",
" else:\n",
" print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
],
"metadata": {
"id": "4vd3A09AxJE0",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title モデルを変換\n",
"#@markdown 変換するモデルをカンマ区切りで任意個指定
\n",
"#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される
\n",
"#@markdown `a.ckpt, -, b.safetensors`のような形式でモデルの引き算ができます\n",
"import os\n",
"import glob\n",
"import torch\n",
"import safetensors.torch\n",
"\n",
"from sys import modules\n",
"if \"huggingface_hub\" in modules:\n",
" from huggingface_hub import HfApi, Repository\n",
"\n",
"models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
"pruning = True #@param {type:\"boolean\"}\n",
"as_fp16 = True #@param {type:\"boolean\"}\n",
"clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
"uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
"merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"anything-v4.0.vae.pt\"] {allow-input: true}\n",
"save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
"#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定
\n",
"#@markdown 投稿しない場合は何も入力しない
\n",
"# 5GB以上のファイルを投稿する場合は、投稿先リポジトリを丸ごとダウンロードする工程が挟まるので、時間がかかる場合があります\n",
"repo_id = \"\" #@param {type:\"string\"}\n",
"\n",
"vae_preset = {\n",
" \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
" \"kl-f8-anime.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
" \"kl-f8-anime2.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
" \"anything-v4.0.vae.pt\": \"https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0.vae.pt\"}\n",
"if (merge_vae in vae_preset) and (not os.path.exists(merge_vae)):\n",
" wget.download(vae_preset[merge_vae])\n",
"\n",
"def upload_to_hugging_face(file_name):\n",
" api = HfApi()\n",
" api.upload_file(path_or_fileobj=file_name,\n",
" path_in_repo=file_name,\n",
" repo_id=repo_id,\n",
" )\n",
"\n",
"def convert_yaml(file_name):\n",
" with open(file_name) as f:\n",
" yaml = f.read()\n",
" if save_directly_to_Google_Drive:\n",
" os.chdir(\"/content/drive/MyDrive\")\n",
" is_safe = save_type == \".safetensors\"\n",
" yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
" if as_fp16:\n",
" yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
" file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
" with open(file_name, mode=\"w\") as f:\n",
" f.write(yaml)\n",
" if repo_id != \"\":\n",
" upload_to_hugging_face(file_name)\n",
" os.chdir(\"/content\")\n",
"\n",
"#use `str.removeprefix(p)` in python 3.9+\n",
"def remove_prefix(input_string, prefix):\n",
" if prefix and input_string.startswith(prefix):\n",
" return input_string[len(prefix):]\n",
" return input_string\n",
"\n",
"load_model = lambda m: safetensors.torch.load_file(m, device=\"cpu\") if os.path.splitext(m)[1] == \".safetensors\" else torch.load(m, map_location=torch.device(\"cpu\"))\n",
"save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
"\n",
"# --- def merge ---#\n",
"@torch.no_grad()\n",
"def merged(model_a, model_b, fab, fa, fb):\n",
" weights_a = load_model(model_a) if isinstance(model_a, str) else model_a\n",
" weights_b = load_model(model_b) if isinstance(model_b, str) else model_b\n",
" if \"state_dict\" in weights_a:\n",
" weights_a = weights_a[\"state_dict\"]\n",
" if \"state_dict\" in weights_b:\n",
" weights_b = weights_b[\"state_dict\"]\n",
" for key in list(weights_a.keys() or weights_b.keys()):\n",
" if isinstance(weights_a[key], dict):\n",
" del weights_a[key]\n",
" if isinstance(weights_b[key], dict):\n",
" del weights_b[key]\n",
" if key.startswith(\"model.\") or key.startswith(\"model_ema.\"):\n",
" if (key in weights_a) and (key in weights_b):\n",
" weights_a[key] = fab(weights_a[key], weights_b[key])\n",
" del weights_b[key]\n",
" elif key in weights_a:\n",
" weights_a[key] = fa(weights_a[key])\n",
" elif key in weights_b:\n",
" weights_a[key] = fb(weights_b[key])\n",
" del weights_b[key]\n",
" del weights_b\n",
" return weights_a\n",
"\n",
"def add(model_a, model_b):\n",
" return merged(model_a, model_b, lambda a, b: a + b, lambda a: a, lambda b: b)\n",
"\n",
"def difference(model_a, model_b):\n",
" return merged(model_a, model_b, lambda a, b: a - b, lambda a: a, lambda b: -b)\n",
"\n",
"def add_difference(model_a, model_b, model_c):\n",
" return add(model_a, difference(model_b, model_c))\n",
"# --- end merge ---#\n",
"\n",
"if models == \"\":\n",
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\") if not os.path.basename(m) in vae_preset]\n",
"else:\n",
" models = [m.strip() for m in models.split(\",\")]\n",
"\n",
"for i, model in enumerate(models):\n",
" model_name, model_ext = os.path.splitext(model)\n",
" # a.ckpt, - ,b.ckpt # - or b.ckpt\n",
" if (models[i] == \"-\") or (models[i - 1] == \"-\"):\n",
" continue\n",
" if model_ext == \".yaml\":\n",
" convert_yaml(model)\n",
" elif (model_ext != \".safetensors\") and (model_ext != \".ckpt\"):\n",
" print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
" else:\n",
" # convert model\n",
" with torch.no_grad():\n",
" # a.ckpt, - ,b.ckpt # a.ckpt\n",
" if (i < len(models) - 1) and (models[i + 1] == \"-\"):\n",
" weights = difference(model, models[i + 2])\n",
" model_name = f\"{model_name}-{os.path.splitext(models[i + 2])[0]}\"\n",
" # otherwise\n",
" else:\n",
" weights = load_model(model)\n",
" if \"state_dict\" in weights:\n",
" weights = weights[\"state_dict\"]\n",
" for key in list(weights.keys()):\n",
" if isinstance(weights[key], dict):\n",
" del weights[key] # to fix the broken model\n",
" if pruning:\n",
" model_name += \"-pruned\"\n",
" for key in list(weights.keys()):\n",
" if key.startswith(\"model_ema.\"):\n",
" del weights[key]\n",
" if as_fp16:\n",
" model_name += \"-fp16\"\n",
" for key in weights.keys():\n",
" weights[key] = weights[key].half()\n",
" if uninvited_key in weights:\n",
" if clip_fix == \"del err key\":\n",
" del weights[uninvited_key]\n",
" if clip_fix == \"fix err key\":\n",
" weights[uninvited_key] = torch.tensor([list(range(77))],dtype=torch.int64)\n",
" if merge_vae != \"\":\n",
" vae_weights = load_model(merge_vae)\n",
" if \"state_dict\" in vae_weights:\n",
" vae_weights = vae_weights[\"state_dict\"]\n",
" for key in weights.keys():\n",
" if key.startswith(\"first_stage_model.\"):\n",
" weights[key] = vae_weights[remove_prefix(key, \"first_stage_model.\")]\n",
" del vae_weights\n",
" if save_directly_to_Google_Drive:\n",
" os.chdir(\"/content/drive/MyDrive\")\n",
" save_model(weights, saved_model := model_name + save_type)\n",
" if repo_id != \"\":\n",
" if os.path.getsize(saved_model) >= 5*1000*1000*1000:\n",
" with Repository(os.path.basename(repo_id), clone_from=repo_id, skip_lfs_files=True, token=True).commit(commit_message=f\"Upload {saved_model} with huggingface_hub\", blocking=False):\n",
" save_model(weights, saved_model)\n",
" else:\n",
" upload_to_hugging_face(saved_model)\n",
" os.chdir(\"/content\")\n",
" del weights\n",
"\n",
"!reset"
],
"metadata": {
"cellView": "form",
"id": "QSzZqGygdXM9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n",
"\n",
"指定方法はモデルと同じです"
],
"metadata": {
"id": "SWTFKmGFLec6"
}
},
{
"cell_type": "markdown",
"source": [
"メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n",
"\n",
"標準では10GBまでのモデルを変換できます"
],
"metadata": {
"id": "0SUK6Alv2ItS"
}
},
{
"cell_type": "markdown",
"source": [
"[モデルのリンク集](https://huggingface.co/models?other=stable-diffusion)等から好きなモデルを選ぼう"
],
"metadata": {
"id": "yaLq5Nqe6an6"
}
}
]
}