|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.fft |
|
import math |
|
|
|
|
|
|
|
|
|
class Inception_Block_V1(nn.Module): |
|
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): |
|
super(Inception_Block_V1, self).__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_kernels = num_kernels |
|
kernels = [] |
|
for i in range(self.num_kernels): |
|
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) |
|
self.kernels = nn.ModuleList(kernels) |
|
if init_weight: |
|
self._initialize_weights() |
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
res_list = [] |
|
for i in range(self.num_kernels): |
|
res_list.append(self.kernels[i](x)) |
|
res = torch.stack(res_list, dim=-1).mean(-1) |
|
return res |
|
|
|
class PositionalEmbedding(nn.Module): |
|
def __init__(self, d_model, max_len=5000): |
|
super(PositionalEmbedding, self).__init__() |
|
|
|
pe = torch.zeros(max_len, d_model).float() |
|
pe.require_grad = False |
|
|
|
position = torch.arange(0, max_len).float().unsqueeze(1) |
|
div_term = (torch.arange(0, d_model, 2).float() |
|
* -(math.log(10000.0) / d_model)).exp() |
|
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
pe = pe.unsqueeze(0) |
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
return self.pe[:, :x.size(1)] |
|
|
|
class FixedEmbedding(nn.Module): |
|
def __init__(self, c_in, d_model): |
|
super(FixedEmbedding, self).__init__() |
|
|
|
w = torch.zeros(c_in, d_model).float() |
|
w.require_grad = False |
|
|
|
position = torch.arange(0, c_in).float().unsqueeze(1) |
|
div_term = (torch.arange(0, d_model, 2).float() |
|
* -(math.log(10000.0) / d_model)).exp() |
|
|
|
w[:, 0::2] = torch.sin(position * div_term) |
|
w[:, 1::2] = torch.cos(position * div_term) |
|
|
|
self.emb = nn.Embedding(c_in, d_model) |
|
self.emb.weight = nn.Parameter(w, requires_grad=False) |
|
|
|
def forward(self, x): |
|
return self.emb(x).detach() |
|
|
|
class TemporalEmbedding(nn.Module): |
|
def __init__(self, d_model, embed_type='fixed', freq='h'): |
|
super(TemporalEmbedding, self).__init__() |
|
|
|
hour_size = 96 |
|
weekday_size = 7 |
|
|
|
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding |
|
self.hour_embed = Embed(hour_size, d_model) |
|
self.weekday_embed = Embed(weekday_size, d_model) |
|
|
|
def forward(self, x): |
|
x = x.long() |
|
hour_x = self.hour_embed(x[:, :, 0]) |
|
weekday_x = self.weekday_embed(x[:, :, 1]) |
|
|
|
return hour_x + weekday_x |
|
|
|
class TokenEmbedding(nn.Module): |
|
def __init__(self, c_in, d_model): |
|
super(TokenEmbedding, self).__init__() |
|
padding = 1 if torch.__version__ >= '1.5.0' else 2 |
|
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, |
|
kernel_size=3, padding=padding, padding_mode='circular', bias=False) |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.kaiming_normal_( |
|
m.weight, mode='fan_in', nonlinearity='leaky_relu') |
|
|
|
def forward(self, x): |
|
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) |
|
return x |
|
|
|
class DataEmbedding(nn.Module): |
|
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): |
|
super(DataEmbedding, self).__init__() |
|
|
|
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) |
|
self.position_embedding = PositionalEmbedding(d_model=d_model) |
|
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, |
|
freq=freq) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
def forward(self, x, x_mark): |
|
if x_mark is None: |
|
x = self.value_embedding(x) + self.position_embedding(x) |
|
else: |
|
x = self.value_embedding( |
|
x) + self.temporal_embedding(x_mark) + self.position_embedding(x) |
|
return self.dropout(x) |
|
|
|
def FFT_for_Period(x, k=2): |
|
|
|
xf = torch.fft.rfft(x, dim=1) |
|
|
|
frequency_list = abs(xf).mean(0).mean(-1) |
|
frequency_list[0] = 0 |
|
_, top_list = torch.topk(frequency_list, k) |
|
top_list = top_list.detach().cpu().numpy() |
|
period = x.shape[1] // top_list |
|
return period, abs(xf).mean(-1)[:, top_list] |
|
|
|
|
|
class TimesBlock(nn.Module): |
|
def __init__(self, seq_len, pred_len, top_k, d_model, d_ff, num_kernels): |
|
super(TimesBlock, self).__init__() |
|
self.seq_len = seq_len |
|
self.pred_len = pred_len |
|
self.k = top_k |
|
|
|
self.conv = nn.Sequential( |
|
Inception_Block_V1(d_model, d_ff, |
|
num_kernels=num_kernels), |
|
nn.GELU(), |
|
Inception_Block_V1(d_ff, d_model, |
|
num_kernels=num_kernels) |
|
) |
|
|
|
def forward(self, x): |
|
B, T, N = x.size() |
|
period_list, period_weight = FFT_for_Period(x, self.k) |
|
|
|
res = [] |
|
for i in range(self.k): |
|
period = period_list[i] |
|
|
|
if (self.seq_len + self.pred_len) % period != 0: |
|
length = ( |
|
((self.seq_len + self.pred_len) // period) + 1) * period |
|
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device) |
|
out = torch.cat([x, padding], dim=1) |
|
else: |
|
length = (self.seq_len + self.pred_len) |
|
out = x |
|
|
|
out = out.reshape(B, length // period, period, |
|
N).permute(0, 3, 1, 2).contiguous() |
|
|
|
out = self.conv(out) |
|
|
|
out = out.permute(0, 2, 3, 1).reshape(B, -1, N) |
|
res.append(out[:, :(self.seq_len + self.pred_len), :]) |
|
res = torch.stack(res, dim=-1) |
|
|
|
period_weight = F.softmax(period_weight, dim=1) |
|
period_weight = period_weight.unsqueeze( |
|
1).unsqueeze(1).repeat(1, T, N, 1) |
|
res = torch.sum(res * period_weight, -1) |
|
|
|
res = res + x |
|
return res |
|
|
|
|
|
class TimesNet(nn.Module): |
|
""" |
|
Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq |
|
""" |
|
|
|
def __init__( |
|
self, |
|
enc_in, |
|
dec_in, |
|
c_out, |
|
pred_len, |
|
seq_len, |
|
output_attention = False, |
|
data_idx = [0,3,4,5,6,7], |
|
time_idx = [1,2], |
|
d_model = 16, |
|
d_ff = 64, |
|
e_layers = 2, |
|
top_k = 5, |
|
num_kernels = 2, |
|
dropout = 0.1 |
|
): |
|
super(TimesNet, self).__init__() |
|
|
|
self.data_idx = data_idx |
|
self.time_idx = time_idx |
|
self.dec_in = dec_in |
|
|
|
self.seq_len = seq_len |
|
self.pred_len = pred_len |
|
self.model = nn.ModuleList([TimesBlock(seq_len, pred_len, top_k, d_model, d_ff, num_kernels) |
|
for _ in range(e_layers)]) |
|
self.enc_embedding = DataEmbedding(enc_in, d_model, 'fixed', 'h', |
|
dropout) |
|
self.layer = e_layers |
|
self.layer_norm = nn.LayerNorm(d_model) |
|
self.predict_linear = nn.Linear( |
|
self.seq_len, self.pred_len + self.seq_len) |
|
self.projection = nn.Linear( |
|
d_model, c_out, bias=True) |
|
|
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): |
|
|
|
means = x_enc.mean(1, keepdim=True).detach() |
|
x_enc = x_enc - means |
|
stdev = torch.sqrt( |
|
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) |
|
x_enc /= stdev |
|
|
|
|
|
enc_out = self.enc_embedding(x_enc, x_mark_enc) |
|
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute( |
|
0, 2, 1) |
|
|
|
for i in range(self.layer): |
|
enc_out = self.layer_norm(self.model[i](enc_out)) |
|
|
|
dec_out = self.projection(enc_out) |
|
|
|
|
|
dec_out = dec_out * \ |
|
(stdev[:, 0, :].unsqueeze(1).repeat( |
|
1, self.pred_len + self.seq_len, 1)) |
|
dec_out = dec_out + \ |
|
(means[:, 0, :].unsqueeze(1).repeat( |
|
1, self.pred_len + self.seq_len, 1)) |
|
return dec_out |
|
|
|
def forward(self, x, fut_time): |
|
|
|
x_enc = x[:,:,self.data_idx] |
|
x_mark_enc = x[:,:,self.time_idx] |
|
x_dec = torch.zeros((fut_time.shape[0],fut_time.shape[1],self.dec_in),dtype=fut_time.dtype,device=fut_time.device) |
|
x_mark_dec = fut_time |
|
|
|
return self.forecast(x_enc,x_mark_enc,x_dec,x_mark_dec)[:,-1,[0]] |