Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from pathlib import Path | |
from kidney_classification.entity.config_entity import TrainingConfig | |
class Training: | |
def __init__(self, config: TrainingConfig): | |
self.config = config | |
def get_base_model(self): | |
self.model = tf.keras.models.load_model(self.config.updated_base_model_path) | |
def train_valid_generator(self): | |
img_height, img_width = self.config.params_image_size[:-1] | |
train = tf.keras.utils.image_dataset_from_directory( | |
self.config.training_data, | |
image_size=(img_height, img_width), | |
validation_split=0.1, | |
subset="training", | |
seed=123, | |
) | |
val = tf.keras.utils.image_dataset_from_directory( | |
self.config.training_data, | |
image_size=(img_height, img_width), | |
validation_split=0.2, | |
subset="validation", | |
seed=123, | |
) | |
train = train.map(lambda x, y: (x / 255, y)) | |
val = val.map(lambda x, y: (x / 255, y)) | |
AUTOTUNE = tf.data.AUTOTUNE | |
self.train_dataset = train.cache().prefetch(buffer_size=AUTOTUNE) | |
self.val_dataset = val.cache().prefetch(buffer_size=AUTOTUNE) | |
def save_model(path: Path, model: tf.keras.Model): | |
model.save(path) | |
def define_and_train_model(self): | |
self.model.fit( | |
self.train_dataset, | |
validation_data=self.val_dataset, | |
epochs=self.config.params_epochs, | |
) | |
self.save_model(path=self.config.trained_model_path, model=self.model) | |