Shourya Bose commited on
Commit
a9073bb
·
1 Parent(s): 4cc7625

add timefm weights

Browse files
README.md CHANGED
@@ -14,6 +14,14 @@ When using the companion [dataset](https://huggingface.co/datasets/APPFL/Illinoi
14
  - All models accept normalized inputs and produce normalized outputs, i.e. set `normalize = True` when generating the datasets.
15
  - For Transformer, Autoformer, Informer, and TimesNet set `transformer = True`, while for LSTM, LSTNet, and PatchTST set `transformer = False`.
16
 
 
 
 
 
 
 
 
 
17
  ## Credits
18
 
19
  Some model definitions have been adapted from the code provided in the [TSLib Library](https://github.com/thuml/Time-Series-Library).
 
14
  - All models accept normalized inputs and produce normalized outputs, i.e. set `normalize = True` when generating the datasets.
15
  - For Transformer, Autoformer, Informer, and TimesNet set `transformer = True`, while for LSTM, LSTNet, and PatchTST set `transformer = False`.
16
 
17
+ ## Packages
18
+
19
+ Executing the code only requires `numpy` and `torch` (PyTorch) packages. You can either have them in your Python base installation, or use a `conda` environment.
20
+
21
+ ## Example
22
+
23
+ In order to see how to use the model definitions and load the weights into them, see `example.py`.
24
+
25
  ## Credits
26
 
27
  Some model definitions have been adapted from the code provided in the [TSLib Library](https://github.com/thuml/Time-Series-Library).
example.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ # import models
5
+ from models.LSTM import LSTM
6
+ from models.LSTNet import LSTNet
7
+ from models.Transformer import Transformer
8
+ from models.Autoformer import Autoformer
9
+ from models.Informer import Informer
10
+ from models.PatchTST import PatchTST
11
+ from models.TimesNet import TimesNet
12
+ from models.TimesFM import TimesFM
13
+
14
+ # import keyword args
15
+ from model_kwargs import *
16
+
17
+ # set lookback and lookahead. lookback is fixed to 512, while lookahead can be one among 4, 48, 96
18
+ # heterogeneity can be 'HET' or 'HOM'
19
+ lookback, lookahead, heterogeneity = 512, 48, 'HET'
20
+
21
+ if __name__ == "__main__":
22
+
23
+ models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM]
24
+ kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs]
25
+
26
+ # loop over models and their keyword functions
27
+ for model_class, kw_fn in zip(models,kw_fns):
28
+ # load an object of the model class
29
+ model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead))
30
+ # load the weight in the model
31
+ result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu'))
32
+ # print the outcome
33
+ print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.")
model_kwargs.py CHANGED
@@ -63,4 +63,10 @@ patchtst_kwargs = lambda lookback,lookahead:{
63
  'd_model': 32*4,
64
  'data_idx': [0,3,4,5,6,7],
65
  'time_idx': [1,2]
 
 
 
 
 
 
66
  }
 
63
  'd_model': 32*4,
64
  'data_idx': [0,3,4,5,6,7],
65
  'time_idx': [1,2]
66
+ }
67
+
68
+ timesfm_kwargs = lambda lookback, lookahead:{
69
+ 'lookback': lookback,
70
+ 'lookahead': lookahead,
71
+ 'context_len': 512
72
  }
models/TimesFM.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Pytorch version of patched decoder."""
15
+
16
+ import dataclasses
17
+ import math
18
+ from typing import List, Tuple
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+
23
+
24
+ def _create_quantiles() -> list[float]:
25
+ return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class TimesFMConfig:
30
+ """Config for initializing timesfm patched_decoder class."""
31
+
32
+ # The number of blocks in the model.
33
+ num_layers: int = 20
34
+ # The number of attention heads used in the attention layers of the model.
35
+ num_heads: int = 16
36
+ # The number of key-value heads for implementing attention.
37
+ num_kv_heads: int = 16
38
+ # The hidden size of the model.
39
+ hidden_size: int = 1280
40
+ # The dimension of the MLP representations.
41
+ intermediate_size: int = 1280
42
+ # The number of head dimensions.
43
+ head_dim: int = 80
44
+ # The epsilon used by the rms normalization layers.
45
+ rms_norm_eps: float = 1e-6
46
+ # Patch length
47
+ patch_len: int = 32
48
+ # Horizon length
49
+ horizon_len: int = 128
50
+ # quantiles
51
+ quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles)
52
+ # Padding value
53
+ pad_val: float = 1123581321.0
54
+ # Tolerance
55
+ tolerance: float = 1e-6
56
+ # The dtype of the weights.
57
+ dtype: str = "bfloat32"
58
+ # use positional embedding
59
+ use_positional_embedding: bool = True
60
+
61
+
62
+ def _masked_mean_std(
63
+ inputs: torch.Tensor,
64
+ padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
65
+ """Calculates mean and standard deviation of `inputs` across axis 1.
66
+
67
+ It excludes values where `padding` is 1.
68
+
69
+ Args:
70
+ inputs: A PyTorch tensor of shape [b, n, p].
71
+ padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
72
+
73
+ Returns:
74
+ A tuple containing the mean and standard deviation.
75
+ We return the statistics of the first patch with more than three non-padded
76
+ values.
77
+ """
78
+ # Selecting the first patch with more than 3 unpadded values.
79
+ pad_sum = torch.sum(1 - padding, dim=2)
80
+
81
+ def _get_patch_index(arr: torch.Tensor):
82
+ indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
83
+ row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
84
+ return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
85
+
86
+ patch_indices = _get_patch_index(pad_sum)
87
+ bidxs = torch.arange(inputs.shape[0])
88
+
89
+ arr = inputs[bidxs, patch_indices, :]
90
+ pad = padding[bidxs, patch_indices, :]
91
+
92
+ # Create a mask where padding is 0
93
+ mask = 1 - pad
94
+
95
+ # Calculate the number of valid elements
96
+ num_valid_elements = torch.sum(mask, dim=1)
97
+ num_valid_elements = torch.where(
98
+ num_valid_elements == 0,
99
+ torch.tensor(1,
100
+ dtype=num_valid_elements.dtype,
101
+ device=num_valid_elements.device),
102
+ num_valid_elements,
103
+ )
104
+
105
+ # Calculate the masked sum and squared sum
106
+ masked_sum = torch.sum(arr * mask, dim=1)
107
+ masked_squared_sum = torch.sum((arr * mask)**2, dim=1)
108
+
109
+ # Calculate the masked mean and standard deviation
110
+ masked_mean = masked_sum / num_valid_elements
111
+ masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
112
+ masked_var = torch.where(
113
+ masked_var < 0.0,
114
+ torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
115
+ masked_var,
116
+ )
117
+ masked_std = torch.sqrt(masked_var)
118
+
119
+ return masked_mean, masked_std
120
+
121
+
122
+ def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
123
+ """Shifts rows of seq based on the first 0 in each row of the mask.
124
+
125
+ Args:
126
+ mask: mask tensor of shape [B, N]
127
+ seq: seq tensor of shape [B, N, P]
128
+
129
+ Returns:
130
+ Returns the shifted sequence.
131
+ """
132
+ batch_size, num_seq, feature_dim = seq.shape
133
+
134
+ new_mask: torch.BoolTensor = mask == 0
135
+
136
+ # Use argmax to find the first True value in each row
137
+ indices = new_mask.to(torch.int32).argmax(dim=1)
138
+
139
+ # Handle rows with all zeros
140
+ indices[~new_mask.any(dim=1)] = -1
141
+
142
+ # Create index ranges for each sequence in the batch
143
+ idx_range = (torch.arange(num_seq).to(
144
+ seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,
145
+ feature_dim))
146
+
147
+ # Calculate shifted indices for each element in each sequence
148
+ shifted_idx = (idx_range - indices[:, None, None]) % num_seq
149
+
150
+ # Gather values from seq using shifted indices
151
+ shifted_seq = seq.gather(1, shifted_idx)
152
+
153
+ return shifted_seq
154
+
155
+
156
+ def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
157
+ """Returns a large negative value for the given dtype."""
158
+ if dtype.is_floating_point:
159
+ dtype_max = torch.finfo(dtype).max
160
+ else:
161
+ dtype_max = torch.iinfo(dtype).max
162
+ return torch.tensor(-0.7 * dtype_max, dtype=dtype)
163
+
164
+
165
+ def apply_mask_to_logits(logits: torch.Tensor,
166
+ mask: torch.Tensor) -> torch.Tensor:
167
+ """Applies a floating-point mask to a set of logits.
168
+
169
+ Args:
170
+ logits: A torch.Tensor of logit values.
171
+ mask: A torch.Tensor (float32) of mask values with the encoding described
172
+ in the function documentation.
173
+
174
+ Returns:
175
+ Masked logits.
176
+ """
177
+
178
+ min_value = get_large_negative_number(logits.dtype)
179
+
180
+ return torch.where((mask >= min_value * 0.5), logits, min_value)
181
+
182
+
183
+ def convert_paddings_to_mask(
184
+ paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
185
+ """Converts binary paddings to a logit mask ready to add to attention matrix.
186
+
187
+ Args:
188
+ paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding
189
+ token.
190
+ dtype: data type of the input.
191
+
192
+ Returns:
193
+ A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.
194
+ """
195
+ attention_mask = paddings.detach().clone()
196
+ attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis
197
+ attention_mask *= get_large_negative_number(dtype)
198
+ return attention_mask
199
+
200
+
201
+ def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
202
+ """Computes and returns causal mask.
203
+
204
+ Args:
205
+ input_t: A torch.Tensor of shape [B, T, D].
206
+
207
+ Returns:
208
+ An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has
209
+ already been converted to large negative values.
210
+ """
211
+ assert input_t.dtype.is_floating_point, input_t.dtype
212
+ large_negative_number = get_large_negative_number(input_t.dtype)
213
+ t = input_t.shape[1]
214
+ col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)
215
+ row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)
216
+ mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number
217
+ return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device)
218
+ ) # Equivalent to jnp.newaxis
219
+
220
+
221
+ def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
222
+ """Merges 2 masks.
223
+
224
+ logscale mask is expected but 0/1 mask is also fine.
225
+
226
+ Args:
227
+ a: torch.Tensor of shape [1|B, 1, 1|T, S].
228
+ b: torch.Tensor of shape [1|B, 1, 1|T, S].
229
+
230
+ Returns:
231
+ torch.Tensor of shape [1|B, 1, 1|T, S].
232
+ """
233
+
234
+ def expand_t(key_mask):
235
+ query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose
236
+ return torch.minimum(query_mask, key_mask)
237
+
238
+ if a.shape[2] != b.shape[2]:
239
+ if a.shape[2] == 1:
240
+ a = expand_t(a)
241
+ else:
242
+ assert b.shape[2] == 1
243
+ b = expand_t(b)
244
+
245
+ assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}."
246
+ return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum
247
+
248
+
249
+ class ResidualBlock(nn.Module):
250
+ """TimesFM residual block."""
251
+
252
+ def __init__(
253
+ self,
254
+ input_dims,
255
+ hidden_dims,
256
+ output_dims,
257
+ ):
258
+ super(ResidualBlock, self).__init__()
259
+ self.input_dims = input_dims
260
+ self.hidden_dims = hidden_dims
261
+ self.output_dims = output_dims
262
+
263
+ # Hidden Layer
264
+ self.hidden_layer = nn.Sequential(
265
+ nn.Linear(input_dims, hidden_dims),
266
+ nn.SiLU(),
267
+ )
268
+
269
+ # Output Layer
270
+ self.output_layer = nn.Linear(hidden_dims, output_dims)
271
+ # Residual Layer
272
+ self.residual_layer = nn.Linear(input_dims, output_dims)
273
+
274
+ def forward(self, x):
275
+ hidden = self.hidden_layer(x)
276
+ output = self.output_layer(hidden)
277
+ residual = self.residual_layer(x)
278
+ return output + residual
279
+
280
+
281
+ class RMSNorm(torch.nn.Module):
282
+ """Pax rms norm in pytorch."""
283
+
284
+ def __init__(
285
+ self,
286
+ dim: int,
287
+ eps: float = 1e-6,
288
+ add_unit_offset: bool = False,
289
+ ):
290
+ super().__init__()
291
+ self.eps = eps
292
+ self.add_unit_offset = add_unit_offset
293
+ self.weight = nn.Parameter(torch.zeros(dim))
294
+
295
+ def _norm(self, x):
296
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
297
+
298
+ def forward(self, x):
299
+ output = self._norm(x.float())
300
+ if self.add_unit_offset:
301
+ output = output * (1 + self.weight.float())
302
+ else:
303
+ output = output * self.weight.float()
304
+ return output.type_as(x)
305
+
306
+
307
+ class TransformerMLP(nn.Module):
308
+ """Pax transformer MLP in pytorch."""
309
+
310
+ def __init__(
311
+ self,
312
+ hidden_size: int,
313
+ intermediate_size: int,
314
+ ):
315
+ super().__init__()
316
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
317
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
318
+ self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
319
+
320
+ def forward(self, x, paddings=None):
321
+ gate_inp = self.layer_norm(x)
322
+ gate = self.gate_proj(gate_inp)
323
+ gate = F.relu(gate)
324
+ outputs = self.down_proj(gate)
325
+ if paddings is not None:
326
+ outputs = outputs * (1.0 - paddings[:, :, None])
327
+ return outputs + x
328
+
329
+
330
+ class TimesFMAttention(nn.Module):
331
+ """Implements the attention used in TimesFM."""
332
+
333
+ def __init__(
334
+ self,
335
+ hidden_size: int,
336
+ num_heads: int,
337
+ num_kv_heads: int,
338
+ head_dim: int,
339
+ ):
340
+ super().__init__()
341
+
342
+ self.num_heads = num_heads
343
+ self.num_kv_heads = num_kv_heads
344
+
345
+ assert self.num_heads % self.num_kv_heads == 0
346
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
347
+
348
+ self.hidden_size = hidden_size
349
+ self.head_dim = head_dim
350
+
351
+ self.q_size = self.num_heads * self.head_dim
352
+ self.kv_size = self.num_kv_heads * self.head_dim
353
+ self.scaling = nn.Parameter(
354
+ torch.empty((self.head_dim,), dtype=torch.float32),)
355
+
356
+ self.qkv_proj = nn.Linear(
357
+ self.hidden_size,
358
+ (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
359
+ )
360
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
361
+
362
+ def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
363
+ # [batch_size, n_local_heads, input_len, head_dim]
364
+ r_softplus_0 = 1.442695041
365
+ softplus_func = torch.nn.Softplus()
366
+ scale = r_softplus_0 / math.sqrt(self.head_dim)
367
+ scale = scale * softplus_func(self.scaling)
368
+ return query * scale[None, None, None, :]
369
+
370
+ def forward(
371
+ self,
372
+ hidden_states: torch.Tensor,
373
+ mask: torch.Tensor,
374
+ kv_write_indices: torch.Tensor | None = None,
375
+ kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
376
+ ) -> torch.Tensor:
377
+ hidden_states_shape = hidden_states.shape
378
+ assert len(hidden_states_shape) == 3
379
+
380
+ batch_size, input_len, _ = hidden_states_shape
381
+
382
+ qkv = self.qkv_proj(hidden_states)
383
+ xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
384
+
385
+ xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
386
+ xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
387
+ xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
388
+ xq = self._per_dim_scaling(xq)
389
+
390
+ # Write new kv cache.
391
+ # [batch_size, input_len, n_local_kv_heads, head_dim]
392
+ if kv_cache is not None and kv_write_indices is not None:
393
+ k_cache, v_cache = kv_cache
394
+ k_cache.index_copy_(1, kv_write_indices, xk)
395
+ v_cache.index_copy_(1, kv_write_indices, xv)
396
+
397
+ key = k_cache
398
+ value = v_cache
399
+ else:
400
+ key = xk
401
+ value = xv
402
+ if self.num_kv_heads != self.num_heads:
403
+ # [batch_size, max_seq_len, n_local_heads, head_dim]
404
+ key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
405
+ value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)
406
+
407
+ # [batch_size, n_local_heads, input_len, head_dim]
408
+ q = xq.transpose(1, 2)
409
+ # [batch_size, n_local_heads, max_seq_len, head_dim]
410
+ k = key.transpose(1, 2)
411
+ v = value.transpose(1, 2)
412
+
413
+ # [batch_size, n_local_heads, input_len, max_seq_len]
414
+ scores = torch.matmul(q, k.transpose(2, 3))
415
+ scores = scores + mask
416
+ scores = F.softmax(scores.float(), dim=-1).type_as(q)
417
+
418
+ # [batch_size, n_local_heads, input_len, head_dim]
419
+ output = torch.matmul(scores, v)
420
+ # return scores, output.transpose(1, 2).contiguous()
421
+
422
+ # [batch_size, input_len, hidden_dim]
423
+ output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
424
+ output = self.o_proj(output)
425
+ return scores, output
426
+
427
+
428
+ class TimesFMDecoderLayer(nn.Module):
429
+ """Transformer layer."""
430
+
431
+ def __init__(
432
+ self,
433
+ hidden_size: int,
434
+ intermediate_size: int,
435
+ num_heads: int,
436
+ num_kv_heads: int,
437
+ head_dim: int,
438
+ rms_norm_eps: float = 1e-6,
439
+ ):
440
+ super().__init__()
441
+ self.self_attn = TimesFMAttention(
442
+ hidden_size=hidden_size,
443
+ num_heads=num_heads,
444
+ num_kv_heads=num_kv_heads,
445
+ head_dim=head_dim,
446
+ )
447
+ self.mlp = TransformerMLP(
448
+ hidden_size=hidden_size,
449
+ intermediate_size=intermediate_size,
450
+ )
451
+ self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
452
+
453
+ def forward(
454
+ self,
455
+ hidden_states: torch.Tensor,
456
+ mask: torch.Tensor,
457
+ paddings: torch.Tensor,
458
+ kv_write_indices: torch.Tensor | None = None,
459
+ kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
460
+ ) -> torch.Tensor:
461
+ # Self Attention
462
+ residual = hidden_states
463
+ hidden_states = self.input_layernorm(hidden_states)
464
+ scores, hidden_states = self.self_attn(
465
+ hidden_states=hidden_states,
466
+ mask=mask,
467
+ kv_write_indices=kv_write_indices,
468
+ kv_cache=kv_cache,
469
+ )
470
+ hidden_states = residual + hidden_states
471
+
472
+ # MLP
473
+ hidden_states = self.mlp(hidden_states, paddings=paddings)
474
+
475
+ return scores, hidden_states
476
+
477
+
478
+ class StackedDecoder(nn.Module):
479
+ """Stacked transformer layer."""
480
+
481
+ def __init__(
482
+ self,
483
+ hidden_size: int,
484
+ intermediate_size: int,
485
+ num_heads: int,
486
+ num_kv_heads: int,
487
+ head_dim: int,
488
+ num_layers: int,
489
+ rms_norm_eps: float = 1e-6,
490
+ ):
491
+ super().__init__()
492
+
493
+ self.layers = nn.ModuleList()
494
+ for _ in range(num_layers):
495
+ self.layers.append(
496
+ TimesFMDecoderLayer(
497
+ hidden_size=hidden_size,
498
+ intermediate_size=intermediate_size,
499
+ num_heads=num_heads,
500
+ num_kv_heads=num_kv_heads,
501
+ head_dim=head_dim,
502
+ rms_norm_eps=rms_norm_eps,
503
+ ))
504
+
505
+ def forward(
506
+ self,
507
+ hidden_states: torch.Tensor,
508
+ paddings: torch.Tensor,
509
+ kv_write_indices: torch.Tensor | None = None,
510
+ kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
511
+ ) -> torch.Tensor:
512
+ padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)
513
+ atten_mask = causal_mask(hidden_states)
514
+ mask = merge_masks(padding_mask, atten_mask)
515
+ for i in range(len(self.layers)):
516
+ layer = self.layers[i]
517
+ kv_cache = kv_caches[i] if kv_caches is not None else None
518
+ _, hidden_states = layer(
519
+ hidden_states=hidden_states,
520
+ mask=mask,
521
+ paddings=paddings,
522
+ kv_write_indices=kv_write_indices,
523
+ kv_cache=kv_cache,
524
+ )
525
+ return hidden_states
526
+
527
+
528
+ class PositionalEmbedding(torch.nn.Module):
529
+ """Generates position embedding for a given 1-d sequence.
530
+
531
+ Attributes:
532
+ min_timescale: Start of the geometric index. Determines the periodicity of
533
+ the added signal.
534
+ max_timescale: End of the geometric index. Determines the frequency of the
535
+ added signal.
536
+ embedding_dims: Dimension of the embedding to be generated.
537
+ """
538
+
539
+ def __init__(
540
+ self,
541
+ embedding_dims: int,
542
+ min_timescale: int = 1,
543
+ max_timescale: int = 10_000,
544
+ ) -> None:
545
+ super().__init__()
546
+ self.min_timescale = min_timescale
547
+ self.max_timescale = max_timescale
548
+ self.embedding_dims = embedding_dims
549
+
550
+ def forward(self, seq_length=None, position=None):
551
+ """Generates a Tensor of sinusoids with different frequencies.
552
+
553
+ Args:
554
+ seq_length: an optional Python int defining the output sequence length.
555
+ if the `position` argument is specified.
556
+ position: [B, seq_length], optional position for each token in the
557
+ sequence, only required when the sequence is packed.
558
+
559
+ Returns:
560
+ [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
561
+ """
562
+ if position is None:
563
+ assert seq_length is not None
564
+ # [1, seqlen]
565
+ position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)
566
+ else:
567
+ assert position.ndim == 2, position.shape
568
+
569
+ num_timescales = self.embedding_dims // 2
570
+ log_timescale_increment = math.log(
571
+ float(self.max_timescale) / float(self.min_timescale)) / max(
572
+ num_timescales - 1, 1)
573
+ inv_timescales = self.min_timescale * torch.exp(
574
+ torch.arange(num_timescales, dtype=torch.float32) *
575
+ -log_timescale_increment)
576
+ scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(
577
+ 0)
578
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
579
+ # Padding to ensure correct embedding dimension
580
+ signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
581
+ return signal
582
+
583
+
584
+ class PatchedTimeSeriesDecoder(nn.Module):
585
+ """Patched time-series decoder."""
586
+
587
+ def __init__(self, config: TimesFMConfig):
588
+ super().__init__()
589
+ self.config = config
590
+ self.input_ff_layer = ResidualBlock(
591
+ input_dims=2 * config.patch_len,
592
+ output_dims=config.hidden_size,
593
+ hidden_dims=config.intermediate_size,
594
+ )
595
+ self.freq_emb = nn.Embedding(num_embeddings=3,
596
+ embedding_dim=config.hidden_size)
597
+ self.horizon_ff_layer = ResidualBlock(
598
+ input_dims=config.hidden_size,
599
+ output_dims=config.horizon_len * (1 + len(config.quantiles)),
600
+ hidden_dims=config.intermediate_size,
601
+ )
602
+ self.stacked_transformer = StackedDecoder(
603
+ hidden_size=self.config.hidden_size,
604
+ intermediate_size=self.config.intermediate_size,
605
+ num_heads=self.config.num_heads,
606
+ num_kv_heads=self.config.num_kv_heads,
607
+ head_dim=self.config.head_dim,
608
+ num_layers=self.config.num_layers,
609
+ rms_norm_eps=self.config.rms_norm_eps,
610
+ )
611
+ if self.config.use_positional_embedding:
612
+ self.position_emb = PositionalEmbedding(self.config.hidden_size)
613
+
614
+ def _forward_transform(
615
+ self, inputs: torch.Tensor, patched_pads: torch.Tensor
616
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
617
+ """Input is of shape [B, N, P]."""
618
+ mu, sigma = _masked_mean_std(inputs, patched_pads)
619
+ sigma = torch.where(
620
+ sigma < self.config.tolerance,
621
+ torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
622
+ sigma,
623
+ )
624
+
625
+ # Normalize each patch
626
+ outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
627
+ outputs = torch.where(
628
+ torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
629
+ torch.tensor(self.config.pad_val,
630
+ dtype=outputs.dtype,
631
+ device=outputs.device),
632
+ outputs,
633
+ )
634
+ return outputs, (mu, sigma)
635
+
636
+ def _reverse_transform(
637
+ self, outputs: torch.Tensor, stats: tuple[torch.Tensor,
638
+ torch.Tensor]) -> torch.Tensor:
639
+ """Output is of shape [B, N, P, Q]."""
640
+ mu, sigma = stats
641
+ return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
642
+
643
+ def _preprocess_input(
644
+ self,
645
+ input_ts: torch.Tensor,
646
+ input_padding: torch.Tensor,
647
+ ) -> tuple[
648
+ torch.Tensor,
649
+ torch.Tensor,
650
+ tuple[torch.Tensor, torch.Tensor] | None,
651
+ torch.Tensor,
652
+ ]:
653
+ """Preprocess input for stacked transformer."""
654
+
655
+ # Reshape into patches (using view for efficiency)
656
+ bsize = input_ts.shape[0]
657
+ patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)
658
+ patched_pads = input_padding.view(bsize, -1, self.config.patch_len)
659
+
660
+ patched_inputs = torch.where(
661
+ torch.abs(patched_pads - 1.0) < self.config.tolerance,
662
+ torch.tensor(0.0,
663
+ dtype=patched_inputs.dtype,
664
+ device=patched_inputs.device),
665
+ patched_inputs,
666
+ )
667
+ patched_pads = torch.where(
668
+ torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
669
+ torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
670
+ patched_pads,
671
+ )
672
+ patched_inputs, stats = self._forward_transform(patched_inputs,
673
+ patched_pads)
674
+
675
+ # B x N x D
676
+ patched_inputs = patched_inputs * (1.0 - patched_pads)
677
+ concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
678
+ model_input = self.input_ff_layer(concat_inputs)
679
+
680
+ # A patch should not be padded even if there is at least one zero.
681
+ patched_padding = torch.min(patched_pads,
682
+ dim=-1)[0] # Get the values from the min result
683
+ if self.config.use_positional_embedding:
684
+ pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)
685
+ pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
686
+ pos_emb = _shift_padded_seq(patched_padding, pos_emb)
687
+ model_input += pos_emb
688
+
689
+ return model_input, patched_padding, stats, patched_inputs
690
+
691
+ def _postprocess_output(
692
+ self,
693
+ model_output: torch.Tensor,
694
+ num_outputs: int,
695
+ stats: tuple[torch.Tensor, torch.Tensor],
696
+ ) -> torch.Tensor:
697
+ """Postprocess output of stacked transformer."""
698
+
699
+ # B x N x (H.Q)
700
+ output_ts = self.horizon_ff_layer(model_output)
701
+
702
+ # Reshape using view
703
+ b, n, _ = output_ts.shape
704
+ output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)
705
+
706
+ return self._reverse_transform(output_ts, stats)
707
+
708
+ def forward(
709
+ self,
710
+ input_ts: torch.Tensor,
711
+ input_padding: torch.LongTensor,
712
+ freq: torch.Tensor,
713
+ ) -> torch.Tensor:
714
+ num_outputs = len(self.config.quantiles) + 1
715
+ model_input, patched_padding, stats, _ = self._preprocess_input(
716
+ input_ts=input_ts,
717
+ input_padding=input_padding,
718
+ )
719
+ f_emb = self.freq_emb(freq) # B x 1 x D
720
+ model_input += f_emb
721
+ model_output = self.stacked_transformer(model_input, patched_padding)
722
+
723
+ output_ts = self._postprocess_output(model_output, num_outputs, stats)
724
+ return output_ts
725
+
726
+ def decode(
727
+ self,
728
+ input_ts: torch.Tensor,
729
+ paddings: torch.Tensor,
730
+ freq: torch.LongTensor,
731
+ horizon_len: int,
732
+ output_patch_len: int | None = None,
733
+ max_len: int = 512,
734
+ return_forecast_on_context: bool = False,
735
+ ) -> tuple[torch.Tensor, torch.Tensor]:
736
+ """Auto-regressive decoding without caching.
737
+
738
+ Args:
739
+ input_ts: input time-series and paddings. Time-series shape B x C.
740
+ paddings: padding shape B x (C + H) where H is the prediction length.
741
+ freq: frequency shape B x 1
742
+ horizon_len: prediction length.
743
+ output_patch_len: output length to be fetched from one step of
744
+ auto-regressive decoding.
745
+ max_len: maximum training context length.
746
+ return_forecast_on_context: whether to return the model forecast on the
747
+ context except the first input patch.
748
+
749
+ Returns:
750
+ Tuple of two forecasting results:
751
+ - Point (mean) output predictions as a tensor with shape B x H'.
752
+ - Full predictions (mean and quantiles) as a tensor with shape
753
+ B x H' x (1 + # quantiles).
754
+ In particular, if return_forecast_on_context is True, H' is H plus
755
+ the forecastable context length, i.e. context_len - (first) patch_len.
756
+ """
757
+ final_out = input_ts
758
+ context_len = final_out.shape[1]
759
+ full_outputs = []
760
+ if paddings.shape[1] != final_out.shape[1] + horizon_len:
761
+ raise ValueError(
762
+ "Length of paddings must match length of input + horizon_len:"
763
+ f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
764
+ if output_patch_len is None:
765
+ output_patch_len = self.config.horizon_len
766
+ num_decode_patches = (horizon_len + output_patch_len -
767
+ 1) // output_patch_len
768
+ for step_index in range(num_decode_patches):
769
+ current_padding = paddings[:, 0:final_out.shape[1]]
770
+ input_ts = final_out[:, -max_len:]
771
+ input_padding = current_padding[:, -max_len:]
772
+ fprop_outputs = self(input_ts, input_padding, freq)
773
+ if return_forecast_on_context and step_index == 0:
774
+ # For the first decodings step, collect the model forecast on the
775
+ # context except the unavailable first input batch forecast.
776
+ new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :]
777
+ new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1,
778
+ new_full_ts.size(3))
779
+
780
+ full_outputs.append(new_full_ts)
781
+
782
+ # (full batch, last patch, output_patch_len, index of mean forecast = 0)
783
+ new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
784
+ new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
785
+ # (full batch, last patch, output_patch_len, all output indices)
786
+ full_outputs.append(new_full_ts)
787
+ final_out = torch.concatenate([final_out, new_ts], axis=-1)
788
+
789
+ if return_forecast_on_context:
790
+ # `full_outputs` indexing starts at after the first input patch.
791
+ full_outputs = torch.concatenate(
792
+ full_outputs,
793
+ axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :]
794
+ else:
795
+ # `full_outputs` indexing starts at the forecast horizon.
796
+ full_outputs = torch.concatenate(full_outputs, axis=1)[:,
797
+ 0:horizon_len, :]
798
+
799
+ return (full_outputs[:, :, 0], full_outputs)
800
+
801
+ class TimesFM(nn.Module):
802
+
803
+ def __init__(self, lookback: int = 512, lookahead: int = 96, context_len: int = 512):
804
+
805
+ super(TimesFM, self).__init__()
806
+
807
+ self.timesfm = PatchedTimeSeriesDecoder(TimesFMConfig())
808
+ self.lookback, self.lookahead = lookback, lookahead
809
+ self.context_len = context_len
810
+
811
+ def load_state_dict(self, state_dict, *args, **kwargs):
812
+
813
+ return self.timesfm.load_state_dict(state_dict, *args, **kwargs)
814
+
815
+ def state_dict(self, *args, **kwargs):
816
+
817
+ return self.timesfm.state_dict(*args, **kwargs)
818
+
819
+ def pad_tensor(self, x):
820
+
821
+ B, L = x.shape
822
+ device = x.device
823
+ dtype = x.dtype
824
+
825
+ if L < self.context_len:
826
+ padded_input = torch.zeros((B, self.context_len), device=device, dtype=dtype)
827
+ padded_input[:, -L:] = x
828
+ padding = torch.ones((B, self.context_len), device=device, dtype=dtype)
829
+ padding[:, -L:] = 0
830
+ else:
831
+ padded_input = x[:, -self.context_len:]
832
+ padding = torch.zeros((B, self.context_len), device=device, dtype=dtype)
833
+
834
+ freq = torch.zeros((B, 1), device=device, dtype=torch.long)
835
+
836
+ return padded_input, torch.cat((padding,torch.zeros((B,self.lookahead),device=device,dtype=dtype)),dim=-1), freq
837
+
838
+ def forward(self, x):
839
+
840
+ padded_inp, padding, freq = self.pad_tensor(x)
841
+ return self.timesfm.decode(padded_inp,padding,freq,self.lookahead)[0] # ignoring quantiles
models/__pycache__/Autoformer.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/Autoformer.cpython-310.pyc and b/models/__pycache__/Autoformer.cpython-310.pyc differ
 
models/__pycache__/LSTM.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/LSTM.cpython-310.pyc and b/models/__pycache__/LSTM.cpython-310.pyc differ
 
models/__pycache__/LSTNet.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/LSTNet.cpython-310.pyc and b/models/__pycache__/LSTNet.cpython-310.pyc differ
 
weights/TimesFM_L_512_T_48_HET.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dd216286c5493e6aaa9aa0f08ccf6e645423e83733e6e2c6be78920f5266cc4
3
+ size 814365703
weights/TimesFM_L_512_T_48_HOM.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1eb69eeaa672c28212c5fe410d4b7d87c41a0868b8874f33308ab932f01ac89
3
+ size 814365703
weights/TimesFM_L_512_T_4_HET.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:826f2c6d2f01218f55579997cd057257f6d5817b3856fb9ffd6e70d13c5d8e2a
3
+ size 814365382
weights/TimesFM_L_512_T_4_HOM.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d710dc8d8012d226a63d4a983743b48dfeb21d10c1d2bc674b86ec6472b4a060
3
+ size 814365382
weights/TimesFM_L_512_T_96_HET.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3987782b50e4e6119cd9d35df3815bb2895ec010100919862c35620d9459767d
3
+ size 814365703
weights/TimesFM_L_512_T_96_HOM.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49f862e58bc92993cf966b06facaddc79a7e8875d8a525561b5ae3fc3b67a1fc
3
+ size 814365703