{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "h3_RFpXiPQR6"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras import Model\n",
        "from tensorflow.keras import Input\n",
        "from tensorflow.keras.layers import Conv2D, ReLU, ELU, LeakyReLU, Dropout, Dense, MaxPooling2D, Flatten, BatchNormalization\n",
        "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
        "from tensorflow.keras.optimizers import Adam\n",
        "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping\n",
        "from tensorflow.keras.utils import plot_model\n",
        "from tensorflow.keras.models import load_model\n",
        "from sklearn.metrics import classification_report\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "ADU3Hu_TAFvG"
      },
      "outputs": [],
      "source": [
        "IMG_WIDTH = 256"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "NyTzubIUjeCR"
      },
      "outputs": [],
      "source": [
        "def get_datagen(use_default_augmentation=True, **kwargs):\n",
        "    kwargs.update({'rescale': 1./255})\n",
        "    if use_default_augmentation:\n",
        "        kwargs.update({\n",
        "            'rotation_range': 15,\n",
        "            'zoom_range': 0.2,\n",
        "            'brightness_range': (0.8, 1.2),\n",
        "            'channel_shift_range': 30,\n",
        "            'horizontal_flip': True,\n",
        "        })\n",
        "    return ImageDataGenerator(**kwargs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "nrH7Fz6EM4mk"
      },
      "outputs": [],
      "source": [
        "def get_train_data_generator(\n",
        "    train_data_dir, \n",
        "    batch_size, \n",
        "    validation_split=None, \n",
        "    use_default_augmentation=True,\n",
        "    augmentations=None\n",
        "):\n",
        "    if not augmentations:\n",
        "        augmentations = {}\n",
        "\n",
        "    train_datagen = get_datagen(\n",
        "        use_default_augmentation=use_default_augmentation,\n",
        "        validation_split=validation_split if validation_split else 0.0,\n",
        "        **augmentations\n",
        "    )\n",
        "   \n",
        "    train_generator = train_datagen.flow_from_directory(\n",
        "        directory=train_data_dir,\n",
        "        target_size=(IMG_WIDTH, IMG_WIDTH),\n",
        "        batch_size=batch_size,\n",
        "        class_mode='binary',\n",
        "        subset='training',\n",
        "    )\n",
        "\n",
        "    validation_generator = None\n",
        "\n",
        "    if validation_split:\n",
        "        validation_generator = train_datagen.flow_from_directory(\n",
        "            directory=train_data_dir,\n",
        "            target_size=(IMG_WIDTH, IMG_WIDTH),\n",
        "            batch_size=batch_size,\n",
        "            class_mode='binary',\n",
        "            subset='validation'\n",
        "        )\n",
        "\n",
        "    return train_generator, validation_generator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "6G7tVf0wNHvd"
      },
      "outputs": [],
      "source": [
        "def get_test_data_generator(test_data_dir, batch_size, shuffle=False):\n",
        "    test_datagen = get_datagen(use_default_augmentation=False)\n",
        "    return test_datagen.flow_from_directory(\n",
        "        directory=test_data_dir,\n",
        "        target_size=(IMG_WIDTH, IMG_WIDTH),\n",
        "        batch_size=batch_size,\n",
        "        class_mode='binary',\n",
        "        shuffle=shuffle\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "kSWlbqMv-TK4"
      },
      "outputs": [],
      "source": [
        "def activation_layer(ip, activation, *args):\n",
        "    return {'relu': ReLU(*args)(ip),\n",
        "            'elu': ELU(*args)(ip),\n",
        "            'lrelu': LeakyReLU(*args)(ip)}[activation]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "CvF1f4Y28oPM"
      },
      "outputs": [],
      "source": [
        "def conv2D(ip,\n",
        "           filters,\n",
        "           kernel_size,\n",
        "           activation,\n",
        "           padding='same',\n",
        "           pool_size=(2, 2)):\n",
        "    layer = Conv2D(filters,\n",
        "                   kernel_size=kernel_size,\n",
        "                   padding=padding)(ip)\n",
        "\n",
        "    layer = activation_layer(layer, activation=activation)\n",
        "\n",
        "    layer = BatchNormalization()(layer)\n",
        "\n",
        "    return MaxPooling2D(pool_size=pool_size, padding=padding)(layer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "d-4--jRd-bz1"
      },
      "outputs": [],
      "source": [
        "def fully_connected_layer(ip,\n",
        "                          hidden_activation,\n",
        "                          dropout):\n",
        "    layer = Dense(16)(ip)\n",
        "    layer = activation_layer(layer, hidden_activation, *[0.1,])\n",
        "    return Dropout(rate=dropout)(layer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "1Cp48aFy_k4G"
      },
      "outputs": [],
      "source": [
        "def build_model(ip=Input(shape=(IMG_WIDTH, IMG_WIDTH, 3)),\n",
        "                activation='relu',\n",
        "                dropout=0.5,\n",
        "                hidden_activation='lrelu'):\n",
        "    \n",
        "    layer = conv2D(ip, filters=8, kernel_size=(3, 3), activation=activation)\n",
        "\n",
        "    layer = conv2D(layer, filters=8, kernel_size=(5, 5), activation=activation)\n",
        "\n",
        "    layer = conv2D(layer, filters=16, kernel_size=(5, 5), activation=activation)\n",
        "\n",
        "    layer = conv2D(layer, filters=16, kernel_size=(5, 5), activation=activation, pool_size=(4, 4))\n",
        "\n",
        "    layer = Flatten()(layer)\n",
        "    layer = Dropout(rate=dropout)(layer)\n",
        "\n",
        "    layer = fully_connected_layer(layer, hidden_activation=hidden_activation, dropout=dropout)\n",
        "\n",
        "    op_layer = Dense(1, activation='sigmoid')(layer)\n",
        "\n",
        "    model = Model(ip, op_layer)\n",
        "\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "ZdoMu0LbDGMC"
      },
      "outputs": [],
      "source": [
        "def evaluate_model(model, test_data_dir, batch_size):\n",
        "    data = get_test_data_generator(test_data_dir, batch_size)\n",
        "    return model.evaluate(data)\n",
        "\n",
        "\n",
        "def predict(model, data, steps=None, threshold=0.5):\n",
        "    predictions = model.predict(data, steps=steps, verbose=1)\n",
        "    return predictions, np.where(predictions >= threshold, 1, 0)\n",
        "\n",
        "\n",
        "def save_model_history(history, filename):\n",
        "    with open(filename, 'wb') as f:\n",
        "        pickle.dump(history.history, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "fuXsZWxke_ic"
      },
      "outputs": [],
      "source": [
        "def get_activation_model(model, conv_idx):\n",
        "    conv_layers = [layer for layer in model.layers if 'conv' in layer.name]\n",
        "    selected_layers = [layer for index, layer in enumerate(conv_layers) if index in conv_idx]\n",
        "    activation_model = Model(\n",
        "        inputs=model.inputs,\n",
        "        outputs=[layer.output for layer in selected_layers]\n",
        "    )\n",
        "    return activation_model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "W0Sda_34HCzQ"
      },
      "outputs": [],
      "source": [
        "def train_model(model,\n",
        "                train_data_dir,\n",
        "                validation_split=None,\n",
        "                batch_size=32,\n",
        "                use_default_augmentation=True,\n",
        "                augmentations=None,\n",
        "                epochs=30,\n",
        "                lr=1e-3,\n",
        "                loss='binary_crossentropy',\n",
        "                compile=True,\n",
        "                lr_decay=True,\n",
        "                decay_rate=0.10,\n",
        "                decay_limit=1e-6,\n",
        "                checkpoint=True,\n",
        "                stop_early=True,\n",
        "                monitor='val_accuracy',\n",
        "                mode='max',\n",
        "                patience=20,\n",
        "                tensorboard=True,\n",
        "                loss_curve=True):\n",
        "    \n",
        "\n",
        "    train_generator, validation_generator = get_train_data_generator(\n",
        "        train_data_dir=train_data_dir,\n",
        "        batch_size=batch_size,\n",
        "        validation_split=validation_split,\n",
        "        use_default_augmentation=use_default_augmentation,\n",
        "        augmentations=augmentations\n",
        "    )\n",
        "\n",
        "    callbacks = []\n",
        "    if checkpoint:\n",
        "        filepath = f'mesonet_trained.hdf5'\n",
        "        model_checkpoint = ModelCheckpoint(\n",
        "            filepath, monitor='val_accuracy', verbose=1,\n",
        "            save_best_only=True\n",
        "        )\n",
        "        callbacks.append(model_checkpoint)\n",
        "\n",
        "    if stop_early:\n",
        "        callbacks.append(\n",
        "            EarlyStopping(\n",
        "                monitor=monitor,\n",
        "                mode=mode,\n",
        "                patience=patience,\n",
        "                verbose=1\n",
        "            )\n",
        "        )\n",
        "\n",
        "\n",
        "    history = model.fit(\n",
        "        train_generator,\n",
        "        epochs=epochs,\n",
        "        verbose=1,\n",
        "        callbacks=callbacks,\n",
        "        validation_data=validation_generator,\n",
        "        steps_per_epoch=train_generator.samples // batch_size,\n",
        "        validation_steps=validation_generator.samples // batch_size if validation_generator else None,\n",
        "    )\n",
        "\n",
        "    return history"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "aZtXFYPDoLZp"
      },
      "outputs": [
        {
          "ename": "OSError",
          "evalue": "Unable to open file (file signature not found)",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mOSError\u001b[0m                                   Traceback (most recent call last)",
            "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model_exp \u001b[38;5;241m=\u001b[39m \u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/Users/jarvis/pymycod/Deepfakes/Meso_4.ipynb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m/opt/anaconda3/envs/tensor/lib/python3.8/site-packages/keras/utils/traceback_utils.py:67\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     65\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:  \u001b[38;5;66;03m# pylint: disable=broad-except\u001b[39;00m\n\u001b[1;32m     66\u001b[0m   filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[0;32m---> 67\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     68\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     69\u001b[0m   \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
            "File \u001b[0;32m/opt/anaconda3/envs/tensor/lib/python3.8/site-packages/h5py/_hl/files.py:507\u001b[0m, in \u001b[0;36mFile.__init__\u001b[0;34m(self, name, mode, driver, libver, userblock_size, swmr, rdcc_nslots, rdcc_nbytes, rdcc_w0, track_order, fs_strategy, fs_persist, fs_threshold, fs_page_size, page_buf_size, min_meta_keep, min_raw_keep, locking, **kwds)\u001b[0m\n\u001b[1;32m    502\u001b[0m     fapl \u001b[38;5;241m=\u001b[39m make_fapl(driver, libver, rdcc_nslots, rdcc_nbytes, rdcc_w0,\n\u001b[1;32m    503\u001b[0m                      locking, page_buf_size, min_meta_keep, min_raw_keep, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[1;32m    504\u001b[0m     fcpl \u001b[38;5;241m=\u001b[39m make_fcpl(track_order\u001b[38;5;241m=\u001b[39mtrack_order, fs_strategy\u001b[38;5;241m=\u001b[39mfs_strategy,\n\u001b[1;32m    505\u001b[0m                      fs_persist\u001b[38;5;241m=\u001b[39mfs_persist, fs_threshold\u001b[38;5;241m=\u001b[39mfs_threshold,\n\u001b[1;32m    506\u001b[0m                      fs_page_size\u001b[38;5;241m=\u001b[39mfs_page_size)\n\u001b[0;32m--> 507\u001b[0m     fid \u001b[38;5;241m=\u001b[39m make_fid(name, mode, userblock_size, fapl, fcpl, swmr\u001b[38;5;241m=\u001b[39mswmr)\n\u001b[1;32m    509\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(libver, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m    510\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_libver \u001b[38;5;241m=\u001b[39m libver\n",
            "File \u001b[0;32m/opt/anaconda3/envs/tensor/lib/python3.8/site-packages/h5py/_hl/files.py:220\u001b[0m, in \u001b[0;36mmake_fid\u001b[0;34m(name, mode, userblock_size, fapl, fcpl, swmr)\u001b[0m\n\u001b[1;32m    218\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m swmr \u001b[38;5;129;01mand\u001b[39;00m swmr_support:\n\u001b[1;32m    219\u001b[0m         flags \u001b[38;5;241m|\u001b[39m\u001b[38;5;241m=\u001b[39m h5f\u001b[38;5;241m.\u001b[39mACC_SWMR_READ\n\u001b[0;32m--> 220\u001b[0m     fid \u001b[38;5;241m=\u001b[39m \u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflags\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfapl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfapl\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    221\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr+\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m    222\u001b[0m     fid \u001b[38;5;241m=\u001b[39m h5f\u001b[38;5;241m.\u001b[39mopen(name, h5f\u001b[38;5;241m.\u001b[39mACC_RDWR, fapl\u001b[38;5;241m=\u001b[39mfapl)\n",
            "File \u001b[0;32mh5py/_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
            "File \u001b[0;32mh5py/_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
            "File \u001b[0;32mh5py/h5f.pyx:106\u001b[0m, in \u001b[0;36mh5py.h5f.open\u001b[0;34m()\u001b[0m\n",
            "\u001b[0;31mOSError\u001b[0m: Unable to open file (file signature not found)"
          ]
        }
      ],
      "source": [
        "model_exp = load_model('/Users/jarvis/pymycod/Deepfakes/Meso_4.ipynb')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IC4V7HflFZC2",
        "outputId": "c86ee986-baeb-449c-b024-d6ad7ecdaf2a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found 1945 images belonging to 2 classes.\n",
            "31/31 [==============================] - 7s 235ms/step - loss: 0.0998 - accuracy: 0.9625\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "[0.09982584416866302, 0.9624678492546082]"
            ]
          },
          "execution_count": 124,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "evaluate_model(model_exp, 'data/test', 64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fG6V0lOzFeg2",
        "outputId": "db1fec59-cd05-440f-9ae6-acaea21db0ea"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found 1945 images belonging to 2 classes.\n",
            "31/31 [==============================] - 7s 235ms/step\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.96      0.94      0.95       773\n",
            "           1       0.96      0.97      0.97      1172\n",
            "\n",
            "    accuracy                           0.96      1945\n",
            "   macro avg       0.96      0.96      0.96      1945\n",
            "weighted avg       0.96      0.96      0.96      1945\n",
            "\n"
          ]
        }
      ],
      "source": [
        "print(get_classification_report(model_exp, 'data/test'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cjHCL4fiF29I"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_z00HkaPG76d",
        "outputId": "f3d07ca5-29dd-4191-ef04-e6ef04bf592a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found 1945 images belonging to 2 classes.\n",
            "31/31 [==============================] - 8s 241ms/step - loss: 0.2321 - accuracy: 0.9080\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "[0.23209184408187866, 0.9079691767692566]"
            ]
          },
          "execution_count": 129,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "evaluate_model(model_exp, 'data/test', 64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "v3ZQN53bHBEf",
        "outputId": "14bd2bd7-0c83-4e52-ee12-054a4e52ca9f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found 1945 images belonging to 2 classes.\n",
            "31/31 [==============================] - 7s 234ms/step\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.90      0.87      0.88       773\n",
            "           1       0.91      0.93      0.92      1172\n",
            "\n",
            "    accuracy                           0.91      1945\n",
            "   macro avg       0.91      0.90      0.90      1945\n",
            "weighted avg       0.91      0.91      0.91      1945\n",
            "\n"
          ]
        }
      ],
      "source": [
        "print(get_classification_report(model_exp, 'data/test'))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Meso-4.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}