|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
|
|
|
|
def get_sinusoidal_embeddings( |
|
timesteps: jnp.ndarray, |
|
embedding_dim: int, |
|
freq_shift: float = 1, |
|
min_timescale: float = 1, |
|
max_timescale: float = 1.0e4, |
|
flip_sin_to_cos: bool = False, |
|
scale: float = 1.0, |
|
) -> jnp.ndarray: |
|
"""Returns the positional encoding (same as Tensor2Tensor). |
|
|
|
Args: |
|
timesteps: a 1-D Tensor of N indices, one per batch element. |
|
These may be fractional. |
|
embedding_dim: The number of output channels. |
|
min_timescale: The smallest time unit (should probably be 0.0). |
|
max_timescale: The largest time unit. |
|
Returns: |
|
a Tensor of timing signals [N, num_channels] |
|
""" |
|
assert timesteps.ndim == 1, "Timesteps should be a 1d-array" |
|
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" |
|
num_timescales = float(embedding_dim // 2) |
|
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) |
|
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) |
|
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) |
|
|
|
|
|
scaled_time = scale * emb |
|
|
|
if flip_sin_to_cos: |
|
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) |
|
else: |
|
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) |
|
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) |
|
return signal |
|
|
|
|
|
class FlaxTimestepEmbedding(nn.Module): |
|
r""" |
|
Time step Embedding Module. Learns embeddings for input time steps. |
|
|
|
Args: |
|
time_embed_dim (`int`, *optional*, defaults to `32`): |
|
Time step embedding dimension |
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
|
Parameters `dtype` |
|
""" |
|
|
|
time_embed_dim: int = 32 |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__(self, temb): |
|
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) |
|
temb = nn.silu(temb) |
|
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) |
|
return temb |
|
|
|
|
|
class FlaxTimesteps(nn.Module): |
|
r""" |
|
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 |
|
|
|
Args: |
|
dim (`int`, *optional*, defaults to `32`): |
|
Time step embedding dimension |
|
""" |
|
|
|
dim: int = 32 |
|
flip_sin_to_cos: bool = False |
|
freq_shift: float = 1 |
|
|
|
@nn.compact |
|
def __call__(self, timesteps): |
|
return get_sinusoidal_embeddings( |
|
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift |
|
) |
|
|