nouamanetazi's picture
nouamanetazi HF staff
quick fix
4ac4e3b
raw
history blame
861 Bytes
import tensorflow as tf
from tensorflow import keras
class OrthogonalRegularizer(keras.regularizers.Regularizer):
"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""
def __init__(self, num_features, l2reg=0.001):
self.num_features = num_features
self.l2reg = l2reg
self.identity = tf.eye(num_features)
def __call__(self, x):
identity = tf.cast(self.identity, x.dtype)
x = tf.reshape(x, (tf.shape(x)[0], self.num_features, self.num_features))
xxt = tf.tensordot(x, x, axes=(2, 2))
xxt = tf.reshape(xxt, (tf.shape(x)[0] * tf.shape(x)[0], self.num_features, self.num_features))
return tf.reduce_sum(self.l2reg * tf.square(xxt - identity))
def get_config(self):
config = {"num_features": self.num_features, "l2reg": self.l2reg}
return config