{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n",
"#### SD2.x系付属の.yamlも併せて変換します\n",
"#### オプションでfp16として保存できます"
],
"metadata": {
"id": "fAIY_GORNEYa"
}
},
{
"cell_type": "markdown",
"source": [
"以下のコードを上から順番に両方とも実行"
],
"metadata": {
"id": "OnuCk_wNLM_D"
}
},
{
"cell_type": "code",
"source": [
"from google.colab import drive \n",
"drive.mount(\"/content/drive\")"
],
"metadata": {
"id": "liEiK8Iioscq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install torch safetensors\n",
"!pip install pytorch-lightning\n",
"!pip install wget"
],
"metadata": {
"id": "pXr7oNJzwwgU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行"
],
"metadata": {
"id": "7Ils-K70k15Y"
}
},
{
"cell_type": "code",
"source": [
"#@title モデルをダウンロード\n",
"#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n",
"#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n",
"#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n",
"import shutil\n",
"import urllib.parse\n",
"import urllib.request\n",
"import wget\n",
"\n",
"models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
"models = [m.strip() for m in models.split(\",\")]\n",
"for model in models:\n",
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
" wget.download(model)\n",
" elif model.endswith((\".ckpt\", \".safetensors\")):\n",
" shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
" else:\n",
" print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
],
"metadata": {
"cellView": "form",
"id": "4vd3A09AxJE0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title モデルを変換\n",
"#@markdown 変換するモデルをカンマ区切りで任意個指定
\n",
"#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される\n",
"import os\n",
"import glob\n",
"import torch\n",
"import safetensors.torch\n",
"from functools import partial\n",
"\n",
"models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
"as_fp16 = True #@param {type:\"boolean\"}\n",
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
"\n",
"def convert_yaml(file_name):\n",
" with open(file_name) as f:\n",
" yaml = f.read()\n",
" if save_directly_to_Google_Drive:\n",
" os.chdir(\"/content/drive/MyDrive\")\n",
" is_safe = save_type == \".safetensors\"\n",
" yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
" if as_fp16:\n",
" yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
" file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
" with open(file_name, mode=\"w\") as f:\n",
" f.write(yaml)\n",
" os.chdir(\"/content\")\n",
"\n",
"if models == \"\":\n",
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
"else:\n",
" models = [m.strip() for m in models.split(\",\")]\n",
"\n",
"for model in models:\n",
" model_name, model_ext = os.path.splitext(model)\n",
" if model_ext == \".yaml\":\n",
" convert_yaml(model)\n",
" elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
" print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
" else:\n",
" load_model = partial(safetensors.torch.load_file, device=\"cpu\") if model_ext == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))\n",
" save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
" # convert model\n",
" with torch.no_grad():\n",
" weights = load_model(model)\n",
" if \"state_dict\" in weights:\n",
" weights = weights[\"state_dict\"]\n",
" if as_fp16:\n",
" model_name = model_name + \"-fp16\"\n",
" for key in weights.keys():\n",
" weights[key] = weights[key].half()\n",
" if save_directly_to_Google_Drive:\n",
" os.chdir(\"/content/drive/MyDrive\")\n",
" save_model(weights, model_name + save_type)\n",
" os.chdir(\"/content\")\n",
" del weights\n",
"\n",
"!reset"
],
"metadata": {
"id": "9OmSG98HxJg2",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n",
"\n",
"指定方法はモデルと同じです"
],
"metadata": {
"id": "SWTFKmGFLec6"
}
},
{
"cell_type": "markdown",
"source": [
"メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n",
"\n",
"標準では10GBまでのモデルを変換できます"
],
"metadata": {
"id": "0SUK6Alv2ItS"
}
},
{
"cell_type": "markdown",
"source": [
"モデルのリンク集: https://huggingface.co/models?other=stable-diffusion 等から好きなモデルを選ぼう"
],
"metadata": {
"id": "yaLq5Nqe6an6"
}
}
]
}