diff --git "a/test.ipynb" "b/test.ipynb"
--- "a/test.ipynb"
+++ "b/test.ipynb"
@@ -10,21 +10,19 @@
"%autoreload 2\n",
"import os\n",
"\n",
- "os.environ['TORCH_LOGS'] = 'dynamic'\n",
- "\n",
+ "os.environ['TORCH_LOGS'] = '+dynamic'\n",
"import pylab as pl"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "WARNING:root:Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES=1\n",
"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
" WeightNorm.apply(module, name, dim)\n",
"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:123: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1\n",
@@ -32,27 +30,19 @@
]
},
{
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "hˌaʊ kʊd aɪ nˈoʊ? ɪts ɐn ʌnˈænsɚɹəbəl kwˈɛstʃən. lˈaɪk ˈæskɪŋ ɐn ʌnbˈɔːɹn tʃˈaɪld ɪf ðeɪl lˈiːd ɐ ɡˈʊd lˈaɪf. ðeɪ hˈævənt ˈiːvən bˌɪn bˈɔːɹn.\n"
+ "ename": "TypeError",
+ "evalue": "CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[3], line 10\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mkokoro\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m generate\n\u001b[1;32m 9\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHow could I know? It\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms an unanswerable question. Like asking an unborn child if they\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mll lead a good life. They haven\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt even been born.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 10\u001b[0m audio, out_ps \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvoicepack\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# 4️⃣ Display the 24khz audio and print the output phonemes\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m display, Audio\n",
+ "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:147\u001b[0m, in \u001b[0;36mgenerate\u001b[0;34m(model, text, voicepack, lang, speed)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTruncated to 510 tokens\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 146\u001b[0m ref_s \u001b[38;5;241m=\u001b[39m voicepack[\u001b[38;5;28mlen\u001b[39m(tokens)]\n\u001b[0;32m--> 147\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mref_s\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspeed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 148\u001b[0m ps \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mnext\u001b[39m(k \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m VOCAB\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m v) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tokens)\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out, ps\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:119\u001b[0m, in \u001b[0;36mforward\u001b[0;34m(model, tokens, ref_s, speed)\u001b[0m\n\u001b[1;32m 117\u001b[0m input_lengths \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mLongTensor([tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]])\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 118\u001b[0m text_mask \u001b[38;5;241m=\u001b[39m length_to_mask(input_lengths)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m--> 119\u001b[0m bert_dur \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m~\u001b[39;49m\u001b[43mtext_mask\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mint\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 120\u001b[0m d_en \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mbert_encoder(bert_dur)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 121\u001b[0m s \u001b[38;5;241m=\u001b[39m ref_s[:, \u001b[38;5;241m128\u001b[39m:]\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
+ "\u001b[0;31mTypeError\u001b[0m: CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'"
]
}
],
@@ -60,13 +50,13 @@
"from models import build_model\n",
"import torch\n",
"device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n",
- "MODEL = build_model('kokoro-v0_19.pth', device)\n",
+ "model = build_model('kokoro-v0_19.pth', device)\n",
"voicepack = torch.load('voices/af.pt', weights_only=True).to(device)\n",
"\n",
"# 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes\n",
"from kokoro import generate\n",
"text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n",
- "audio, out_ps = generate(MODEL, text, voicepack)\n",
+ "audio, out_ps = generate(model, text, voicepack)\n",
"\n",
"# 4️⃣ Display the 24khz audio and print the output phonemes\n",
"from IPython.display import display, Audio\n",
@@ -114,7 +104,7 @@
"source": [
"from kokoro import phonemize, tokenize, length_to_mask\n",
"import torch.nn.functional as F\n",
- "model = MODEL\n",
+ "model = model\n",
"speed = 1.\n",
"\n",
"ps = phonemize(text, \"a\")\n",
@@ -161,6 +151,37 @@
"print(out_ps)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "Error",
+ "evalue": "Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m scrpt \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscript\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1429\u001b[0m, in \u001b[0;36mscript\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m 1427\u001b[0m prev \u001b[38;5;241m=\u001b[39m _TOPLEVEL\n\u001b[1;32m 1428\u001b[0m _TOPLEVEL \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 1429\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43m_script_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1430\u001b[0m \u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1431\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1432\u001b[0m \u001b[43m \u001b[49m\u001b[43m_frames_up\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_frames_up\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1433\u001b[0m \u001b[43m \u001b[49m\u001b[43m_rcb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_rcb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1434\u001b[0m \u001b[43m \u001b[49m\u001b[43mexample_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexample_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1435\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prev:\n\u001b[1;32m 1438\u001b[0m log_torchscript_usage(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscript\u001b[39m\u001b[38;5;124m\"\u001b[39m, model_id\u001b[38;5;241m=\u001b[39m_get_model_id(ret))\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1154\u001b[0m, in \u001b[0;36m_script_impl\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m 1151\u001b[0m obj \u001b[38;5;241m=\u001b[39m obj\u001b[38;5;241m.\u001b[39m__prepare_scriptable__() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(obj, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__prepare_scriptable__\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m obj \u001b[38;5;66;03m# type: ignore[operator]\u001b[39;00m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mdict\u001b[39m):\n\u001b[0;32m-> 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m create_script_dict(obj)\n\u001b[1;32m 1155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 1156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m create_script_list(obj)\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1066\u001b[0m, in \u001b[0;36mcreate_script_dict\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_script_dict\u001b[39m(obj):\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;124;03m Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.\u001b[39;00m\n\u001b[1;32m 1056\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;124;03m zero copy overhead.\u001b[39;00m\n\u001b[1;32m 1065\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1066\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mScriptDict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n",
+ "\u001b[0;31mError\u001b[0m: Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module"
+ ]
+ }
+ ],
+ "source": [
+ "scrpt = torch.jit.script(model)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 11,
@@ -968,6 +989,244 @@
"tokens = tokenize(ps)\n",
"print(tokens)"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from models import build_model\n",
+ "import torch\n",
+ "device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "model = build_model('kokoro-v0_19.pth', device)\n",
+ "voicepack = torch.load('voices/af.pt', weights_only=True).to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "bert = model[\"bert\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "embeddings.word_embeddings.weight torch.Size([178, 128])\n",
+ "embeddings.position_embeddings.weight torch.Size([512, 128])\n",
+ "embeddings.token_type_embeddings.weight torch.Size([2, 128])\n",
+ "embeddings.LayerNorm.weight torch.Size([128])\n",
+ "embeddings.LayerNorm.bias torch.Size([128])\n",
+ "encoder.embedding_hidden_mapping_in.weight torch.Size([768, 128])\n",
+ "encoder.embedding_hidden_mapping_in.bias torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.weight torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.bias torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.query.weight torch.Size([768, 768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.query.bias torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.key.weight torch.Size([768, 768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.key.bias torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.value.weight torch.Size([768, 768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.value.bias torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.weight torch.Size([768, 768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.bias torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.weight torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.bias torch.Size([768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.ffn.weight torch.Size([2048, 768])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.ffn.bias torch.Size([2048])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight torch.Size([768, 2048])\n",
+ "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias torch.Size([768])\n",
+ "pooler.weight torch.Size([768, 768])\n",
+ "pooler.bias torch.Size([768])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# show all parameters of model bert\n",
+ "for name, param in bert.named_parameters():\n",
+ " print(name, param.requires_grad())\n",
+ " # print(param)\n",
+ " # print(param.shape)\n",
+ " # break"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Testing LSTM export"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "x1.shape=torch.Size([1, 300, 256])\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:4279: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. \n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Exported graph: graph(%x : Float(*, 300, 128, strides=[38400, 128, 1], requires_grad=0, device=cpu),\n",
+ " %onnx::LSTM_194 : Float(2, 1024, strides=[1024, 1], requires_grad=0, device=cpu),\n",
+ " %onnx::LSTM_195 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu),\n",
+ " %onnx::LSTM_196 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu)):\n",
+ " %/lstm/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/lstm/Shape\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n",
+ " %/lstm/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name=\"/lstm/Constant\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n",
+ " %/lstm/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/lstm/Gather\"](%/lstm/Shape_output_0, %/lstm/Constant_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n",
+ " %/lstm/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name=\"/lstm/Constant_1\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n",
+ " %onnx::Unsqueeze_16 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
+ " %/lstm/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/lstm/Unsqueeze\"](%/lstm/Gather_output_0, %onnx::Unsqueeze_16), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n",
+ " %/lstm/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={128}, onnx_name=\"/lstm/Constant_2\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n",
+ " %/lstm/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name=\"/lstm/Concat\"](%/lstm/Constant_1_output_0, %/lstm/Unsqueeze_output_0, %/lstm/Constant_2_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n",
+ " %/lstm/ConstantOfShape_output_0 : Float(*, *, *, strides=[128, 128, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name=\"/lstm/ConstantOfShape\"](%/lstm/Concat_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n",
+ " %/lstm/Transpose_output_0 : Float(300, *, 128, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
+ " %onnx::LSTM_23 : Tensor? = prim::Constant(), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
+ " %/lstm/LSTM_output_0 : Float(300, 2, *, 128, device=cpu), %/lstm/LSTM_output_1 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu), %/lstm/LSTM_output_2 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu) = onnx::LSTM[direction=\"bidirectional\", hidden_size=128, onnx_name=\"/lstm/LSTM\"](%/lstm/Transpose_output_0, %onnx::LSTM_195, %onnx::LSTM_196, %onnx::LSTM_194, %onnx::LSTM_23, %/lstm/ConstantOfShape_output_0, %/lstm/ConstantOfShape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
+ " %/lstm/Transpose_1_output_0 : Float(300, *, 2, 128, device=cpu) = onnx::Transpose[perm=[0, 2, 1, 3], onnx_name=\"/lstm/Transpose_1\"](%/lstm/LSTM_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
+ " %/lstm/Constant_3_output_0 : Long(3, strides=[1], device=cpu) = onnx::Constant[value= 0 0 -1 [ CPULongType{3} ], onnx_name=\"/lstm/Constant_3\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
+ " %/lstm/Reshape_output_0 : Float(300, *, 256, device=cpu) = onnx::Reshape[allowzero=0, onnx_name=\"/lstm/Reshape\"](%/lstm/Transpose_1_output_0, %/lstm/Constant_3_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
+ " %151 : Float(*, 300, 256, strides=[256, 256, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose_2\"](%/lstm/Reshape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
+ " return (%151)\n",
+ "\n"
+ ]
+ },
+ {
+ "ename": "AttributeError",
+ "evalue": "'NoneType' object has no attribute 'graph'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[2], line 37\u001b[0m\n\u001b[1;32m 34\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39monnx\u001b[38;5;241m.\u001b[39mexport(model\u001b[38;5;241m=\u001b[39mmodel, args\u001b[38;5;241m=\u001b[39m( xa, ), dynamic_axes\u001b[38;5;241m=\u001b[39mdynamic_shapes, input_names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m], f\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, dynamo\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# export_mod.save(\"model.onnx\")\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mexport_mod\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgraph\u001b[49m)\n",
+ "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'graph'"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "# os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"Eq(s0, 384)\"\n",
+ "\n",
+ "# model class containing a single bidirectional LSTM layer\n",
+ "class Model(torch.nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.lstm = torch.nn.LSTM(128, 128, 1, bidirectional=True, batch_first=True)\n",
+ " #initialize lstm weights\n",
+ " for name, param in self.lstm.named_parameters():\n",
+ " if 'weight' in name:\n",
+ " torch.nn.init.orthogonal_(param)\n",
+ " elif 'bias' in name:\n",
+ " torch.nn.init.zeros_(param)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x1 = x.transpose(-1,-2)\n",
+ " # print(f\"{x.shape=} {x1.shape=}\")\n",
+ " x2, _ = self.lstm(x)\n",
+ " return x2\n",
+ "\n",
+ "model = Model()\n",
+ "model = model.to(\"cpu\")\n",
+ "model.eval()\n",
+ "\n",
+ "#inital input to LSTM in variable x\n",
+ "xa = torch.zeros((1, 300, 128)).to(\"cpu\")\n",
+ "x1 = model(xa)\n",
+ "print(f\"{x1.shape=}\")\n",
+ "ntokens = torch.export.Dim(\"ntokens\", min=3)\n",
+ "dynamic_shapes= {\"x\":{0:\"ntokens\"}}\n",
+ "\n",
+ "# scripted = torch.jit.script(model)\n",
+ "torch.onnx.export(model=model, args=( xa, ), dynamic_axes=dynamic_shapes, input_names=[\"x\"], f=\"model.onnx\", verbose=True, dynamo=False)\n",
+ "# export_mod.save(\"model.onnx\")\n",
+ "# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\n",
+ "# print(export_mod.graph)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 143])\n"
+ ]
+ }
+ ],
+ "source": [
+ "from kokoro import phonemize, tokenize\n",
+ "from models_scripting import load_plbert\n",
+ "bert = load_plbert()\n",
+ "\n",
+ "text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n",
+ "ps = phonemize(text, \"a\")\n",
+ "tokens = tokenize(ps)\n",
+ "tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)\n",
+ "dynamic_shapes = {\"tokens\":{1:'ntokens'}}\n",
+ "print(tokens.shape)\n",
+ "torch.onnx.export(model=bert, args=( tokens, ), dynamic_axes=dynamic_shapes, input_names=[\"tokens\"], f=\"bert.onnx\", verbose=False, dynamo=False)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "Fail",
+ "evalue": "[ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mFail\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[6], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstyle_model.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01monnxruntime\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mort\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m ort_session \u001b[38;5;241m=\u001b[39m \u001b[43mort\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstyle_model.onnx\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m outputs \u001b[38;5;241m=\u001b[39m ort_session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: tokens\u001b[38;5;241m.\u001b[39mnumpy()})\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:465\u001b[0m, in \u001b[0;36mInferenceSession.__init__\u001b[0;34m(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)\u001b[0m\n\u001b[1;32m 462\u001b[0m disabled_optimizers \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdisabled_optimizers\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 464\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 465\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create_inference_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproviders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprovider_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisabled_optimizers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mRuntimeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_fallback:\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:526\u001b[0m, in \u001b[0;36mInferenceSession._create_inference_session\u001b[0;34m(self, providers, provider_options, disabled_optimizers)\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_register_ep_custom_ops(session_options, providers, provider_options, available_providers)\n\u001b[1;32m 525\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_path:\n\u001b[0;32m--> 526\u001b[0m sess \u001b[38;5;241m=\u001b[39m \u001b[43mC\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[43msession_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read_config_from_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 528\u001b[0m sess \u001b[38;5;241m=\u001b[39m C\u001b[38;5;241m.\u001b[39mInferenceSession(session_options, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_bytes, \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_config_from_model)\n",
+ "\u001b[0;31mFail\u001b[0m: [ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}"
+ ]
+ }
+ ],
+ "source": [
+ "import onnx\n",
+ "\n",
+ "onnx_model = onnx.load(\"style_model.onnx\")\n",
+ "import onnxruntime as ort\n",
+ "\n",
+ "ort_session = ort.InferenceSession(\"style_model.onnx\")\n",
+ "outputs = ort_session.run(None, {\"tokens\": tokens.numpy()})"
+ ]
}
],
"metadata": {