|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Pytorch version of patched decoder.""" |
|
|
|
import dataclasses |
|
import math |
|
from typing import List, Tuple |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def _create_quantiles() -> list[float]: |
|
return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] |
|
|
|
|
|
@dataclasses.dataclass |
|
class TimesFMConfig: |
|
"""Config for initializing timesfm patched_decoder class.""" |
|
|
|
|
|
num_layers: int = 20 |
|
|
|
num_heads: int = 16 |
|
|
|
num_kv_heads: int = 16 |
|
|
|
hidden_size: int = 1280 |
|
|
|
intermediate_size: int = 1280 |
|
|
|
head_dim: int = 80 |
|
|
|
rms_norm_eps: float = 1e-6 |
|
|
|
patch_len: int = 32 |
|
|
|
horizon_len: int = 128 |
|
|
|
quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles) |
|
|
|
pad_val: float = 1123581321.0 |
|
|
|
tolerance: float = 1e-6 |
|
|
|
dtype: str = "bfloat32" |
|
|
|
use_positional_embedding: bool = True |
|
|
|
|
|
def _masked_mean_std( |
|
inputs: torch.Tensor, |
|
padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Calculates mean and standard deviation of `inputs` across axis 1. |
|
|
|
It excludes values where `padding` is 1. |
|
|
|
Args: |
|
inputs: A PyTorch tensor of shape [b, n, p]. |
|
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. |
|
|
|
Returns: |
|
A tuple containing the mean and standard deviation. |
|
We return the statistics of the first patch with more than three non-padded |
|
values. |
|
""" |
|
|
|
pad_sum = torch.sum(1 - padding, dim=2) |
|
|
|
def _get_patch_index(arr: torch.Tensor): |
|
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) |
|
row_sum = (arr >= 3).to(torch.int32).sum(dim=1) |
|
return torch.where(row_sum == 0, arr.shape[1] - 1, indices) |
|
|
|
patch_indices = _get_patch_index(pad_sum) |
|
bidxs = torch.arange(inputs.shape[0]) |
|
|
|
arr = inputs[bidxs, patch_indices, :] |
|
pad = padding[bidxs, patch_indices, :] |
|
|
|
|
|
mask = 1 - pad |
|
|
|
|
|
num_valid_elements = torch.sum(mask, dim=1) |
|
num_valid_elements = torch.where( |
|
num_valid_elements == 0, |
|
torch.tensor(1, |
|
dtype=num_valid_elements.dtype, |
|
device=num_valid_elements.device), |
|
num_valid_elements, |
|
) |
|
|
|
|
|
masked_sum = torch.sum(arr * mask, dim=1) |
|
masked_squared_sum = torch.sum((arr * mask)**2, dim=1) |
|
|
|
|
|
masked_mean = masked_sum / num_valid_elements |
|
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 |
|
masked_var = torch.where( |
|
masked_var < 0.0, |
|
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), |
|
masked_var, |
|
) |
|
masked_std = torch.sqrt(masked_var) |
|
|
|
return masked_mean, masked_std |
|
|
|
|
|
def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: |
|
"""Shifts rows of seq based on the first 0 in each row of the mask. |
|
|
|
Args: |
|
mask: mask tensor of shape [B, N] |
|
seq: seq tensor of shape [B, N, P] |
|
|
|
Returns: |
|
Returns the shifted sequence. |
|
""" |
|
batch_size, num_seq, feature_dim = seq.shape |
|
|
|
new_mask: torch.BoolTensor = mask == 0 |
|
|
|
|
|
indices = new_mask.to(torch.int32).argmax(dim=1) |
|
|
|
|
|
indices[~new_mask.any(dim=1)] = -1 |
|
|
|
|
|
idx_range = (torch.arange(num_seq).to( |
|
seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, |
|
feature_dim)) |
|
|
|
|
|
shifted_idx = (idx_range - indices[:, None, None]) % num_seq |
|
|
|
|
|
shifted_seq = seq.gather(1, shifted_idx) |
|
|
|
return shifted_seq |
|
|
|
|
|
def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: |
|
"""Returns a large negative value for the given dtype.""" |
|
if dtype.is_floating_point: |
|
dtype_max = torch.finfo(dtype).max |
|
else: |
|
dtype_max = torch.iinfo(dtype).max |
|
return torch.tensor(-0.7 * dtype_max, dtype=dtype) |
|
|
|
|
|
def apply_mask_to_logits(logits: torch.Tensor, |
|
mask: torch.Tensor) -> torch.Tensor: |
|
"""Applies a floating-point mask to a set of logits. |
|
|
|
Args: |
|
logits: A torch.Tensor of logit values. |
|
mask: A torch.Tensor (float32) of mask values with the encoding described |
|
in the function documentation. |
|
|
|
Returns: |
|
Masked logits. |
|
""" |
|
|
|
min_value = get_large_negative_number(logits.dtype) |
|
|
|
return torch.where((mask >= min_value * 0.5), logits, min_value) |
|
|
|
|
|
def convert_paddings_to_mask( |
|
paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: |
|
"""Converts binary paddings to a logit mask ready to add to attention matrix. |
|
|
|
Args: |
|
paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding |
|
token. |
|
dtype: data type of the input. |
|
|
|
Returns: |
|
A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. |
|
""" |
|
attention_mask = paddings.detach().clone() |
|
attention_mask = attention_mask[:, None, None, :] |
|
attention_mask *= get_large_negative_number(dtype) |
|
return attention_mask |
|
|
|
|
|
def causal_mask(input_t: torch.Tensor) -> torch.Tensor: |
|
"""Computes and returns causal mask. |
|
|
|
Args: |
|
input_t: A torch.Tensor of shape [B, T, D]. |
|
|
|
Returns: |
|
An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has |
|
already been converted to large negative values. |
|
""" |
|
assert input_t.dtype.is_floating_point, input_t.dtype |
|
large_negative_number = get_large_negative_number(input_t.dtype) |
|
t = input_t.shape[1] |
|
col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) |
|
row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) |
|
mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number |
|
return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device) |
|
) |
|
|
|
|
|
def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
|
"""Merges 2 masks. |
|
|
|
logscale mask is expected but 0/1 mask is also fine. |
|
|
|
Args: |
|
a: torch.Tensor of shape [1|B, 1, 1|T, S]. |
|
b: torch.Tensor of shape [1|B, 1, 1|T, S]. |
|
|
|
Returns: |
|
torch.Tensor of shape [1|B, 1, 1|T, S]. |
|
""" |
|
|
|
def expand_t(key_mask): |
|
query_mask = key_mask.transpose(-1, -2) |
|
return torch.minimum(query_mask, key_mask) |
|
|
|
if a.shape[2] != b.shape[2]: |
|
if a.shape[2] == 1: |
|
a = expand_t(a) |
|
else: |
|
assert b.shape[2] == 1 |
|
b = expand_t(b) |
|
|
|
assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." |
|
return torch.minimum(a, b) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
"""TimesFM residual block.""" |
|
|
|
def __init__( |
|
self, |
|
input_dims, |
|
hidden_dims, |
|
output_dims, |
|
): |
|
super(ResidualBlock, self).__init__() |
|
self.input_dims = input_dims |
|
self.hidden_dims = hidden_dims |
|
self.output_dims = output_dims |
|
|
|
|
|
self.hidden_layer = nn.Sequential( |
|
nn.Linear(input_dims, hidden_dims), |
|
nn.SiLU(), |
|
) |
|
|
|
|
|
self.output_layer = nn.Linear(hidden_dims, output_dims) |
|
|
|
self.residual_layer = nn.Linear(input_dims, output_dims) |
|
|
|
def forward(self, x): |
|
hidden = self.hidden_layer(x) |
|
output = self.output_layer(hidden) |
|
residual = self.residual_layer(x) |
|
return output + residual |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
"""Pax rms norm in pytorch.""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
eps: float = 1e-6, |
|
add_unit_offset: bool = False, |
|
): |
|
super().__init__() |
|
self.eps = eps |
|
self.add_unit_offset = add_unit_offset |
|
self.weight = nn.Parameter(torch.zeros(dim)) |
|
|
|
def _norm(self, x): |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
output = self._norm(x.float()) |
|
if self.add_unit_offset: |
|
output = output * (1 + self.weight.float()) |
|
else: |
|
output = output * self.weight.float() |
|
return output.type_as(x) |
|
|
|
|
|
class TransformerMLP(nn.Module): |
|
"""Pax transformer MLP in pytorch.""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
): |
|
super().__init__() |
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size) |
|
self.down_proj = nn.Linear(intermediate_size, hidden_size) |
|
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) |
|
|
|
def forward(self, x, paddings=None): |
|
gate_inp = self.layer_norm(x) |
|
gate = self.gate_proj(gate_inp) |
|
gate = F.relu(gate) |
|
outputs = self.down_proj(gate) |
|
if paddings is not None: |
|
outputs = outputs * (1.0 - paddings[:, :, None]) |
|
return outputs + x |
|
|
|
|
|
class TimesFMAttention(nn.Module): |
|
"""Implements the attention used in TimesFM.""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
num_heads: int, |
|
num_kv_heads: int, |
|
head_dim: int, |
|
): |
|
super().__init__() |
|
|
|
self.num_heads = num_heads |
|
self.num_kv_heads = num_kv_heads |
|
|
|
assert self.num_heads % self.num_kv_heads == 0 |
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads |
|
|
|
self.hidden_size = hidden_size |
|
self.head_dim = head_dim |
|
|
|
self.q_size = self.num_heads * self.head_dim |
|
self.kv_size = self.num_kv_heads * self.head_dim |
|
self.scaling = nn.Parameter( |
|
torch.empty((self.head_dim,), dtype=torch.float32),) |
|
|
|
self.qkv_proj = nn.Linear( |
|
self.hidden_size, |
|
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim, |
|
) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) |
|
|
|
def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: |
|
|
|
r_softplus_0 = 1.442695041 |
|
softplus_func = torch.nn.Softplus() |
|
scale = r_softplus_0 / math.sqrt(self.head_dim) |
|
scale = scale * softplus_func(self.scaling) |
|
return query * scale[None, None, None, :] |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
mask: torch.Tensor, |
|
kv_write_indices: torch.Tensor | None = None, |
|
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, |
|
) -> torch.Tensor: |
|
hidden_states_shape = hidden_states.shape |
|
assert len(hidden_states_shape) == 3 |
|
|
|
batch_size, input_len, _ = hidden_states_shape |
|
|
|
qkv = self.qkv_proj(hidden_states) |
|
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
|
|
|
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) |
|
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) |
|
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) |
|
xq = self._per_dim_scaling(xq) |
|
|
|
|
|
|
|
if kv_cache is not None and kv_write_indices is not None: |
|
k_cache, v_cache = kv_cache |
|
k_cache.index_copy_(1, kv_write_indices, xk) |
|
v_cache.index_copy_(1, kv_write_indices, xv) |
|
|
|
key = k_cache |
|
value = v_cache |
|
else: |
|
key = xk |
|
value = xv |
|
if self.num_kv_heads != self.num_heads: |
|
|
|
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) |
|
value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) |
|
|
|
|
|
q = xq.transpose(1, 2) |
|
|
|
k = key.transpose(1, 2) |
|
v = value.transpose(1, 2) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(2, 3)) |
|
scores = scores + mask |
|
scores = F.softmax(scores.float(), dim=-1).type_as(q) |
|
|
|
|
|
output = torch.matmul(scores, v) |
|
|
|
|
|
|
|
output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) |
|
output = self.o_proj(output) |
|
return scores, output |
|
|
|
|
|
class TimesFMDecoderLayer(nn.Module): |
|
"""Transformer layer.""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
num_heads: int, |
|
num_kv_heads: int, |
|
head_dim: int, |
|
rms_norm_eps: float = 1e-6, |
|
): |
|
super().__init__() |
|
self.self_attn = TimesFMAttention( |
|
hidden_size=hidden_size, |
|
num_heads=num_heads, |
|
num_kv_heads=num_kv_heads, |
|
head_dim=head_dim, |
|
) |
|
self.mlp = TransformerMLP( |
|
hidden_size=hidden_size, |
|
intermediate_size=intermediate_size, |
|
) |
|
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
mask: torch.Tensor, |
|
paddings: torch.Tensor, |
|
kv_write_indices: torch.Tensor | None = None, |
|
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, |
|
) -> torch.Tensor: |
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
scores, hidden_states = self.self_attn( |
|
hidden_states=hidden_states, |
|
mask=mask, |
|
kv_write_indices=kv_write_indices, |
|
kv_cache=kv_cache, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
hidden_states = self.mlp(hidden_states, paddings=paddings) |
|
|
|
return scores, hidden_states |
|
|
|
|
|
class StackedDecoder(nn.Module): |
|
"""Stacked transformer layer.""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
num_heads: int, |
|
num_kv_heads: int, |
|
head_dim: int, |
|
num_layers: int, |
|
rms_norm_eps: float = 1e-6, |
|
): |
|
super().__init__() |
|
|
|
self.layers = nn.ModuleList() |
|
for _ in range(num_layers): |
|
self.layers.append( |
|
TimesFMDecoderLayer( |
|
hidden_size=hidden_size, |
|
intermediate_size=intermediate_size, |
|
num_heads=num_heads, |
|
num_kv_heads=num_kv_heads, |
|
head_dim=head_dim, |
|
rms_norm_eps=rms_norm_eps, |
|
)) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
paddings: torch.Tensor, |
|
kv_write_indices: torch.Tensor | None = None, |
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, |
|
) -> torch.Tensor: |
|
padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) |
|
atten_mask = causal_mask(hidden_states) |
|
mask = merge_masks(padding_mask, atten_mask) |
|
for i in range(len(self.layers)): |
|
layer = self.layers[i] |
|
kv_cache = kv_caches[i] if kv_caches is not None else None |
|
_, hidden_states = layer( |
|
hidden_states=hidden_states, |
|
mask=mask, |
|
paddings=paddings, |
|
kv_write_indices=kv_write_indices, |
|
kv_cache=kv_cache, |
|
) |
|
return hidden_states |
|
|
|
|
|
class PositionalEmbedding(torch.nn.Module): |
|
"""Generates position embedding for a given 1-d sequence. |
|
|
|
Attributes: |
|
min_timescale: Start of the geometric index. Determines the periodicity of |
|
the added signal. |
|
max_timescale: End of the geometric index. Determines the frequency of the |
|
added signal. |
|
embedding_dims: Dimension of the embedding to be generated. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embedding_dims: int, |
|
min_timescale: int = 1, |
|
max_timescale: int = 10_000, |
|
) -> None: |
|
super().__init__() |
|
self.min_timescale = min_timescale |
|
self.max_timescale = max_timescale |
|
self.embedding_dims = embedding_dims |
|
|
|
def forward(self, seq_length=None, position=None): |
|
"""Generates a Tensor of sinusoids with different frequencies. |
|
|
|
Args: |
|
seq_length: an optional Python int defining the output sequence length. |
|
if the `position` argument is specified. |
|
position: [B, seq_length], optional position for each token in the |
|
sequence, only required when the sequence is packed. |
|
|
|
Returns: |
|
[B, seqlen, D] if `position` is specified, else [1, seqlen, D] |
|
""" |
|
if position is None: |
|
assert seq_length is not None |
|
|
|
position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) |
|
else: |
|
assert position.ndim == 2, position.shape |
|
|
|
num_timescales = self.embedding_dims // 2 |
|
log_timescale_increment = math.log( |
|
float(self.max_timescale) / float(self.min_timescale)) / max( |
|
num_timescales - 1, 1) |
|
inv_timescales = self.min_timescale * torch.exp( |
|
torch.arange(num_timescales, dtype=torch.float32) * |
|
-log_timescale_increment) |
|
scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze( |
|
0) |
|
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) |
|
|
|
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) |
|
return signal |
|
|
|
|
|
class PatchedTimeSeriesDecoder(nn.Module): |
|
"""Patched time-series decoder.""" |
|
|
|
def __init__(self, config: TimesFMConfig): |
|
super().__init__() |
|
self.config = config |
|
self.input_ff_layer = ResidualBlock( |
|
input_dims=2 * config.patch_len, |
|
output_dims=config.hidden_size, |
|
hidden_dims=config.intermediate_size, |
|
) |
|
self.freq_emb = nn.Embedding(num_embeddings=3, |
|
embedding_dim=config.hidden_size) |
|
self.horizon_ff_layer = ResidualBlock( |
|
input_dims=config.hidden_size, |
|
output_dims=config.horizon_len * (1 + len(config.quantiles)), |
|
hidden_dims=config.intermediate_size, |
|
) |
|
self.stacked_transformer = StackedDecoder( |
|
hidden_size=self.config.hidden_size, |
|
intermediate_size=self.config.intermediate_size, |
|
num_heads=self.config.num_heads, |
|
num_kv_heads=self.config.num_kv_heads, |
|
head_dim=self.config.head_dim, |
|
num_layers=self.config.num_layers, |
|
rms_norm_eps=self.config.rms_norm_eps, |
|
) |
|
if self.config.use_positional_embedding: |
|
self.position_emb = PositionalEmbedding(self.config.hidden_size) |
|
|
|
def _forward_transform( |
|
self, inputs: torch.Tensor, patched_pads: torch.Tensor |
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
|
"""Input is of shape [B, N, P].""" |
|
mu, sigma = _masked_mean_std(inputs, patched_pads) |
|
sigma = torch.where( |
|
sigma < self.config.tolerance, |
|
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), |
|
sigma, |
|
) |
|
|
|
|
|
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] |
|
outputs = torch.where( |
|
torch.abs(inputs - self.config.pad_val) < self.config.tolerance, |
|
torch.tensor(self.config.pad_val, |
|
dtype=outputs.dtype, |
|
device=outputs.device), |
|
outputs, |
|
) |
|
return outputs, (mu, sigma) |
|
|
|
def _reverse_transform( |
|
self, outputs: torch.Tensor, stats: tuple[torch.Tensor, |
|
torch.Tensor]) -> torch.Tensor: |
|
"""Output is of shape [B, N, P, Q].""" |
|
mu, sigma = stats |
|
return outputs * sigma[:, None, None, None] + mu[:, None, None, None] |
|
|
|
def _preprocess_input( |
|
self, |
|
input_ts: torch.Tensor, |
|
input_padding: torch.Tensor, |
|
) -> tuple[ |
|
torch.Tensor, |
|
torch.Tensor, |
|
tuple[torch.Tensor, torch.Tensor] | None, |
|
torch.Tensor, |
|
]: |
|
"""Preprocess input for stacked transformer.""" |
|
|
|
|
|
bsize = input_ts.shape[0] |
|
patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) |
|
patched_pads = input_padding.view(bsize, -1, self.config.patch_len) |
|
|
|
patched_inputs = torch.where( |
|
torch.abs(patched_pads - 1.0) < self.config.tolerance, |
|
torch.tensor(0.0, |
|
dtype=patched_inputs.dtype, |
|
device=patched_inputs.device), |
|
patched_inputs, |
|
) |
|
patched_pads = torch.where( |
|
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, |
|
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), |
|
patched_pads, |
|
) |
|
patched_inputs, stats = self._forward_transform(patched_inputs, |
|
patched_pads) |
|
|
|
|
|
patched_inputs = patched_inputs * (1.0 - patched_pads) |
|
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) |
|
model_input = self.input_ff_layer(concat_inputs) |
|
|
|
|
|
patched_padding = torch.min(patched_pads, |
|
dim=-1)[0] |
|
if self.config.use_positional_embedding: |
|
pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) |
|
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) |
|
pos_emb = _shift_padded_seq(patched_padding, pos_emb) |
|
model_input += pos_emb |
|
|
|
return model_input, patched_padding, stats, patched_inputs |
|
|
|
def _postprocess_output( |
|
self, |
|
model_output: torch.Tensor, |
|
num_outputs: int, |
|
stats: tuple[torch.Tensor, torch.Tensor], |
|
) -> torch.Tensor: |
|
"""Postprocess output of stacked transformer.""" |
|
|
|
|
|
output_ts = self.horizon_ff_layer(model_output) |
|
|
|
|
|
b, n, _ = output_ts.shape |
|
output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) |
|
|
|
return self._reverse_transform(output_ts, stats) |
|
|
|
def forward( |
|
self, |
|
input_ts: torch.Tensor, |
|
input_padding: torch.LongTensor, |
|
freq: torch.Tensor, |
|
) -> torch.Tensor: |
|
num_outputs = len(self.config.quantiles) + 1 |
|
model_input, patched_padding, stats, _ = self._preprocess_input( |
|
input_ts=input_ts, |
|
input_padding=input_padding, |
|
) |
|
f_emb = self.freq_emb(freq) |
|
model_input += f_emb |
|
model_output = self.stacked_transformer(model_input, patched_padding) |
|
|
|
output_ts = self._postprocess_output(model_output, num_outputs, stats) |
|
return output_ts |
|
|
|
def decode( |
|
self, |
|
input_ts: torch.Tensor, |
|
paddings: torch.Tensor, |
|
freq: torch.LongTensor, |
|
horizon_len: int, |
|
output_patch_len: int | None = None, |
|
max_len: int = 512, |
|
return_forecast_on_context: bool = False, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Auto-regressive decoding without caching. |
|
|
|
Args: |
|
input_ts: input time-series and paddings. Time-series shape B x C. |
|
paddings: padding shape B x (C + H) where H is the prediction length. |
|
freq: frequency shape B x 1 |
|
horizon_len: prediction length. |
|
output_patch_len: output length to be fetched from one step of |
|
auto-regressive decoding. |
|
max_len: maximum training context length. |
|
return_forecast_on_context: whether to return the model forecast on the |
|
context except the first input patch. |
|
|
|
Returns: |
|
Tuple of two forecasting results: |
|
- Point (mean) output predictions as a tensor with shape B x H'. |
|
- Full predictions (mean and quantiles) as a tensor with shape |
|
B x H' x (1 + # quantiles). |
|
In particular, if return_forecast_on_context is True, H' is H plus |
|
the forecastable context length, i.e. context_len - (first) patch_len. |
|
""" |
|
final_out = input_ts |
|
context_len = final_out.shape[1] |
|
full_outputs = [] |
|
if paddings.shape[1] != final_out.shape[1] + horizon_len: |
|
raise ValueError( |
|
"Length of paddings must match length of input + horizon_len:" |
|
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}") |
|
if output_patch_len is None: |
|
output_patch_len = self.config.horizon_len |
|
num_decode_patches = (horizon_len + output_patch_len - |
|
1) // output_patch_len |
|
for step_index in range(num_decode_patches): |
|
current_padding = paddings[:, 0:final_out.shape[1]] |
|
input_ts = final_out[:, -max_len:] |
|
input_padding = current_padding[:, -max_len:] |
|
fprop_outputs = self(input_ts, input_padding, freq) |
|
if return_forecast_on_context and step_index == 0: |
|
|
|
|
|
new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :] |
|
new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, |
|
new_full_ts.size(3)) |
|
|
|
full_outputs.append(new_full_ts) |
|
|
|
|
|
new_ts = fprop_outputs[:, -1, :output_patch_len, 0] |
|
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] |
|
|
|
full_outputs.append(new_full_ts) |
|
final_out = torch.concatenate([final_out, new_ts], axis=-1) |
|
|
|
if return_forecast_on_context: |
|
|
|
full_outputs = torch.concatenate( |
|
full_outputs, |
|
axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :] |
|
else: |
|
|
|
full_outputs = torch.concatenate(full_outputs, axis=1)[:, |
|
0:horizon_len, :] |
|
|
|
return (full_outputs[:, :, 0], full_outputs) |
|
|
|
class TimesFM(nn.Module): |
|
|
|
def __init__(self, lookback: int = 512, lookahead: int = 96, context_len: int = 512): |
|
|
|
super(TimesFM, self).__init__() |
|
|
|
self.timesfm = PatchedTimeSeriesDecoder(TimesFMConfig()) |
|
self.lookback, self.lookahead = lookback, lookahead |
|
self.context_len = context_len |
|
|
|
def load_state_dict(self, state_dict, *args, **kwargs): |
|
|
|
return self.timesfm.load_state_dict(state_dict, *args, **kwargs) |
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
|
return self.timesfm.state_dict(*args, **kwargs) |
|
|
|
def pad_tensor(self, x): |
|
|
|
B, L = x.shape |
|
device = x.device |
|
dtype = x.dtype |
|
|
|
if L < self.context_len: |
|
padded_input = torch.zeros((B, self.context_len), device=device, dtype=dtype) |
|
padded_input[:, -L:] = x |
|
padding = torch.ones((B, self.context_len), device=device, dtype=dtype) |
|
padding[:, -L:] = 0 |
|
else: |
|
padded_input = x[:, -self.context_len:] |
|
padding = torch.zeros((B, self.context_len), device=device, dtype=dtype) |
|
|
|
freq = torch.zeros((B, 1), device=device, dtype=torch.long) |
|
|
|
return padded_input, torch.cat((padding,torch.zeros((B,self.lookahead),device=device,dtype=dtype)),dim=-1), freq |
|
|
|
def forward(self, x): |
|
|
|
padded_inp, padding, freq = self.pad_tensor(x) |
|
return self.timesfm.decode(padded_inp,padding,freq,self.lookahead)[0] |