lipnet / model.py
milselarch's picture
Upload folder using huggingface_hub
3a3c68a
raw
history blame
3.11 kB
from typing import Any
import keras
from helpers import *
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv3D, LSTM, Dense, Dropout, Bidirectional, MaxPool3D, Activation, Reshape, SpatialDropout3D, BatchNormalization, TimeDistributed, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
class Predictor(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = self.create_model()
@classmethod
def create_model(cls):
model = Sequential()
model.add(Conv3D(128, 3, input_shape=(75, 46, 140, 1), padding='same'))
model.add(Activation('relu'))
model.add(MaxPool3D((1, 2, 2)))
model.add(Conv3D(256, 3, padding='same'))
model.add(Activation('relu'))
model.add(MaxPool3D((1, 2, 2)))
model.add(Conv3D(75, 3, padding='same'))
model.add(Activation('relu'))
model.add(MaxPool3D((1, 2, 2)))
model.add(TimeDistributed(Flatten()))
model.add(Bidirectional(LSTM(
128, kernel_initializer='Orthogonal', return_sequences=True
)))
model.add(Dropout(.5))
model.add(Bidirectional(LSTM(
128, kernel_initializer='Orthogonal', return_sequences=True
)))
model.add(Dropout(.5))
model.add(Dense(
char_to_num.vocabulary_size() + 1,
kernel_initializer='he_normal',
activation='softmax'
))
return model
def call(self, *args, **kwargs):
return self.model.call(*args, **kwargs)
@classmethod
def scheduler(cls, epoch, lr):
if epoch < 30:
return lr
else:
return lr * tf.math.exp(-0.1)
@classmethod
def CTCLoss(cls, y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = tf.keras.backend.ctc_batch_cost(
y_true, y_pred, input_length, label_length
)
return loss
class ProduceExample(tf.keras.callbacks.Callback):
def __init__(self, dataset) -> None:
self.dataset = dataset.as_numpy_iterator()
def on_epoch_end(self, epoch, logs=None) -> None:
data = self.dataset.next()
yhat = self.model.predict(data[0])
decoded = tf.keras.backend.ctc_decode(
yhat, [75, 75], greedy=False
)[0][0].numpy()
for x in range(len(yhat)):
print('Original:', tf.strings.reduce_join(
num_to_char(data[1][x])
).numpy().decode('utf-8'))
print('Prediction:', tf.strings.reduce_join(
num_to_char(decoded[x])
).numpy().decode('utf-8'))
print('~' * 100)