FlowerNet / FlowerNet.py
Innokentiy's picture
Upload 6 files
80be6f3
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import regularizers
assert 'COLAB_TPU_ADDR' in os.environ, 'Missin TPU?'
if('COLAB_TPU_ADDR') in os.environ:
TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
else:
TF_MASTER = ''
tpu_address = TF_MASTER
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
def create_model():
return tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.001)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.001)),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001)),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(5, activation='softmax')# всего пять классов цветов
])
def get_train_and_val_dataset(batch_size, is_training=True):
if(is_training):
dataset, info = tfds.load(name='tf_flowers',
split='train[:80%]',
with_info = True,
as_supervised=True,
try_gcs=True)
else:
dataset, info = tfds.load(name='tf_flowers',
split='train[80%:90%]',
with_info = True,
as_supervised=True,
try_gcs=True)
def scale(image, label):
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [224, 224]) # изменение всех изображений на вход до (None, 224, 224)
image /= 255.0 # Нормализация
return image, label
dataset = dataset.map(scale)
if is_training:
dataset = dataset.shuffle(2936)#Перемешивание обучающей выборки
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
return dataset
def get_final_dataset(batch_size):
dataset, info = tfds.load(name='tf_flowers',
split='train[90%:]',
with_info = True,
as_supervised=True,
try_gcs=True)
def scale(image, label):
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [224, 224]) # изменение всех изображений на вход до (None, 224, 224)
image /= 255.0 # Нормализация
return image, label
dataset = dataset.map(scale)
#dataset = dataset.shuffle(2936)#Перемешивание обучающей выборки
#dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
return dataset
def create_xception_model(input_shape=(224, 224, 3), num_classes=5):
#Загрузка предварительно обученной модели Xception без головной части
base_model = tf.keras.applications.Xception(include_top=False, input_shape=input_shape)
#Добавление головной части
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
#Объединение предварительно обученной модели и головной части в единую модель
model = tf.keras.models.Model(inputs=base_model.input, outputs=x)
#Заморозка слоев предварительно обученной модели
for layer in base_model.layers:
layer.trainable = False
return model
batch_size = 1024 #Размер пакета
epochs = 1000 #Количество эпох, на тензорных процессорах можно делать много проверок
execution_steps = 1000 #Количество шагов перед обновлением весов
#Загрузка и создание обучающей и проверочной(валидационной) выборки
train_dataset = get_train_and_val_dataset(batch_size, True)
validation_dataset = get_train_and_val_dataset(batch_size, False)
steps_per_epoch = 2936 // batch_size
validation_steps = len(validation_dataset) // batch_size
with strategy.scope():
xmodel = create_xception_model()
xmodel.compile(optimizer='adagrad', steps_per_execution=execution_steps, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy'])
x_history = xmodel.fit(train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=validation_dataset)
#Переменные для графика
acc = x_history.history['sparse_categorical_accuracy']
val_acc = x_history.history['val_sparse_categorical_accuracy']
loss = x_history.history['loss']
val_loss = x_history.history['val_loss']
epochs_range = range(epochs)
#График при помощи matplotlib
plt.figure(figsize=(15, 15))
plt.subplot(2, 2, 1)
plt.plot(epochs_range, acc, label='Тренировочная точность')
plt.plot(epochs_range, val_acc, label='Валидационная точность')
plt.legend(loc='lower right')
plt.title('Тренировочная и валидационная точность')
plt.subplot(2, 2, 2)
plt.plot(epochs_range, loss, label='Тренировочная потеря')
plt.plot(epochs_range, val_loss, label='Валидационная потеря')
plt.legend(loc='upper right')
plt.title('Тренировочная и валидационная точность')
plt.show()
#всего три выборки: тренировочная(train_dataset), валидационная(validation_dataset) и тестовая(test_dataset)
#тренировочная 0:80
#валидационная 80:90
#тестовая 90:100
test_dataset = get_final_dataset(batch_size)
test_images, test_labels = next(iter(test_dataset.take(10)))
#Можно использоать информацию о классах из info, но мне нужно было перевести названия классов и их не слишком много, поэтому я решил их инициализировать. Если количество классов большое, например их 100 или больше, то лучше обращаться к ним через info.
class_names = ['Одуванчик', 'Ромашка', 'Тюльпаны', 'Подсолнухи', 'Розы']
test_loss, test_accuracy = xmodel.evaluate(test_dataset)
print('Test loss: {}, Test accuracy: {}'.format(test_loss, test_accuracy))
# Получение предсказаний нейросети для 10 изображений
predictions = xmodel.predict(test_images)
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 6),
subplot_kw={'xticks': [], 'yticks': []})
for i, ax in enumerate(axes.flat):
# Отображение изображения
ax.imshow(test_images[i])
# Отображение меток и предсказаний
true_label = class_names[test_labels[i]]
pred_label = class_names[np.argmax(predictions[i])]
if true_label == pred_label:
ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='green')
else:
ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='red')
plt.tight_layout()
plt.show()