sayakpaul's picture
sayakpaul HF staff
add: files.
3f9d71f
import einops
import tensorflow as tf
from tensorflow.experimental import numpy as tnp
from tensorflow.keras import backend as K
from tensorflow.keras import layers
@tf.keras.utils.register_keras_serializable("maxim")
class BlockImages(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, x, patch_size):
bs, h, w, num_channels = (
K.int_shape(x)[0],
K.int_shape(x)[1],
K.int_shape(x)[2],
K.int_shape(x)[3],
)
grid_height, grid_width = h // patch_size[0], w // patch_size[1]
x = einops.rearrange(
x,
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
gh=grid_height,
gw=grid_width,
fh=patch_size[0],
fw=patch_size[1],
)
return x
def get_config(self):
config = super().get_config().copy()
return config
@tf.keras.utils.register_keras_serializable("maxim")
class UnblockImages(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, x, grid_size, patch_size):
x = einops.rearrange(
x,
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
gh=grid_size[0],
gw=grid_size[1],
fh=patch_size[0],
fw=patch_size[1],
)
return x
def get_config(self):
config = super().get_config().copy()
return config
@tf.keras.utils.register_keras_serializable("maxim")
class SwapAxes(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, x, axis_one, axis_two):
return tnp.swapaxes(x, axis_one, axis_two)
def get_config(self):
config = super().get_config().copy()
return config
@tf.keras.utils.register_keras_serializable("maxim")
class Resizing(layers.Layer):
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
super().__init__(**kwargs)
self.height = height
self.width = width
self.antialias = antialias
self.method = method
def call(self, x):
return tf.image.resize(
x,
size=(self.height, self.width),
antialias=self.antialias,
method=self.method,
)
def get_config(self):
config = super().get_config().copy()
config.update(
{
"height": self.height,
"width": self.width,
"antialias": self.antialias,
"method": self.method,
}
)
return config