from typing import Any import keras from helpers import * from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv3D, LSTM, Dense, Dropout, Bidirectional, MaxPool3D, Activation, Reshape, SpatialDropout3D, BatchNormalization, TimeDistributed, Flatten from tensorflow.keras.optimizers import Adam from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler class Predictor(keras.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model = self.create_model() @classmethod def create_model(cls): model = Sequential() model.add(Conv3D(128, 3, input_shape=(75, 46, 140, 1), padding='same')) model.add(Activation('relu')) model.add(MaxPool3D((1, 2, 2))) model.add(Conv3D(256, 3, padding='same')) model.add(Activation('relu')) model.add(MaxPool3D((1, 2, 2))) model.add(Conv3D(75, 3, padding='same')) model.add(Activation('relu')) model.add(MaxPool3D((1, 2, 2))) model.add(TimeDistributed(Flatten())) model.add(Bidirectional(LSTM( 128, kernel_initializer='Orthogonal', return_sequences=True ))) model.add(Dropout(.5)) model.add(Bidirectional(LSTM( 128, kernel_initializer='Orthogonal', return_sequences=True ))) model.add(Dropout(.5)) model.add(Dense( char_to_num.vocabulary_size() + 1, kernel_initializer='he_normal', activation='softmax' )) return model def call(self, *args, **kwargs): return self.model.call(*args, **kwargs) @classmethod def scheduler(cls, epoch, lr): if epoch < 30: return lr else: return lr * tf.math.exp(-0.1) @classmethod def CTCLoss(cls, y_true, y_pred): batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") loss = tf.keras.backend.ctc_batch_cost( y_true, y_pred, input_length, label_length ) return loss class ProduceExample(tf.keras.callbacks.Callback): def __init__(self, dataset) -> None: self.dataset = dataset.as_numpy_iterator() def on_epoch_end(self, epoch, logs=None) -> None: data = self.dataset.next() yhat = self.model.predict(data[0]) decoded = tf.keras.backend.ctc_decode( yhat, [75, 75], greedy=False )[0][0].numpy() for x in range(len(yhat)): print('Original:', tf.strings.reduce_join( num_to_char(data[1][x]) ).numpy().decode('utf-8')) print('Prediction:', tf.strings.reduce_join( num_to_char(decoded[x]) ).numpy().decode('utf-8')) print('~' * 100)