|
from tensorflow.python.keras.constraints import Constraint |
|
from tensorflow.python.ops import math_ops, array_ops |
|
|
|
|
|
class TightFrame(Constraint): |
|
""" |
|
Parseval (tight) frame contstraint, as introduced in https://arxiv.org/abs/1704.08847 |
|
|
|
Constraints the weight matrix to be a tight frame, so that the Lipschitz |
|
constant of the layer is <= 1. This increases the robustness of the network |
|
to adversarial noise. |
|
|
|
Warning: This constraint simply performs the update step on the weight matrix |
|
(or the unfolded weight matrix for convolutional layers). Thus, it does not |
|
handle the necessary scalings for convolutional layers. |
|
|
|
Args: |
|
scale (float): Retraction parameter (length of retraction step). |
|
num_passes (int): Number of retraction steps. |
|
|
|
Returns: |
|
Weight matrix after applying regularizer. |
|
""" |
|
|
|
def __init__(self, scale, num_passes=1): |
|
"""[summary] |
|
|
|
Args: |
|
scale ([type]): [description] |
|
num_passes (int, optional): [description]. Defaults to 1. |
|
|
|
Raises: |
|
ValueError: [description] |
|
""" |
|
self.scale = scale |
|
|
|
if num_passes < 1: |
|
raise ValueError( |
|
"Number of passes cannot be non-positive! (got {})".format(num_passes) |
|
) |
|
self.num_passes = num_passes |
|
|
|
def __call__(self, w): |
|
"""[summary] |
|
|
|
Args: |
|
w ([type]): weight of conv or linear layers |
|
|
|
Returns: |
|
[type]: returns new weights |
|
""" |
|
transpose_channels = len(w.shape) == 4 |
|
|
|
|
|
if transpose_channels: |
|
w_reordered = array_ops.reshape(w, (-1, w.shape[3])) |
|
|
|
else: |
|
w_reordered = w |
|
|
|
last = w_reordered |
|
for i in range(self.num_passes): |
|
temp1 = math_ops.matmul(last, last, transpose_a=True) |
|
temp2 = (1 + self.scale) * w_reordered - self.scale * math_ops.matmul( |
|
w_reordered, temp1 |
|
) |
|
|
|
last = temp2 |
|
|
|
|
|
if transpose_channels: |
|
return array_ops.reshape(last, w.shape) |
|
else: |
|
return last |
|
|
|
def get_config(self): |
|
return {"scale": self.scale, "num_passes": self.num_passes} |
|
|
|
|
|
|
|
tight_frame = TightFrame |
|
|