{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/shrey/Desktop/Kidney-Disease-Classifcation'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/shrey/Desktop/Kidney-Disease-Classifcation'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir(\"src\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "@dataclass(frozen=True)\n",
    "class PrepareBaseModelConfig:\n",
    "    root_dir: Path\n",
    "    base_model_path: Path\n",
    "    updated_base_model_path: Path\n",
    "    params_image_size: list\n",
    "    params_include_top: bool\n",
    "    params_weights: str\n",
    "    params_classes: int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kidney_classification.constants import *\n",
    "from kidney_classification.utils.common import read_yaml, create_directories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConfigurationManager:\n",
    "    def __init__(\n",
    "        self,\n",
    "        config_filepath = CONFIG_FILE_PATH,\n",
    "        params_filepath = PARAMS_FILE_PATH):\n",
    "\n",
    "        self.config = read_yaml(config_filepath)\n",
    "        self.params = read_yaml(params_filepath)\n",
    "\n",
    "        create_directories([self.config.artifacts_root])\n",
    "\n",
    "    \n",
    "\n",
    "    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:\n",
    "        config = self.config.prepare_base_model\n",
    "        \n",
    "        create_directories([config.root_dir])\n",
    "\n",
    "        prepare_base_model_config = PrepareBaseModelConfig(\n",
    "            root_dir=Path(config.root_dir),\n",
    "            base_model_path=Path(config.base_model_path),\n",
    "            updated_base_model_path=Path(config.updated_base_model_path),\n",
    "            params_image_size=self.params.IMAGE_SIZE,\n",
    "            params_include_top=self.params.INCLUDE_TOP,\n",
    "            params_weights=self.params.WEIGHTS,\n",
    "            params_classes=self.params.CLASSES\n",
    "        )\n",
    "\n",
    "        return prepare_base_model_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout\n",
    "from tensorflow.keras.models import Sequential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrepareBaseModel:\n",
    "    @staticmethod\n",
    "    def prepare_full_model():\n",
    "        VGG_model = Sequential()\n",
    "\n",
    "        pretrained_model= tf.keras.applications.VGG16(\n",
    "            include_top=False,\n",
    "            input_shape=(150, 150, 3),\n",
    "            pooling='max',\n",
    "            classes=4,\n",
    "            weights='imagenet'\n",
    "        )\n",
    "\n",
    "        VGG_model.add(pretrained_model)\n",
    "        VGG_model.add(Flatten())\n",
    "        VGG_model.add(Dense(512, activation='relu'))\n",
    "        VGG_model.add(BatchNormalization())\n",
    "        VGG_model.add(Dropout(0.5))\n",
    "\n",
    "        VGG_model.add(Dense(4, activation='softmax'))\n",
    "        pretrained_model.trainable = False\n",
    "\n",
    "        VGG_model.compile(\n",
    "            optimizer='adam',\n",
    "            loss='sparse_categorical_crossentropy',\n",
    "            metrics=['accuracy']\n",
    "        )\n",
    "\n",
    "        return VGG_model\n",
    "\n",
    "    def update_base_model(self, config: PrepareBaseModelConfig):\n",
    "        full_model = self.prepare_full_model()\n",
    "\n",
    "        full_model.summary()\n",
    "        full_model.save(config.updated_base_model_path)\n",
    "    \n",
    "        \n",
    "    @staticmethod\n",
    "    def save_model(path: Path, model: tf.keras.Model):\n",
    "        model.save(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-01-03 15:58:32,485: INFO: common yaml file: config/config.yaml loaded successfully]\n",
      "[2024-01-03 15:58:32,487: INFO: common yaml file: params.yaml loaded successfully]\n",
      "Model: \"sequential_3\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " vgg16 (Functional)          (None, 512)               14714688  \n",
      "                                                                 \n",
      " flatten_3 (Flatten)         (None, 512)               0         \n",
      "                                                                 \n",
      " dense_5 (Dense)             (None, 512)               262656    \n",
      "                                                                 \n",
      " batch_normalization_3 (Bat  (None, 512)               2048      \n",
      " chNormalization)                                                \n",
      "                                                                 \n",
      " dropout_3 (Dropout)         (None, 512)               0         \n",
      "                                                                 \n",
      " dense_6 (Dense)             (None, 4)                 2052      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 14981444 (57.15 MB)\n",
      "Trainable params: 265732 (1.01 MB)\n",
      "Non-trainable params: 14715712 (56.14 MB)\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    config = ConfigurationManager()\n",
    "    prepare_base_model_config = config.get_prepare_base_model_config()\n",
    "    prepare_base_model = PrepareBaseModel()\n",
    "    prepare_base_model.update_base_model(prepare_base_model_config)\n",
    "except Exception as e:\n",
    "    raise e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}