|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
import os |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from tensorflow.keras import regularizers |
|
|
|
assert 'COLAB_TPU_ADDR' in os.environ, 'Missin TPU?' |
|
if('COLAB_TPU_ADDR') in os.environ: |
|
TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR']) |
|
else: |
|
TF_MASTER = '' |
|
tpu_address = TF_MASTER |
|
|
|
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address) |
|
tf.config.experimental_connect_to_cluster(resolver) |
|
tf.tpu.experimental.initialize_tpu_system(resolver) |
|
|
|
strategy = tf.distribute.TPUStrategy(resolver) |
|
|
|
|
|
def create_model(): |
|
return tf.keras.Sequential([ |
|
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)), |
|
tf.keras.layers.MaxPooling2D((2, 2)), |
|
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), |
|
tf.keras.layers.MaxPooling2D((2, 2)), |
|
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'), |
|
tf.keras.layers.MaxPooling2D((2, 2)), |
|
tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
|
tf.keras.layers.MaxPooling2D((2, 2)), |
|
tf.keras.layers.Flatten(), |
|
tf.keras.layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
|
tf.keras.layers.Dropout(0.5), |
|
tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
|
tf.keras.layers.Dropout(0.5), |
|
tf.keras.layers.Dense(5, activation='softmax') |
|
|
|
]) |
|
|
|
|
|
def get_train_and_val_dataset(batch_size, is_training=True): |
|
if(is_training): |
|
dataset, info = tfds.load(name='tf_flowers', |
|
split='train[:80%]', |
|
with_info = True, |
|
as_supervised=True, |
|
try_gcs=True) |
|
else: |
|
dataset, info = tfds.load(name='tf_flowers', |
|
split='train[80%:90%]', |
|
with_info = True, |
|
as_supervised=True, |
|
try_gcs=True) |
|
|
|
def scale(image, label): |
|
image = tf.cast(image, tf.float32) |
|
image = tf.image.resize(image, [224, 224]) |
|
image /= 255.0 |
|
return image, label |
|
|
|
dataset = dataset.map(scale) |
|
|
|
if is_training: |
|
dataset = dataset.shuffle(2936) |
|
dataset = dataset.repeat() |
|
|
|
dataset = dataset.batch(batch_size) |
|
return dataset |
|
|
|
def get_final_dataset(batch_size): |
|
dataset, info = tfds.load(name='tf_flowers', |
|
split='train[90%:]', |
|
with_info = True, |
|
as_supervised=True, |
|
try_gcs=True) |
|
def scale(image, label): |
|
image = tf.cast(image, tf.float32) |
|
image = tf.image.resize(image, [224, 224]) |
|
image /= 255.0 |
|
return image, label |
|
dataset = dataset.map(scale) |
|
|
|
|
|
|
|
|
|
dataset = dataset.batch(batch_size) |
|
return dataset |
|
|
|
|
|
def create_xception_model(input_shape=(224, 224, 3), num_classes=5): |
|
|
|
base_model = tf.keras.applications.Xception(include_top=False, input_shape=input_shape) |
|
|
|
|
|
x = base_model.output |
|
x = tf.keras.layers.GlobalAveragePooling2D()(x) |
|
x = tf.keras.layers.Dense(1024, activation='relu')(x) |
|
x = tf.keras.layers.Dropout(0.5)(x) |
|
x = tf.keras.layers.Dense(num_classes, activation='softmax')(x) |
|
|
|
|
|
model = tf.keras.models.Model(inputs=base_model.input, outputs=x) |
|
|
|
|
|
for layer in base_model.layers: |
|
layer.trainable = False |
|
|
|
return model |
|
|
|
batch_size = 1024 |
|
epochs = 1000 |
|
execution_steps = 1000 |
|
|
|
train_dataset = get_train_and_val_dataset(batch_size, True) |
|
validation_dataset = get_train_and_val_dataset(batch_size, False) |
|
steps_per_epoch = 2936 // batch_size |
|
validation_steps = len(validation_dataset) // batch_size |
|
|
|
|
|
with strategy.scope(): |
|
xmodel = create_xception_model() |
|
xmodel.compile(optimizer='adagrad', steps_per_execution=execution_steps, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) |
|
x_history = xmodel.fit(train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=validation_dataset) |
|
|
|
|
|
|
|
acc = x_history.history['sparse_categorical_accuracy'] |
|
val_acc = x_history.history['val_sparse_categorical_accuracy'] |
|
loss = x_history.history['loss'] |
|
val_loss = x_history.history['val_loss'] |
|
epochs_range = range(epochs) |
|
|
|
|
|
|
|
plt.figure(figsize=(15, 15)) |
|
plt.subplot(2, 2, 1) |
|
plt.plot(epochs_range, acc, label='Тренировочная точность') |
|
plt.plot(epochs_range, val_acc, label='Валидационная точность') |
|
plt.legend(loc='lower right') |
|
plt.title('Тренировочная и валидационная точность') |
|
|
|
plt.subplot(2, 2, 2) |
|
plt.plot(epochs_range, loss, label='Тренировочная потеря') |
|
plt.plot(epochs_range, val_loss, label='Валидационная потеря') |
|
plt.legend(loc='upper right') |
|
plt.title('Тренировочная и валидационная точность') |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
test_dataset = get_final_dataset(batch_size) |
|
test_images, test_labels = next(iter(test_dataset.take(10))) |
|
|
|
class_names = ['Одуванчик', 'Ромашка', 'Тюльпаны', 'Подсолнухи', 'Розы'] |
|
|
|
test_loss, test_accuracy = xmodel.evaluate(test_dataset) |
|
print('Test loss: {}, Test accuracy: {}'.format(test_loss, test_accuracy)) |
|
|
|
|
|
predictions = xmodel.predict(test_images) |
|
|
|
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 6), |
|
subplot_kw={'xticks': [], 'yticks': []}) |
|
for i, ax in enumerate(axes.flat): |
|
|
|
ax.imshow(test_images[i]) |
|
|
|
true_label = class_names[test_labels[i]] |
|
pred_label = class_names[np.argmax(predictions[i])] |
|
if true_label == pred_label: |
|
ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='green') |
|
else: |
|
ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='red') |
|
|
|
plt.tight_layout() |
|
plt.show() |