Update as_safetensors+fp16_en.ipynb
Browse files- as_safetensors+fp16_en.ipynb +78 -96
as_safetensors+fp16_en.ipynb
CHANGED
@@ -14,6 +14,17 @@
|
|
14 |
}
|
15 |
},
|
16 |
"cells": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
{
|
18 |
"cell_type": "markdown",
|
19 |
"source": [
|
@@ -39,6 +50,7 @@
|
|
39 |
"cell_type": "code",
|
40 |
"source": [
|
41 |
"!pip install torch safetensors\n",
|
|
|
42 |
"!pip install wget"
|
43 |
],
|
44 |
"metadata": {
|
@@ -63,27 +75,20 @@
|
|
63 |
"#@markdown Please specify the model name or download link for Google Drive, separated by commas\n",
|
64 |
"#@markdown - If it is the model name on Google Drive, specify it as a relative path to My Drive\n",
|
65 |
"#@markdown - If it is a download link, copy the link address by right-clicking and paste it in place of the link below\n",
|
66 |
-
"\n",
|
67 |
"import shutil\n",
|
68 |
"import urllib.parse\n",
|
69 |
"import urllib.request\n",
|
70 |
"import wget\n",
|
71 |
"\n",
|
72 |
-
"models = \"
|
73 |
-
"models = [m.strip() for m in models.split(\",\")
|
74 |
"for model in models:\n",
|
75 |
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
|
76 |
" wget.download(model)\n",
|
77 |
-
"
|
78 |
-
"
|
79 |
-
" ## with open(os.path.basename(model), mode=\"wb\") as f:\n",
|
80 |
-
" ## f.write(model_data)\n",
|
81 |
-
" elif model.endswith((\".ckpt\", \".safetensors\", \".pt\", \".pth\")):\n",
|
82 |
-
" from_ = \"/content/drive/MyDrive/\" + model\n",
|
83 |
-
" to_ = \"/content/\" + model\n",
|
84 |
-
" shutil.copy(from_, to_)\n",
|
85 |
" else:\n",
|
86 |
-
" print(f\"\\\"{model}\\\"
|
87 |
],
|
88 |
"metadata": {
|
89 |
"cellView": "form",
|
@@ -92,66 +97,65 @@
|
|
92 |
"execution_count": null,
|
93 |
"outputs": []
|
94 |
},
|
95 |
-
{
|
96 |
-
"cell_type": "markdown",
|
97 |
-
"source": [
|
98 |
-
"if you use a relatively newer model such as SD2.1, run the following code"
|
99 |
-
],
|
100 |
-
"metadata": {
|
101 |
-
"id": "m1mHzOMjcDhz"
|
102 |
-
}
|
103 |
-
},
|
104 |
-
{
|
105 |
-
"cell_type": "code",
|
106 |
-
"source": [
|
107 |
-
"!pip install pytorch-lightning"
|
108 |
-
],
|
109 |
-
"metadata": {
|
110 |
-
"id": "TkrmByc0aYVN"
|
111 |
-
},
|
112 |
-
"execution_count": null,
|
113 |
-
"outputs": []
|
114 |
-
},
|
115 |
-
{
|
116 |
-
"cell_type": "markdown",
|
117 |
-
"source": [
|
118 |
-
"Run either of the following two codes. If you run out of memory and crash, use a smaller model or a paid high-memory runtime"
|
119 |
-
],
|
120 |
-
"metadata": {
|
121 |
-
"id": "0SUK6Alv2ItS"
|
122 |
-
}
|
123 |
-
},
|
124 |
{
|
125 |
"cell_type": "code",
|
126 |
"source": [
|
127 |
-
"#@title <font size=\"-0\">
|
|
|
|
|
128 |
"import os\n",
|
|
|
129 |
"import torch\n",
|
130 |
"import safetensors.torch\n",
|
|
|
131 |
"\n",
|
132 |
-
"
|
133 |
-
"model_name, model_ext = os.path.splitext(model)\n",
|
134 |
"as_fp16 = True #@param {type:\"boolean\"}\n",
|
135 |
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
|
|
|
136 |
"\n",
|
137 |
-
"
|
138 |
-
"
|
139 |
-
"
|
140 |
-
" elif model_ext == \".ckpt\":\n",
|
141 |
-
" weights = torch.load(model_name + model_ext, map_location=torch.device('cpu'))[\"state_dict\"]\n",
|
142 |
-
" else:\n",
|
143 |
-
" raise Exception(\"対応形式は.ckptと.safetensorsです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
|
144 |
-
" if as_fp16:\n",
|
145 |
-
" model_name = model_name + \"-fp16\"\n",
|
146 |
-
" for key in weights.keys():\n",
|
147 |
-
" weights[key] = weights[key].half()\n",
|
148 |
" if save_directly_to_Google_Drive:\n",
|
149 |
" os.chdir(\"/content/drive/MyDrive\")\n",
|
150 |
-
"
|
151 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
" else:\n",
|
153 |
-
" safetensors.torch.
|
154 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
"\n",
|
156 |
"!reset"
|
157 |
],
|
@@ -163,48 +167,26 @@
|
|
163 |
"outputs": []
|
164 |
},
|
165 |
{
|
166 |
-
"cell_type": "
|
167 |
"source": [
|
168 |
-
"
|
169 |
-
"import os\n",
|
170 |
-
"import glob\n",
|
171 |
-
"import torch\n",
|
172 |
-
"import safetensors.torch\n",
|
173 |
-
"\n",
|
174 |
-
"as_fp16 = True #@param {type:\"boolean\"}\n",
|
175 |
-
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
|
176 |
"\n",
|
177 |
-
"
|
178 |
-
|
179 |
-
|
180 |
-
"
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
"
|
187 |
-
" break\n",
|
188 |
-
" if as_fp16:\n",
|
189 |
-
" model_name = model_name + \"-fp16\"\n",
|
190 |
-
" for key in weights.keys():\n",
|
191 |
-
" weights[key] = weights[key].half()\n",
|
192 |
-
" if save_directly_to_Google_Drive:\n",
|
193 |
-
" os.chdir(\"/content/drive/MyDrive\")\n",
|
194 |
-
" safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
|
195 |
-
" os.chdir(\"/content\")\n",
|
196 |
-
" else:\n",
|
197 |
-
" safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
|
198 |
-
" del weights\n",
|
199 |
"\n",
|
200 |
-
"
|
201 |
],
|
202 |
"metadata": {
|
203 |
-
"id": "
|
204 |
-
|
205 |
-
},
|
206 |
-
"execution_count": null,
|
207 |
-
"outputs": []
|
208 |
},
|
209 |
{
|
210 |
"cell_type": "markdown",
|
|
|
14 |
}
|
15 |
},
|
16 |
"cells": [
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"source": [
|
20 |
+
"### This is a script that converts the format of the model (.ckpt/.safetensors)\n",
|
21 |
+
"#### It also converts the .yaml file included with the SD2.x series\n",
|
22 |
+
"#### It can also be saved as fp16 as an option"
|
23 |
+
],
|
24 |
+
"metadata": {
|
25 |
+
"id": "fAIY_GORNEYa"
|
26 |
+
}
|
27 |
+
},
|
28 |
{
|
29 |
"cell_type": "markdown",
|
30 |
"source": [
|
|
|
50 |
"cell_type": "code",
|
51 |
"source": [
|
52 |
"!pip install torch safetensors\n",
|
53 |
+
"!pip install pytorch-lightning\n",
|
54 |
"!pip install wget"
|
55 |
],
|
56 |
"metadata": {
|
|
|
75 |
"#@markdown Please specify the model name or download link for Google Drive, separated by commas\n",
|
76 |
"#@markdown - If it is the model name on Google Drive, specify it as a relative path to My Drive\n",
|
77 |
"#@markdown - If it is a download link, copy the link address by right-clicking and paste it in place of the link below\n",
|
|
|
78 |
"import shutil\n",
|
79 |
"import urllib.parse\n",
|
80 |
"import urllib.request\n",
|
81 |
"import wget\n",
|
82 |
"\n",
|
83 |
+
"models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
|
84 |
+
"models = [m.strip() for m in models.split(\",\")]\n",
|
85 |
"for model in models:\n",
|
86 |
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
|
87 |
" wget.download(model)\n",
|
88 |
+
" elif model.endswith((\".ckpt\", \".safetensors\")):\n",
|
89 |
+
" shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
" else:\n",
|
91 |
+
" print(f\"\\\"{model}\\\" is not a URL and is also not a file with a proper extension\")"
|
92 |
],
|
93 |
"metadata": {
|
94 |
"cellView": "form",
|
|
|
97 |
"execution_count": null,
|
98 |
"outputs": []
|
99 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
{
|
101 |
"cell_type": "code",
|
102 |
"source": [
|
103 |
+
"#@title <font size=\"-0\">Convert the Models</font>\n",
|
104 |
+
"#@markdown Specify the models to be converted, separated by commas<br>\n",
|
105 |
+
"#@markdown If nothing is inputted, all loaded models will be converted\n",
|
106 |
"import os\n",
|
107 |
+
"import glob\n",
|
108 |
"import torch\n",
|
109 |
"import safetensors.torch\n",
|
110 |
+
"from functools import partial\n",
|
111 |
"\n",
|
112 |
+
"models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
|
|
|
113 |
"as_fp16 = True #@param {type:\"boolean\"}\n",
|
114 |
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
|
115 |
+
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
|
116 |
"\n",
|
117 |
+
"def convert_yaml(file_name):\n",
|
118 |
+
" with open(file_name) as f:\n",
|
119 |
+
" yaml = f.read()\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
" if save_directly_to_Google_Drive:\n",
|
121 |
" os.chdir(\"/content/drive/MyDrive\")\n",
|
122 |
+
" is_safe = save_type == \".safetensors\"\n",
|
123 |
+
" yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
|
124 |
+
" if as_fp16:\n",
|
125 |
+
" yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
|
126 |
+
" file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
|
127 |
+
" with open(file_name, mode=\"w\") as f:\n",
|
128 |
+
" f.write(yaml)\n",
|
129 |
+
" os.chdir(\"/content\")\n",
|
130 |
+
"\n",
|
131 |
+
"if models == \"\":\n",
|
132 |
+
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
|
133 |
+
"else:\n",
|
134 |
+
" models = [m.strip() for m in models.split(\",\")]\n",
|
135 |
+
"\n",
|
136 |
+
"for model in models:\n",
|
137 |
+
" model_name, model_ext = os.path.splitext(model)\n",
|
138 |
+
" if model_ext == \".yaml\":\n",
|
139 |
+
" convert_yaml(model)\n",
|
140 |
+
" elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
|
141 |
+
" print(\"The supported formats are only .ckpt, .safetensors, and .yaml\\n\" + f\"\\\"{model}\\\" is not a supported format\")\n",
|
142 |
" else:\n",
|
143 |
+
" load_model = partial(safetensors.torch.load_file, device=\"cpu\") if model_ext == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))\n",
|
144 |
+
" save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
|
145 |
+
" # convert model\n",
|
146 |
+
" with torch.no_grad():\n",
|
147 |
+
" weights = load_model(model)\n",
|
148 |
+
" if \"state_dict\" in weights:\n",
|
149 |
+
" weights = weights[\"state_dict\"]\n",
|
150 |
+
" if as_fp16:\n",
|
151 |
+
" model_name = model_name + \"-fp16\"\n",
|
152 |
+
" for key in weights.keys():\n",
|
153 |
+
" weights[key] = weights[key].half()\n",
|
154 |
+
" if save_directly_to_Google_Drive:\n",
|
155 |
+
" os.chdir(\"/content/drive/MyDrive\")\n",
|
156 |
+
" save_model(weights, model_name + save_type)\n",
|
157 |
+
" os.chdir(\"/content\")\n",
|
158 |
+
" del weights\n",
|
159 |
"\n",
|
160 |
"!reset"
|
161 |
],
|
|
|
167 |
"outputs": []
|
168 |
},
|
169 |
{
|
170 |
+
"cell_type": "markdown",
|
171 |
"source": [
|
172 |
+
"If you are converting SD2.x series models, etc., be sure to download/convert the accompanying configuration file (a .yaml file with the same name as the model) at the same time.\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
"\n",
|
174 |
+
"It can be converted in the same way as the model."
|
175 |
+
],
|
176 |
+
"metadata": {
|
177 |
+
"id": "SWTFKmGFLec6"
|
178 |
+
}
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "markdown",
|
182 |
+
"source": [
|
183 |
+
"If you run out of memory and crash, you can use a smaller model or a paid high memory runtime.\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
"\n",
|
185 |
+
"With the free ~12GB runtime, you can convert models up to ~10GB."
|
186 |
],
|
187 |
"metadata": {
|
188 |
+
"id": "0SUK6Alv2ItS"
|
189 |
+
}
|
|
|
|
|
|
|
190 |
},
|
191 |
{
|
192 |
"cell_type": "markdown",
|