h3110Fr13nd's picture
Initial
47488ce
raw
history blame
1.36 kB
import tensorflow as tf
from pathlib import Path
from kidney_classification.entity.config_entity import PrepareBaseModelConfig
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Dropout
class PrepareBaseModel:
@staticmethod
def prepare_full_model():
VGG_model = Sequential()
pretrained_model = tf.keras.applications.VGG16(
include_top=False,
input_shape=(150, 150, 3),
pooling="max",
classes=4,
weights="imagenet",
)
VGG_model.add(pretrained_model)
VGG_model.add(Flatten())
VGG_model.add(Dense(512, activation="relu"))
VGG_model.add(BatchNormalization())
VGG_model.add(Dropout(0.5))
VGG_model.add(Dense(4, activation="softmax"))
pretrained_model.trainable = False
VGG_model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
return VGG_model
def update_base_model(self, config: PrepareBaseModelConfig):
full_model = self.prepare_full_model()
full_model.summary()
full_model.save(config.updated_base_model_path)
@staticmethod
def save_model(path: Path, model: tf.keras.Model):
model.save(path)