diff --git "a/test.ipynb" "b/test.ipynb" --- "a/test.ipynb" +++ "b/test.ipynb" @@ -10,21 +10,19 @@ "%autoreload 2\n", "import os\n", "\n", - "os.environ['TORCH_LOGS'] = 'dynamic'\n", - "\n", + "os.environ['TORCH_LOGS'] = '+dynamic'\n", "import pylab as pl" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES=1\n", "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n", " WeightNorm.apply(module, name, dim)\n", "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:123: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1\n", @@ -32,27 +30,19 @@ ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hˌaʊ kʊd aɪ nˈoʊ? ɪts ɐn ʌnˈænsɚɹəbəl kwˈɛstʃən. lˈaɪk ˈæskɪŋ ɐn ʌnbˈɔːɹn tʃˈaɪld ɪf ðeɪl lˈiːd ɐ ɡˈʊd lˈaɪf. ðeɪ hˈævənt ˈiːvən bˌɪn bˈɔːɹn.\n" + "ename": "TypeError", + "evalue": "CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 10\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mkokoro\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m generate\n\u001b[1;32m 9\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHow could I know? It\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms an unanswerable question. Like asking an unborn child if they\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mll lead a good life. They haven\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt even been born.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 10\u001b[0m audio, out_ps \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvoicepack\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# 4️⃣ Display the 24khz audio and print the output phonemes\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m display, Audio\n", + "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:147\u001b[0m, in \u001b[0;36mgenerate\u001b[0;34m(model, text, voicepack, lang, speed)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTruncated to 510 tokens\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 146\u001b[0m ref_s \u001b[38;5;241m=\u001b[39m voicepack[\u001b[38;5;28mlen\u001b[39m(tokens)]\n\u001b[0;32m--> 147\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mref_s\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspeed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 148\u001b[0m ps \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mnext\u001b[39m(k \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m VOCAB\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m v) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tokens)\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out, ps\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:119\u001b[0m, in \u001b[0;36mforward\u001b[0;34m(model, tokens, ref_s, speed)\u001b[0m\n\u001b[1;32m 117\u001b[0m input_lengths \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mLongTensor([tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]])\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 118\u001b[0m text_mask \u001b[38;5;241m=\u001b[39m length_to_mask(input_lengths)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m--> 119\u001b[0m bert_dur \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m~\u001b[39;49m\u001b[43mtext_mask\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mint\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 120\u001b[0m d_en \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mbert_encoder(bert_dur)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 121\u001b[0m s \u001b[38;5;241m=\u001b[39m ref_s[:, \u001b[38;5;241m128\u001b[39m:]\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[0;31mTypeError\u001b[0m: CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'" ] } ], @@ -60,13 +50,13 @@ "from models import build_model\n", "import torch\n", "device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n", - "MODEL = build_model('kokoro-v0_19.pth', device)\n", + "model = build_model('kokoro-v0_19.pth', device)\n", "voicepack = torch.load('voices/af.pt', weights_only=True).to(device)\n", "\n", "# 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes\n", "from kokoro import generate\n", "text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n", - "audio, out_ps = generate(MODEL, text, voicepack)\n", + "audio, out_ps = generate(model, text, voicepack)\n", "\n", "# 4️⃣ Display the 24khz audio and print the output phonemes\n", "from IPython.display import display, Audio\n", @@ -114,7 +104,7 @@ "source": [ "from kokoro import phonemize, tokenize, length_to_mask\n", "import torch.nn.functional as F\n", - "model = MODEL\n", + "model = model\n", "speed = 1.\n", "\n", "ps = phonemize(text, \"a\")\n", @@ -161,6 +151,37 @@ "print(out_ps)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "Error", + "evalue": "Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m scrpt \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscript\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1429\u001b[0m, in \u001b[0;36mscript\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m 1427\u001b[0m prev \u001b[38;5;241m=\u001b[39m _TOPLEVEL\n\u001b[1;32m 1428\u001b[0m _TOPLEVEL \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 1429\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43m_script_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1430\u001b[0m \u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1431\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1432\u001b[0m \u001b[43m \u001b[49m\u001b[43m_frames_up\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_frames_up\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1433\u001b[0m \u001b[43m \u001b[49m\u001b[43m_rcb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_rcb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1434\u001b[0m \u001b[43m \u001b[49m\u001b[43mexample_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexample_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1435\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prev:\n\u001b[1;32m 1438\u001b[0m log_torchscript_usage(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscript\u001b[39m\u001b[38;5;124m\"\u001b[39m, model_id\u001b[38;5;241m=\u001b[39m_get_model_id(ret))\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1154\u001b[0m, in \u001b[0;36m_script_impl\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m 1151\u001b[0m obj \u001b[38;5;241m=\u001b[39m obj\u001b[38;5;241m.\u001b[39m__prepare_scriptable__() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(obj, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__prepare_scriptable__\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m obj \u001b[38;5;66;03m# type: ignore[operator]\u001b[39;00m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mdict\u001b[39m):\n\u001b[0;32m-> 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m create_script_dict(obj)\n\u001b[1;32m 1155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 1156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m create_script_list(obj)\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1066\u001b[0m, in \u001b[0;36mcreate_script_dict\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_script_dict\u001b[39m(obj):\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;124;03m Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.\u001b[39;00m\n\u001b[1;32m 1056\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;124;03m zero copy overhead.\u001b[39;00m\n\u001b[1;32m 1065\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1066\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mScriptDict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mError\u001b[0m: Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module" + ] + } + ], + "source": [ + "scrpt = torch.jit.script(model)" + ] + }, { "cell_type": "code", "execution_count": 11, @@ -968,6 +989,244 @@ "tokens = tokenize(ps)\n", "print(tokens)" ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from models import build_model\n", + "import torch\n", + "device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n", + "model = build_model('kokoro-v0_19.pth', device)\n", + "voicepack = torch.load('voices/af.pt', weights_only=True).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "bert = model[\"bert\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "embeddings.word_embeddings.weight torch.Size([178, 128])\n", + "embeddings.position_embeddings.weight torch.Size([512, 128])\n", + "embeddings.token_type_embeddings.weight torch.Size([2, 128])\n", + "embeddings.LayerNorm.weight torch.Size([128])\n", + "embeddings.LayerNorm.bias torch.Size([128])\n", + "encoder.embedding_hidden_mapping_in.weight torch.Size([768, 128])\n", + "encoder.embedding_hidden_mapping_in.bias torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.weight torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.bias torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.query.weight torch.Size([768, 768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.query.bias torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.key.weight torch.Size([768, 768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.key.bias torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.value.weight torch.Size([768, 768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.value.bias torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.weight torch.Size([768, 768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.bias torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.weight torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.bias torch.Size([768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.ffn.weight torch.Size([2048, 768])\n", + "encoder.albert_layer_groups.0.albert_layers.0.ffn.bias torch.Size([2048])\n", + "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight torch.Size([768, 2048])\n", + "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias torch.Size([768])\n", + "pooler.weight torch.Size([768, 768])\n", + "pooler.bias torch.Size([768])\n" + ] + } + ], + "source": [ + "# show all parameters of model bert\n", + "for name, param in bert.named_parameters():\n", + " print(name, param.requires_grad())\n", + " # print(param)\n", + " # print(param.shape)\n", + " # break" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Testing LSTM export" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x1.shape=torch.Size([1, 300, 256])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:4279: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. \n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exported graph: graph(%x : Float(*, 300, 128, strides=[38400, 128, 1], requires_grad=0, device=cpu),\n", + " %onnx::LSTM_194 : Float(2, 1024, strides=[1024, 1], requires_grad=0, device=cpu),\n", + " %onnx::LSTM_195 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu),\n", + " %onnx::LSTM_196 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu)):\n", + " %/lstm/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/lstm/Shape\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n", + " %/lstm/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name=\"/lstm/Constant\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n", + " %/lstm/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/lstm/Gather\"](%/lstm/Shape_output_0, %/lstm/Constant_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n", + " %/lstm/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name=\"/lstm/Constant_1\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n", + " %onnx::Unsqueeze_16 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n", + " %/lstm/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/lstm/Unsqueeze\"](%/lstm/Gather_output_0, %onnx::Unsqueeze_16), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n", + " %/lstm/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={128}, onnx_name=\"/lstm/Constant_2\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n", + " %/lstm/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name=\"/lstm/Concat\"](%/lstm/Constant_1_output_0, %/lstm/Unsqueeze_output_0, %/lstm/Constant_2_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n", + " %/lstm/ConstantOfShape_output_0 : Float(*, *, *, strides=[128, 128, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name=\"/lstm/ConstantOfShape\"](%/lstm/Concat_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n", + " %/lstm/Transpose_output_0 : Float(300, *, 128, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", + " %onnx::LSTM_23 : Tensor? = prim::Constant(), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", + " %/lstm/LSTM_output_0 : Float(300, 2, *, 128, device=cpu), %/lstm/LSTM_output_1 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu), %/lstm/LSTM_output_2 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu) = onnx::LSTM[direction=\"bidirectional\", hidden_size=128, onnx_name=\"/lstm/LSTM\"](%/lstm/Transpose_output_0, %onnx::LSTM_195, %onnx::LSTM_196, %onnx::LSTM_194, %onnx::LSTM_23, %/lstm/ConstantOfShape_output_0, %/lstm/ConstantOfShape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", + " %/lstm/Transpose_1_output_0 : Float(300, *, 2, 128, device=cpu) = onnx::Transpose[perm=[0, 2, 1, 3], onnx_name=\"/lstm/Transpose_1\"](%/lstm/LSTM_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", + " %/lstm/Constant_3_output_0 : Long(3, strides=[1], device=cpu) = onnx::Constant[value= 0 0 -1 [ CPULongType{3} ], onnx_name=\"/lstm/Constant_3\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", + " %/lstm/Reshape_output_0 : Float(300, *, 256, device=cpu) = onnx::Reshape[allowzero=0, onnx_name=\"/lstm/Reshape\"](%/lstm/Transpose_1_output_0, %/lstm/Constant_3_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", + " %151 : Float(*, 300, 256, strides=[256, 256, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose_2\"](%/lstm/Reshape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", + " return (%151)\n", + "\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'NoneType' object has no attribute 'graph'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 37\u001b[0m\n\u001b[1;32m 34\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39monnx\u001b[38;5;241m.\u001b[39mexport(model\u001b[38;5;241m=\u001b[39mmodel, args\u001b[38;5;241m=\u001b[39m( xa, ), dynamic_axes\u001b[38;5;241m=\u001b[39mdynamic_shapes, input_names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m], f\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, dynamo\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# export_mod.save(\"model.onnx\")\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mexport_mod\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgraph\u001b[49m)\n", + "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'graph'" + ] + } + ], + "source": [ + "import torch\n", + "# os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"Eq(s0, 384)\"\n", + "\n", + "# model class containing a single bidirectional LSTM layer\n", + "class Model(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.lstm = torch.nn.LSTM(128, 128, 1, bidirectional=True, batch_first=True)\n", + " #initialize lstm weights\n", + " for name, param in self.lstm.named_parameters():\n", + " if 'weight' in name:\n", + " torch.nn.init.orthogonal_(param)\n", + " elif 'bias' in name:\n", + " torch.nn.init.zeros_(param)\n", + "\n", + " def forward(self, x):\n", + " x1 = x.transpose(-1,-2)\n", + " # print(f\"{x.shape=} {x1.shape=}\")\n", + " x2, _ = self.lstm(x)\n", + " return x2\n", + "\n", + "model = Model()\n", + "model = model.to(\"cpu\")\n", + "model.eval()\n", + "\n", + "#inital input to LSTM in variable x\n", + "xa = torch.zeros((1, 300, 128)).to(\"cpu\")\n", + "x1 = model(xa)\n", + "print(f\"{x1.shape=}\")\n", + "ntokens = torch.export.Dim(\"ntokens\", min=3)\n", + "dynamic_shapes= {\"x\":{0:\"ntokens\"}}\n", + "\n", + "# scripted = torch.jit.script(model)\n", + "torch.onnx.export(model=model, args=( xa, ), dynamic_axes=dynamic_shapes, input_names=[\"x\"], f=\"model.onnx\", verbose=True, dynamo=False)\n", + "# export_mod.save(\"model.onnx\")\n", + "# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\n", + "# print(export_mod.graph)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 143])\n" + ] + } + ], + "source": [ + "from kokoro import phonemize, tokenize\n", + "from models_scripting import load_plbert\n", + "bert = load_plbert()\n", + "\n", + "text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n", + "ps = phonemize(text, \"a\")\n", + "tokens = tokenize(ps)\n", + "tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)\n", + "dynamic_shapes = {\"tokens\":{1:'ntokens'}}\n", + "print(tokens.shape)\n", + "torch.onnx.export(model=bert, args=( tokens, ), dynamic_axes=dynamic_shapes, input_names=[\"tokens\"], f=\"bert.onnx\", verbose=False, dynamo=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "Fail", + "evalue": "[ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFail\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstyle_model.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01monnxruntime\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mort\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m ort_session \u001b[38;5;241m=\u001b[39m \u001b[43mort\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstyle_model.onnx\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m outputs \u001b[38;5;241m=\u001b[39m ort_session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: tokens\u001b[38;5;241m.\u001b[39mnumpy()})\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:465\u001b[0m, in \u001b[0;36mInferenceSession.__init__\u001b[0;34m(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)\u001b[0m\n\u001b[1;32m 462\u001b[0m disabled_optimizers \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdisabled_optimizers\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 464\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 465\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create_inference_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproviders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprovider_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisabled_optimizers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mRuntimeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_fallback:\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:526\u001b[0m, in \u001b[0;36mInferenceSession._create_inference_session\u001b[0;34m(self, providers, provider_options, disabled_optimizers)\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_register_ep_custom_ops(session_options, providers, provider_options, available_providers)\n\u001b[1;32m 525\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_path:\n\u001b[0;32m--> 526\u001b[0m sess \u001b[38;5;241m=\u001b[39m \u001b[43mC\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[43msession_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read_config_from_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 528\u001b[0m sess \u001b[38;5;241m=\u001b[39m C\u001b[38;5;241m.\u001b[39mInferenceSession(session_options, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_bytes, \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_config_from_model)\n", + "\u001b[0;31mFail\u001b[0m: [ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}" + ] + } + ], + "source": [ + "import onnx\n", + "\n", + "onnx_model = onnx.load(\"style_model.onnx\")\n", + "import onnxruntime as ort\n", + "\n", + "ort_session = ort.InferenceSession(\"style_model.onnx\")\n", + "outputs = ort_session.run(None, {\"tokens\": tokens.numpy()})" + ] } ], "metadata": {