|
from typing import Any |
|
|
|
import keras |
|
|
|
from helpers import * |
|
|
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Conv3D, LSTM, Dense, Dropout, Bidirectional, MaxPool3D, Activation, Reshape, SpatialDropout3D, BatchNormalization, TimeDistributed, Flatten |
|
from tensorflow.keras.optimizers import Adam |
|
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler |
|
|
|
|
|
class Predictor(keras.Model): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.model = self.create_model() |
|
|
|
@classmethod |
|
def create_model(cls): |
|
model = Sequential() |
|
model.add(Conv3D(128, 3, input_shape=(75, 46, 140, 1), padding='same')) |
|
model.add(Activation('relu')) |
|
model.add(MaxPool3D((1, 2, 2))) |
|
|
|
model.add(Conv3D(256, 3, padding='same')) |
|
model.add(Activation('relu')) |
|
model.add(MaxPool3D((1, 2, 2))) |
|
|
|
model.add(Conv3D(75, 3, padding='same')) |
|
model.add(Activation('relu')) |
|
model.add(MaxPool3D((1, 2, 2))) |
|
|
|
model.add(TimeDistributed(Flatten())) |
|
|
|
model.add(Bidirectional(LSTM( |
|
128, kernel_initializer='Orthogonal', return_sequences=True |
|
))) |
|
model.add(Dropout(.5)) |
|
|
|
model.add(Bidirectional(LSTM( |
|
128, kernel_initializer='Orthogonal', return_sequences=True |
|
))) |
|
model.add(Dropout(.5)) |
|
|
|
model.add(Dense( |
|
char_to_num.vocabulary_size() + 1, |
|
kernel_initializer='he_normal', |
|
activation='softmax' |
|
)) |
|
|
|
return model |
|
|
|
def call(self, *args, **kwargs): |
|
return self.model.call(*args, **kwargs) |
|
|
|
@classmethod |
|
def scheduler(cls, epoch, lr): |
|
if epoch < 30: |
|
return lr |
|
else: |
|
return lr * tf.math.exp(-0.1) |
|
|
|
@classmethod |
|
def CTCLoss(cls, y_true, y_pred): |
|
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") |
|
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") |
|
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") |
|
|
|
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") |
|
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") |
|
|
|
loss = tf.keras.backend.ctc_batch_cost( |
|
y_true, y_pred, input_length, label_length |
|
) |
|
return loss |
|
|
|
|
|
class ProduceExample(tf.keras.callbacks.Callback): |
|
def __init__(self, dataset) -> None: |
|
self.dataset = dataset.as_numpy_iterator() |
|
|
|
def on_epoch_end(self, epoch, logs=None) -> None: |
|
data = self.dataset.next() |
|
yhat = self.model.predict(data[0]) |
|
decoded = tf.keras.backend.ctc_decode( |
|
yhat, [75, 75], greedy=False |
|
)[0][0].numpy() |
|
|
|
for x in range(len(yhat)): |
|
print('Original:', tf.strings.reduce_join( |
|
num_to_char(data[1][x]) |
|
).numpy().decode('utf-8')) |
|
print('Prediction:', tf.strings.reduce_join( |
|
num_to_char(decoded[x]) |
|
).numpy().decode('utf-8')) |
|
print('~' * 100) |