|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Dense attention classes and mask/weighting functions.""" |
|
|
|
|
|
|
|
import dataclasses |
|
import functools |
|
import operator |
|
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from flax import linen as nn |
|
from flax.linen import partitioning as nn_partitioning |
|
from flax.linen.dtypes import promote_dtype |
|
from jax import lax, random |
|
|
|
|
|
|
|
param_with_axes = nn_partitioning.param_with_axes |
|
with_sharding_constraint = nn_partitioning.with_sharding_constraint |
|
|
|
|
|
|
|
Array = jnp.ndarray |
|
DType = jnp.dtype |
|
PRNGKey = jnp.ndarray |
|
Shape = Iterable[int] |
|
Activation = Callable[..., Array] |
|
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] |
|
DotGeneralT = Callable[..., Array] |
|
ConvGeneralDilatedT = Callable[..., Array] |
|
PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] |
|
LaxPadding = Union[str, Sequence[Tuple[int, int]]] |
|
|
|
|
|
Initializer = Callable[[PRNGKey, Shape, DType], Array] |
|
InitializerAxis = Union[int, Tuple[int, ...]] |
|
NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array] |
|
|
|
default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0) |
|
|
|
|
|
def dot_product_attention( |
|
query: Array, |
|
key: Array, |
|
value: Array, |
|
bias: Optional[Array] = None, |
|
dropout_rng: Optional[PRNGKey] = None, |
|
dropout_rate: float = 0.0, |
|
deterministic: bool = False, |
|
dtype: DType = jnp.float32, |
|
float32_logits: bool = False, |
|
): |
|
"""Computes dot-product attention given query, key, and value. |
|
This is the core function for applying attention based on |
|
https://arxiv.org/abs/1706.03762. It calculates the attention weights given |
|
query and key and combines the values using the attention weights. |
|
Args: |
|
query: queries for calculating attention with shape of `[batch, q_length, |
|
num_heads, qk_depth_per_head]`. |
|
key: keys for calculating attention with shape of `[batch, kv_length, |
|
num_heads, qk_depth_per_head]`. |
|
value: values to be used in attention with shape of `[batch, kv_length, |
|
num_heads, v_depth_per_head]`. |
|
bias: bias for the attention weights. This should be broadcastable to the |
|
shape `[batch, num_heads, q_length, kv_length]` This can be used for |
|
incorporating causal masks, padding masks, proximity bias, etc. |
|
dropout_rng: JAX PRNGKey: to be used for dropout |
|
dropout_rate: dropout rate |
|
deterministic: bool, deterministic or not (to apply dropout) |
|
dtype: the dtype of the computation (default: float32) |
|
float32_logits: bool, if True then compute logits in float32 to avoid |
|
numerical issues with bfloat16. |
|
Returns: |
|
Output of shape `[batch, length, num_heads, v_depth_per_head]`. |
|
""" |
|
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." |
|
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." |
|
assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match." |
|
assert key.shape[-3] == value.shape[-3], "k, v lengths must match." |
|
assert query.shape[-1] == key.shape[-1], "q, k depths must match." |
|
|
|
|
|
if float32_logits: |
|
query = query.astype(jnp.float32) |
|
key = key.astype(jnp.float32) |
|
|
|
|
|
attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key) |
|
|
|
|
|
if bias is not None: |
|
attn_weights = attn_weights + bias.astype(attn_weights.dtype) |
|
|
|
|
|
attn_weights = jax.nn.softmax(attn_weights).astype(dtype) |
|
|
|
|
|
if not deterministic and dropout_rate > 0.0: |
|
keep_prob = 1.0 - dropout_rate |
|
|
|
|
|
dropout_shape = list(attn_weights.shape) |
|
dropout_shape[-2] = 1 |
|
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) |
|
keep = jnp.broadcast_to(keep, attn_weights.shape) |
|
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype) |
|
attn_weights = attn_weights * multiplier |
|
|
|
|
|
return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) |
|
|
|
|
|
dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) |
|
|
|
|
|
class MultiHeadDotProductAttention(nn.Module): |
|
"""Multi-head dot-product attention. |
|
Attributes: |
|
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) |
|
should be divisible by the number of heads. |
|
head_dim: dimension of each head. |
|
dtype: the dtype of the computation. |
|
dropout_rate: dropout rate |
|
kernel_init: initializer for the kernel of the Dense layers. |
|
float32_logits: bool, if True then compute logits in float32 to avoid |
|
numerical issues with bfloat16. |
|
""" |
|
|
|
num_heads: int |
|
head_dim: int |
|
dtype: DType = jnp.float32 |
|
dropout_rate: float = 0.0 |
|
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") |
|
float32_logits: bool = False |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
inputs_q: Array, |
|
inputs_kv: Array, |
|
mask: Optional[Array] = None, |
|
bias: Optional[Array] = None, |
|
*, |
|
decode: bool = False, |
|
deterministic: bool = False, |
|
) -> Array: |
|
"""Applies multi-head dot product attention on the input data. |
|
Projects the inputs into multi-headed query, key, and value vectors, |
|
applies dot-product attention and project the results to an output vector. |
|
There are two modes: decoding and non-decoding (e.g., training). The mode is |
|
determined by `decode` argument. For decoding, this method is called twice, |
|
first to initialize the cache and then for an actual decoding process. The |
|
two calls are differentiated by the presence of 'cached_key' in the variable |
|
dict. In the cache initialization stage, the cache variables are initialized |
|
as zeros and will be filled in the subsequent decoding process. |
|
In the cache initialization call, `inputs_q` has a shape [batch, length, |
|
q_features] and `inputs_kv`: [batch, length, kv_features]. During the |
|
incremental decoding stage, query, key and value all have the shape [batch, |
|
1, qkv_features] corresponding to a single step. |
|
Args: |
|
inputs_q: input queries of shape `[batch, q_length, q_features]`. |
|
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. |
|
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. |
|
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. |
|
decode: Whether to prepare and use an autoregressive cache. |
|
deterministic: Disables dropout if set to True. |
|
Returns: |
|
output of shape `[batch, length, q_features]`. |
|
""" |
|
projection = functools.partial( |
|
DenseGeneral, |
|
axis=-1, |
|
features=(self.num_heads, self.head_dim), |
|
kernel_axes=("embed", "heads", "kv"), |
|
dtype=self.dtype, |
|
) |
|
|
|
|
|
|
|
|
|
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) |
|
|
|
def query_init(*args): |
|
return self.kernel_init(*args) / depth_scaling |
|
|
|
|
|
|
|
query = projection(kernel_init=query_init, name="query")(inputs_q) |
|
key = projection(kernel_init=self.kernel_init, name="key")(inputs_kv) |
|
value = projection(kernel_init=self.kernel_init, name="value")(inputs_kv) |
|
|
|
query = with_sharding_constraint(query, ("batch", "length", "heads", "kv")) |
|
key = with_sharding_constraint(key, ("batch", "length", "heads", "kv")) |
|
value = with_sharding_constraint(value, ("batch", "length", "heads", "kv")) |
|
|
|
if decode: |
|
|
|
is_initialized = self.has_variable("cache", "cached_key") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def swap_dims(x): |
|
return x[:-3] + tuple(x[i] for i in [-2, -1, -3]) |
|
|
|
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype) |
|
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype) |
|
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) |
|
if is_initialized: |
|
batch, num_heads, head_dim, length = cached_key.value.shape |
|
|
|
|
|
|
|
expected_shape = (batch, 1, num_heads, head_dim) |
|
if expected_shape != query.shape: |
|
raise ValueError( |
|
"Autoregressive cache shape error, " |
|
"expected query shape %s instead got %s." % (expected_shape, query.shape) |
|
) |
|
|
|
|
|
cur_index = cache_index.value |
|
one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) |
|
|
|
|
|
|
|
|
|
|
|
one_token_key = jnp.moveaxis(key, -3, -1) |
|
one_token_value = jnp.moveaxis(value, -3, -1) |
|
|
|
|
|
|
|
key = cached_key.value + one_token_key * one_hot_indices |
|
value = cached_value.value + one_token_value * one_hot_indices |
|
cached_key.value = key |
|
cached_value.value = value |
|
cache_index.value = cache_index.value + 1 |
|
|
|
key = jnp.moveaxis(key, -1, -3) |
|
value = jnp.moveaxis(value, -1, -3) |
|
|
|
|
|
|
|
|
|
mask = combine_masks( |
|
mask, |
|
jnp.broadcast_to( |
|
jnp.arange(length) <= cur_index, |
|
|
|
|
|
|
|
|
|
(batch, 1, 1, length), |
|
), |
|
) |
|
|
|
|
|
|
|
if bias is not None: |
|
|
|
|
|
|
|
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) |
|
|
|
|
|
if mask is not None: |
|
|
|
attention_bias = lax.select( |
|
mask > 0, jnp.full(mask.shape, 0.0).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype) |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
|
|
if bias is not None: |
|
attention_bias = combine_biases(attention_bias, bias) |
|
|
|
dropout_rng = None |
|
if not deterministic and self.dropout_rate > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
|
|
x = dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.dropout_rate, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
float32_logits=self.float32_logits, |
|
) |
|
|
|
|
|
out = DenseGeneral( |
|
features=inputs_q.shape[-1], |
|
axis=(-2, -1), |
|
kernel_init=self.kernel_init, |
|
kernel_axes=("heads", "kv", "embed"), |
|
dtype=self.dtype, |
|
name="out", |
|
)(x) |
|
return out |
|
|
|
|
|
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: |
|
|
|
return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) |
|
|
|
|
|
def _canonicalize_tuple(x): |
|
if isinstance(x, Iterable): |
|
return tuple(x) |
|
else: |
|
return (x,) |
|
|
|
|
|
|
|
|
|
|
|
class DenseGeneral(nn.Module): |
|
"""A linear transformation (without bias) with flexible axes. |
|
Attributes: |
|
features: tuple with numbers of output features. |
|
axis: tuple with axes to apply the transformation on. |
|
dtype: the dtype of the computation (default: float32). |
|
kernel_init: initializer function for the weight matrix. |
|
""" |
|
|
|
features: Union[Iterable[int], int] |
|
axis: Union[Iterable[int], int] = -1 |
|
dtype: DType = jnp.float32 |
|
params_dtype: DType = jnp.float32 |
|
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") |
|
kernel_axes: Tuple[str, ...] = () |
|
use_bias: bool = True |
|
bias_init: Any = nn.initializers.zeros |
|
|
|
@nn.compact |
|
def __call__(self, inputs: Array) -> Array: |
|
"""Applies a linear transformation to the inputs along multiple dimensions. |
|
Args: |
|
inputs: The nd-array to be transformed. |
|
Returns: |
|
The transformed input. |
|
""" |
|
features = _canonicalize_tuple(self.features) |
|
axis = _canonicalize_tuple(self.axis) |
|
|
|
inputs = jnp.asarray(inputs, self.dtype) |
|
axis = _normalize_axes(axis, inputs.ndim) |
|
|
|
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features |
|
kernel_in_axis = np.arange(len(axis)) |
|
kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) |
|
kernel = param_with_axes( |
|
"kernel", |
|
self.kernel_init, |
|
kernel_shape, |
|
self.params_dtype, |
|
kernel_in_axis, |
|
kernel_out_axis, |
|
axes=self.kernel_axes, |
|
) |
|
if self.use_bias: |
|
bias = param_with_axes("bias", self.bias_init, features, self.params_dtype, axes=(self.kernel_axes[-1],)) |
|
kernel = jnp.asarray(kernel, self.dtype) |
|
|
|
contract_ind = tuple(range(0, len(axis))) |
|
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) |
|
if self.use_bias: |
|
bias = jnp.asarray(bias, self.dtype) |
|
|
|
y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:]) |
|
return y |
|
|
|
|
|
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable: |
|
"""Convert a string to an activation function.""" |
|
if fn_or_string == "linear": |
|
return lambda x: x |
|
elif isinstance(fn_or_string, str): |
|
return getattr(nn, fn_or_string) |
|
elif callable(fn_or_string): |
|
return fn_or_string |
|
else: |
|
raise ValueError("don't know how to convert %s to an activation function" % (fn_or_string,)) |
|
|
|
|
|
class MlpBlock(nn.Module): |
|
"""Transformer MLP / feed-forward block. |
|
Attributes: |
|
intermediate_dim: Shared dimension of hidden layers. |
|
activations: Type of activations for each layer. Each element is either |
|
'linear', a string function name in flax.linen, or a function. |
|
kernel_init: Kernel function, passed to the dense layers. |
|
deterministic: Whether the dropout layers should be deterministic. |
|
intermediate_dropout_rate: Dropout rate used after the intermediate layers. |
|
dtype: Type for the dense layer. |
|
""" |
|
|
|
intermediate_dim: int = 2048 |
|
activations: Sequence[Union[str, Callable]] = ("relu",) |
|
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") |
|
intermediate_dropout_rate: float = 0.1 |
|
dtype: Any = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__(self, inputs, decode: bool = False, deterministic: bool = False): |
|
"""Applies Transformer MlpBlock module.""" |
|
|
|
|
|
activations = [] |
|
for idx, act_fn in enumerate(self.activations): |
|
dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" |
|
x = DenseGeneral( |
|
self.intermediate_dim, |
|
dtype=self.dtype, |
|
kernel_init=self.kernel_init, |
|
kernel_axes=("embed", "mlp"), |
|
name=dense_name, |
|
)(inputs) |
|
x = _convert_to_activation_function(act_fn)(x) |
|
activations.append(x) |
|
|
|
|
|
x = functools.reduce(operator.mul, activations) |
|
|
|
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( |
|
x, deterministic=deterministic |
|
) |
|
x = with_sharding_constraint(x, ("batch", "length", "mlp")) |
|
output = DenseGeneral( |
|
inputs.shape[-1], dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=("mlp", "embed"), name="wo" |
|
)(x) |
|
return output |
|
|
|
|
|
class Embed(nn.Module): |
|
"""A parameterized function from integers [0, n) to d-dimensional vectors. |
|
Attributes: |
|
num_embeddings: number of embeddings. |
|
features: number of feature dimensions for each embedding. |
|
dtype: the dtype of the embedding vectors (default: float32). |
|
embedding_init: embedding initializer. |
|
one_hot: performs the gather with a one-hot contraction rather than a true |
|
gather. This is currently needed for SPMD partitioning. |
|
""" |
|
|
|
num_embeddings: int |
|
features: int |
|
cast_input_dtype: Optional[DType] = None |
|
dtype: DType = jnp.float32 |
|
params_dtype: DType = jnp.float32 |
|
attend_dtype: Optional[DType] = None |
|
embedding_init: Initializer = default_embed_init |
|
one_hot: bool = True |
|
embedding: Array = dataclasses.field(init=False) |
|
|
|
def setup(self): |
|
self.embedding = param_with_axes( |
|
"embedding", |
|
self.embedding_init, |
|
(self.num_embeddings, self.features), |
|
self.params_dtype, |
|
axes=("vocab", "embed"), |
|
) |
|
|
|
def __call__(self, inputs: Array) -> Array: |
|
"""Embeds the inputs along the last dimension. |
|
Args: |
|
inputs: input data, all dimensions are considered batch dimensions. |
|
Returns: |
|
Output which is embedded input data. The output shape follows the input, |
|
with an additional `features` dimension appended. |
|
""" |
|
if self.cast_input_dtype: |
|
inputs = inputs.astype(self.cast_input_dtype) |
|
if not jnp.issubdtype(inputs.dtype, jnp.integer): |
|
raise ValueError("Input type must be an integer or unsigned integer.") |
|
if self.one_hot: |
|
iota = lax.iota(jnp.int32, self.num_embeddings) |
|
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) |
|
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) |
|
else: |
|
output = jnp.asarray(self.embedding, self.dtype)[inputs] |
|
output = with_sharding_constraint(output, ("batch", "length", "embed")) |
|
return output |
|
|
|
def attend(self, query: Array) -> Array: |
|
"""Attend over the embedding using a query array. |
|
Args: |
|
query: array with last dimension equal the feature depth `features` of the |
|
embedding. |
|
Returns: |
|
An array with final dim `num_embeddings` corresponding to the batched |
|
inner-product of the array of query vectors against each embedding. |
|
Commonly used for weight-sharing between embeddings and logit transform |
|
in NLP models. |
|
""" |
|
dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype |
|
return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) |
|
|
|
|
|
class RelativePositionBiases(nn.Module): |
|
"""Adds T5-style relative positional embeddings to the attention logits. |
|
Attributes: |
|
num_buckets: Number of buckets to bucket distances between key and query |
|
positions into. |
|
max_distance: Maximum distance before everything is lumped into the last |
|
distance bucket. |
|
num_heads: Number of heads in the attention layer. Each head will get a |
|
different relative position weighting. |
|
dtype: Type of arrays through this module. |
|
embedding_init: initializer for relative embedding table. |
|
""" |
|
|
|
num_buckets: int |
|
max_distance: int |
|
num_heads: int |
|
dtype: Any |
|
embedding_init: Callable[..., Array] = nn.linear.default_embed_init |
|
|
|
@staticmethod |
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): |
|
"""Translate relative position to a bucket number for relative attention. |
|
The relative position is defined as memory_position - query_position, i.e. |
|
the distance in tokens from the attending position to the attended-to |
|
position. If bidirectional=False, then positive relative positions are |
|
invalid. |
|
We use smaller buckets for small absolute relative_position and larger |
|
buckets for larger absolute relative_positions. All relative |
|
positions >=max_distance map to the same bucket. All relative |
|
positions <=-max_distance map to the same bucket. This should allow for |
|
more graceful generalization to longer sequences than the model has been |
|
trained on. |
|
Args: |
|
relative_position: an int32 array |
|
bidirectional: a boolean - whether the attention is bidirectional |
|
num_buckets: an integer |
|
max_distance: an integer |
|
Returns: |
|
a Tensor with the same shape as relative_position, containing int32 |
|
values in the range [0, num_buckets) |
|
""" |
|
ret = 0 |
|
n = -relative_position |
|
if bidirectional: |
|
num_buckets //= 2 |
|
ret += (n < 0).astype(np.int32) * num_buckets |
|
n = np.abs(n) |
|
else: |
|
n = np.maximum(n, 0) |
|
|
|
max_exact = num_buckets // 2 |
|
is_small = n < max_exact |
|
val_if_large = max_exact + ( |
|
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) |
|
/ np.log(max_distance / max_exact) |
|
* (num_buckets - max_exact) |
|
).astype(np.int32) |
|
val_if_large = np.minimum(val_if_large, num_buckets - 1) |
|
ret += np.where(is_small, n, val_if_large) |
|
return ret |
|
|
|
@nn.compact |
|
def __call__(self, qlen, klen, bidirectional=True): |
|
"""Produce relative position embedding attention biases. |
|
Args: |
|
qlen: attention query length. |
|
klen: attention key length. |
|
bidirectional: whether to allow positive memory-query relative position |
|
embeddings. |
|
Returns: |
|
output: `(1, len, q_len, k_len)` attention bias |
|
""" |
|
|
|
|
|
context_position = np.arange(qlen, dtype=jnp.int32)[:, None] |
|
memory_position = np.arange(klen, dtype=jnp.int32)[None, :] |
|
relative_position = memory_position - context_position |
|
rp_bucket = self._relative_position_bucket( |
|
relative_position, |
|
bidirectional=bidirectional, |
|
num_buckets=self.num_buckets, |
|
max_distance=self.max_distance, |
|
) |
|
relative_attention_bias = param_with_axes( |
|
"rel_embedding", |
|
self.embedding_init, |
|
(self.num_heads, self.num_buckets), |
|
jnp.float32, |
|
axes=("heads", "relpos_buckets"), |
|
) |
|
|
|
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) |
|
|
|
|
|
|
|
|
|
|
|
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) |
|
rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) |
|
|
|
values = lax.dot_general( |
|
relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ())) |
|
) |
|
|
|
|
|
return values[jnp.newaxis, ...] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
"""Layer normalization (https://arxiv.org/abs/1607.06450). |
|
Operates on the last axis of the input data. |
|
It normalizes the activations of the layer for each given example in a |
|
batch independently, rather than across a batch like Batch Normalization. |
|
i.e. applies a transformation that maintains the mean activation within |
|
each example close to 0 and the activation standard deviation close to 1. |
|
Attributes: |
|
epsilon: A small float added to variance to avoid dividing by zero. |
|
dtype: the dtype of the computation (default: float32). |
|
use_bias: If True, bias (beta) is added. |
|
use_scale: If True, multiply by scale (gamma). When the next layer is linear |
|
(also e.g. nn.relu), this can be disabled since the scaling will be done |
|
by the next layer. |
|
bias_init: Initializer for bias, by default, zero. |
|
scale_init: Initializer for scale, by default, one. |
|
""" |
|
|
|
epsilon: float = 1e-6 |
|
dtype: Any = jnp.float32 |
|
params_dtype: DType = jnp.float32 |
|
use_bias: bool = True |
|
use_scale: bool = True |
|
bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.zeros |
|
scale_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.ones |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
"""Applies layer normalization on the input. |
|
Args: |
|
x: the inputs |
|
Returns: |
|
Normalized inputs (the same shape as inputs). |
|
""" |
|
x = jnp.asarray(x, jnp.float32) |
|
features = x.shape[-1] |
|
mean = jnp.mean(x, axis=-1, keepdims=True) |
|
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) |
|
var = mean2 - lax.square(mean) |
|
mul = lax.rsqrt(var + self.epsilon) |
|
if self.use_scale: |
|
scale = param_with_axes("scale", self.scale_init, (features,), self.params_dtype, axes=("embed",)) |
|
mul = mul * jnp.asarray(scale, self.dtype) |
|
y = (x - mean) * mul |
|
if self.use_bias: |
|
bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",)) |
|
y = y + jnp.asarray(bias, self.dtype) |
|
return jnp.asarray(y, self.dtype) |
|
|
|
|
|
|
|
|
|
|
|
def make_attention_mask( |
|
query_input: Array, |
|
key_input: Array, |
|
pairwise_fn: Callable = jnp.multiply, |
|
extra_batch_dims: int = 0, |
|
dtype: DType = jnp.float32, |
|
) -> Array: |
|
"""Mask-making helper for attention weights. |
|
In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the |
|
attention weights will be `[batch, heads, len_q, len_kv]` and this |
|
function will produce `[batch, 1, len_q, len_kv]`. |
|
Args: |
|
query_input: a batched, flat input of query_length size |
|
key_input: a batched, flat input of key_length size |
|
pairwise_fn: broadcasting elementwise comparison function |
|
extra_batch_dims: number of extra batch dims to add singleton axes for, none |
|
by default |
|
dtype: mask return dtype |
|
Returns: |
|
A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. |
|
""" |
|
|
|
mask = pairwise_fn( |
|
|
|
jnp.expand_dims(query_input, axis=-1), |
|
|
|
jnp.expand_dims(key_input, axis=-2), |
|
) |
|
|
|
|
|
mask = jnp.expand_dims(mask, axis=-3) |
|
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) |
|
return mask.astype(dtype) |
|
|
|
|
|
def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array: |
|
"""Make a causal mask for self-attention. |
|
In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights |
|
will be `[batch, heads, len, len]` and this function will produce a |
|
causal mask of shape `[batch, 1, len, len]`. |
|
Note that a causal mask does not depend on the values of x; it only depends on |
|
the shape. If x has padding elements, they will not be treated in a special |
|
manner. |
|
Args: |
|
x: input array of shape `[batch, len]` |
|
extra_batch_dims: number of batch dims to add singleton axes for, none by |
|
default |
|
dtype: mask return dtype |
|
Returns: |
|
A `[batch, 1, len, len]` shaped causal mask for 1d attention. |
|
""" |
|
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) |
|
return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype) |
|
|
|
|
|
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): |
|
"""Combine attention masks. |
|
Args: |
|
*masks: set of attention mask arguments to combine, some can be None. |
|
dtype: final mask dtype |
|
Returns: |
|
Combined mask, reduced by logical and, returns None if no masks given. |
|
""" |
|
masks = [m for m in masks if m is not None] |
|
if not masks: |
|
return None |
|
assert all( |
|
(x.ndim == masks[0].ndim for x in masks) |
|
), f"masks must have same rank: {tuple((x.ndim for x in masks))}" |
|
mask, *other_masks = masks |
|
for other_mask in other_masks: |
|
mask = jnp.logical_and(mask, other_mask) |
|
return mask.astype(dtype) |
|
|
|
|
|
def combine_biases(*masks: Optional[Array]): |
|
"""Combine attention biases. |
|
Args: |
|
*masks: set of attention bias arguments to combine, some can be None. |
|
Returns: |
|
Combined mask, reduced by summation, returns None if no masks given. |
|
""" |
|
masks = [m for m in masks if m is not None] |
|
if not masks: |
|
return None |
|
assert all( |
|
(x.ndim == masks[0].ndim for x in masks) |
|
), f"masks must have same rank: {tuple((x.ndim for x in masks))}" |
|
mask, *other_masks = masks |
|
for other_mask in other_masks: |
|
mask = mask + other_mask |
|
return mask |
|
|
|
|
|
def make_decoder_mask( |
|
decoder_target_tokens: Array, |
|
dtype: DType, |
|
decoder_causal_attention: Optional[Array] = None, |
|
decoder_segment_ids: Optional[Array] = None, |
|
) -> Array: |
|
"""Compute the self-attention mask for a decoder. |
|
Decoder mask is formed by combining a causal mask, a padding mask and an |
|
optional packing mask. If decoder_causal_attention is passed, it makes the |
|
masking non-causal for positions that have value of 1. |
|
A prefix LM is applied to a dataset which has a notion of "inputs" and |
|
"targets", e.g., a machine translation task. The inputs and targets are |
|
concatenated to form a new target. `decoder_target_tokens` is the concatenated |
|
decoder output tokens. |
|
The "inputs" portion of the concatenated sequence can attend to other "inputs" |
|
tokens even for those at a later time steps. In order to control this |
|
behavior, `decoder_causal_attention` is necessary. This is a binary mask with |
|
a value of 1 indicating that the position belonged to "inputs" portion of the |
|
original dataset. |
|
Example: |
|
Suppose we have a dataset with two examples. |
|
ds = [{"inputs": [6, 7], "targets": [8]}, |
|
{"inputs": [3, 4], "targets": [5]}] |
|
After the data preprocessing with packing, the two examples are packed into |
|
one example with the following three fields (some fields are skipped for |
|
simplicity). |
|
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] |
|
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] |
|
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] |
|
where each array has [batch, length] shape with batch size being 1. Then, |
|
this function computes the following mask. |
|
mask = [[[[1, 1, 0, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0, 0], |
|
[0, 0, 0, 1, 1, 0, 0], |
|
[0, 0, 0, 1, 1, 0, 0], |
|
[0, 0, 0, 1, 1, 1, 0], |
|
[0, 0, 0, 0, 0, 0, 0]]]] |
|
mask[b, 1, :, :] represents the mask for the example `b` in the batch. |
|
Because mask is for a self-attention layer, the mask's shape is a square of |
|
shape [query length, key length]. |
|
mask[b, 1, i, j] = 1 means that the query token at position i can attend to |
|
the key token at position j. |
|
Args: |
|
decoder_target_tokens: decoder output tokens. [batch, length] |
|
dtype: dtype of the output mask. |
|
decoder_causal_attention: a binary mask indicating which position should |
|
only attend to earlier positions in the sequence. Others will attend |
|
bidirectionally. [batch, length] |
|
decoder_segment_ids: decoder segmentation info for packed examples. [batch, |
|
length] |
|
Returns: |
|
the combined decoder mask. |
|
""" |
|
masks = [] |
|
|
|
|
|
|
|
causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) |
|
|
|
|
|
|
|
if decoder_causal_attention is not None: |
|
|
|
inputs_mask = make_attention_mask( |
|
decoder_causal_attention, decoder_causal_attention, jnp.logical_and, dtype=dtype |
|
) |
|
masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) |
|
else: |
|
masks.append(causal_mask) |
|
|
|
|
|
masks.append(make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) |
|
|
|
|
|
if decoder_segment_ids is not None: |
|
masks.append(make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) |
|
|
|
return combine_masks(*masks, dtype=dtype) |
|
|
|
|
|
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: |
|
""" "Canonicalizes conv padding to a jax.lax supported format.""" |
|
if isinstance(padding, str): |
|
return padding |
|
if isinstance(padding, int): |
|
return [(padding, padding)] * rank |
|
if isinstance(padding, Sequence) and len(padding) == rank: |
|
new_pad = [] |
|
for p in padding: |
|
if isinstance(p, int): |
|
new_pad.append((p, p)) |
|
elif isinstance(p, tuple) and len(p) == 2: |
|
new_pad.append(p) |
|
else: |
|
break |
|
if len(new_pad) == rank: |
|
return new_pad |
|
raise ValueError( |
|
f"Invalid padding format: {padding}, should be str, int," |
|
f" or a sequence of len {rank} where each element is an" |
|
f" int or pair of ints." |
|
) |
|
|
|
|
|
def _conv_dimension_numbers(input_shape): |
|
"""Computes the dimension numbers based on the input shape.""" |
|
ndim = len(input_shape) |
|
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) |
|
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) |
|
out_spec = lhs_spec |
|
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) |
|
|
|
|
|
class _Conv(nn.Module): |
|
"""Convolution Module wrapping `lax.conv_general_dilated[_local]`. |
|
Attributes: |
|
features: number of convolution filters. |
|
kernel_size: shape of the convolutional kernel. For 1D convolution, |
|
the kernel size can be passed as an integer. For all other cases, it must |
|
be a sequence of integers. |
|
strides: an integer or a sequence of `n` integers, representing the |
|
inter-window strides (default: 1). |
|
padding: either the string `'SAME'`, the string `'VALID'`, the string |
|
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, |
|
high)` integer pairs that give the padding to apply before and after each |
|
spatial dimension. A single int is interpeted as applying the same padding |
|
in all dims and passign a single int in a sequence causes the same padding |
|
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will |
|
left-pad the convolution axis, resulting in same-sized output. |
|
input_dilation: an integer or a sequence of `n` integers, giving the |
|
dilation factor to apply in each spatial dimension of `inputs` |
|
(default: 1). Convolution with input dilation `d` is equivalent to |
|
transposed convolution with stride `d`. |
|
kernel_dilation: an integer or a sequence of `n` integers, giving the |
|
dilation factor to apply in each spatial dimension of the convolution |
|
kernel (default: 1). Convolution with kernel dilation |
|
is also known as 'atrous convolution'. |
|
feature_group_count: integer, default 1. If specified divides the input |
|
features into groups. |
|
use_bias: whether to add a bias to the output (default: True). |
|
mask: Optional mask for the weights during masked convolution. The mask must |
|
be the same shape as the convolution weight matrix. |
|
dtype: the dtype of the computation (default: infer from input and params). |
|
params_dtype: the dtype passed to parameter initializers (default: float32). |
|
precision: numerical precision of the computation see `jax.lax.Precision` |
|
for details. |
|
kernel_init: initializer for the convolutional kernel. |
|
bias_init: initializer for the bias. |
|
""" |
|
|
|
features: int |
|
kernel_size: Sequence[int] |
|
strides: Union[None, int, Sequence[int]] = 1 |
|
padding: PaddingLike = "SAME" |
|
input_dilation: Union[None, int, Sequence[int]] = 1 |
|
kernel_dilation: Union[None, int, Sequence[int]] = 1 |
|
feature_group_count: int = 1 |
|
use_bias: bool = True |
|
mask: Optional[Array] = None |
|
dtype: Optional[DType] = None |
|
params_dtype: DType = jnp.float32 |
|
precision: PrecisionLike = None |
|
kernel_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.lecun_normal() |
|
bias_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.zeros |
|
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated |
|
kernel_axes: Tuple[str, ...] = () |
|
|
|
@property |
|
def shared_weights(self) -> bool: |
|
"""Defines whether weights are shared or not between different pixels. |
|
Returns: |
|
`True` to use shared weights in convolution (regular convolution). |
|
`False` to use different weights at different pixels, a.k.a. |
|
"locally connected layer", "unshared convolution", or "local convolution". |
|
""" |
|
... |
|
|
|
@nn.compact |
|
def __call__(self, inputs: Array) -> Array: |
|
"""Applies a (potentially unshared) convolution to the inputs. |
|
Args: |
|
inputs: input data with dimensions (*batch_dims, spatial_dims..., |
|
features). This is the channels-last convention, i.e. NHWC for a 2d |
|
convolution and NDHWC for a 3D convolution. Note: this is different from |
|
the input convention used by `lax.conv_general_dilated`, which puts the |
|
spatial dimensions last. |
|
Note: If the input has more than 1 batch dimension, all batch dimensions |
|
are flattened into a single dimension for the convolution and restored |
|
before returning. In some cases directly vmap'ing the layer may yield |
|
better performance than this default flattening approach. If the input |
|
lacks a batch dimension it will be added for the convolution and removed |
|
n return, an allowance made to enable writing single-example code. |
|
Returns: |
|
The convolved data. |
|
""" |
|
|
|
if isinstance(self.kernel_size, int): |
|
raise TypeError( |
|
"Expected Conv kernel_size to be a" |
|
" tuple/list of integers (eg.: [3, 3]) but got" |
|
f" {self.kernel_size}." |
|
) |
|
else: |
|
kernel_size = tuple(self.kernel_size) |
|
|
|
def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]: |
|
if x is None: |
|
|
|
|
|
x = 1 |
|
if isinstance(x, int): |
|
return (x,) * len(kernel_size) |
|
return tuple(x) |
|
|
|
|
|
num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) |
|
if num_batch_dimensions != 1: |
|
input_batch_shape = inputs.shape[:num_batch_dimensions] |
|
total_batch_size = int(np.prod(input_batch_shape)) |
|
flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:] |
|
inputs = jnp.reshape(inputs, flat_input_shape) |
|
|
|
|
|
strides = maybe_broadcast(self.strides) |
|
input_dilation = maybe_broadcast(self.input_dilation) |
|
kernel_dilation = maybe_broadcast(self.kernel_dilation) |
|
|
|
padding_lax = canonicalize_padding(self.padding, len(kernel_size)) |
|
if padding_lax == "CIRCULAR": |
|
kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)] |
|
zero_pad: List[Tuple[int, int]] = [(0, 0)] |
|
pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] |
|
inputs = jnp.pad(inputs, pads, mode="wrap") |
|
padding_lax = "VALID" |
|
elif padding_lax == "CAUSAL": |
|
if len(kernel_size) != 1: |
|
raise ValueError("Causal padding is only implemented for 1D convolutions.") |
|
left_pad = kernel_dilation[0] * (kernel_size[0] - 1) |
|
pads = [(0, 0), (left_pad, 0), (0, 0)] |
|
inputs = jnp.pad(inputs, pads) |
|
padding_lax = "VALID" |
|
|
|
dimension_numbers = _conv_dimension_numbers(inputs.shape) |
|
in_features = jnp.shape(inputs)[-1] |
|
|
|
if self.shared_weights: |
|
|
|
assert in_features % self.feature_group_count == 0 |
|
kernel_shape = kernel_size + (in_features // self.feature_group_count, self.features) |
|
|
|
else: |
|
if self.feature_group_count != 1: |
|
raise NotImplementedError( |
|
f"`lax.conv_general_dilated_local` does not support " |
|
f"`feature_group_count != 1`, got `{self.feature_group_count}`." |
|
) |
|
|
|
|
|
|
|
conv_output_shape = jax.eval_shape( |
|
lambda lhs, rhs: self.conv_general_dilated( |
|
lhs=lhs, |
|
rhs=rhs, |
|
window_strides=strides, |
|
padding=padding_lax, |
|
dimension_numbers=dimension_numbers, |
|
lhs_dilation=input_dilation, |
|
rhs_dilation=kernel_dilation, |
|
), |
|
inputs, |
|
jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype), |
|
).shape |
|
|
|
|
|
kernel_shape = conv_output_shape[1:-1] + (np.prod(kernel_size) * in_features, self.features) |
|
|
|
if self.mask is not None and self.mask.shape != kernel_shape: |
|
raise ValueError( |
|
"Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}" |
|
) |
|
|
|
kernel = param_with_axes( |
|
"kernel", |
|
self.kernel_init, |
|
kernel_shape, |
|
self.params_dtype, |
|
axes=self.kernel_axes, |
|
) |
|
|
|
if self.mask is not None: |
|
kernel *= self.mask |
|
|
|
if self.use_bias: |
|
if self.shared_weights: |
|
|
|
bias_shape = (self.features,) |
|
else: |
|
|
|
bias_shape = conv_output_shape[1:] |
|
|
|
bias = param_with_axes( |
|
"bias", |
|
self.bias_init, |
|
bias_shape, |
|
self.params_dtype, |
|
axes=(self.kernel_axes[-1],), |
|
) |
|
else: |
|
bias = None |
|
|
|
inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) |
|
if self.shared_weights: |
|
y = self.conv_general_dilated( |
|
inputs, |
|
kernel, |
|
strides, |
|
padding_lax, |
|
lhs_dilation=input_dilation, |
|
rhs_dilation=kernel_dilation, |
|
dimension_numbers=dimension_numbers, |
|
feature_group_count=self.feature_group_count, |
|
precision=self.precision, |
|
) |
|
else: |
|
y = lax.conv_general_dilated_local( |
|
lhs=inputs, |
|
rhs=kernel, |
|
window_strides=strides, |
|
padding=padding_lax, |
|
filter_shape=kernel_size, |
|
lhs_dilation=input_dilation, |
|
rhs_dilation=kernel_dilation, |
|
dimension_numbers=dimension_numbers, |
|
precision=self.precision, |
|
) |
|
|
|
if self.use_bias: |
|
bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) |
|
y += bias |
|
|
|
if num_batch_dimensions != 1: |
|
output_shape = input_batch_shape + y.shape[1:] |
|
y = jnp.reshape(y, output_shape) |
|
return y |
|
|
|
|
|
class Conv(_Conv): |
|
"""Convolution Module wrapping `lax.conv_general_dilated`. |
|
Attributes: |
|
features: number of convolution filters. |
|
kernel_size: shape of the convolutional kernel. For 1D convolution, |
|
the kernel size can be passed as an integer. For all other cases, it must |
|
be a sequence of integers. |
|
strides: an integer or a sequence of `n` integers, representing the |
|
inter-window strides (default: 1). |
|
padding: either the string `'SAME'`, the string `'VALID'`, the string |
|
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, |
|
high)` integer pairs that give the padding to apply before and after each |
|
spatial dimension. A single int is interpeted as applying the same padding |
|
in all dims and passign a single int in a sequence causes the same padding |
|
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will |
|
left-pad the convolution axis, resulting in same-sized output. |
|
input_dilation: an integer or a sequence of `n` integers, giving the |
|
dilation factor to apply in each spatial dimension of `inputs` |
|
(default: 1). Convolution with input dilation `d` is equivalent to |
|
transposed convolution with stride `d`. |
|
kernel_dilation: an integer or a sequence of `n` integers, giving the |
|
dilation factor to apply in each spatial dimension of the convolution |
|
kernel (default: 1). Convolution with kernel dilation |
|
is also known as 'atrous convolution'. |
|
feature_group_count: integer, default 1. If specified divides the input |
|
features into groups. |
|
use_bias: whether to add a bias to the output (default: True). |
|
mask: Optional mask for the weights during masked convolution. The mask must |
|
be the same shape as the convolution weight matrix. |
|
dtype: the dtype of the computation (default: infer from input and params). |
|
params_dtype: the dtype passed to parameter initializers (default: float32). |
|
precision: numerical precision of the computation see `jax.lax.Precision` |
|
for details. |
|
kernel_init: initializer for the convolutional kernel. |
|
bias_init: initializer for the bias. |
|
""" |
|
|
|
@property |
|
def shared_weights(self) -> bool: |
|
return True |