File size: 1,646 Bytes
6dd3ebe
 
 
 
4cc7625
 
6dd3ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
from typing import Union, List, Tuple

# Written by Shourya Bose, shbose@ucsc.edu

class LSTM(nn.Module):
    
    def __init__(
        self,
        input_size: int = 8,
        hidden_size: int = 40,
        num_layers: int = 2,
        dropout: float = 0.1,
        lookback: int = 8, # this will not be used, but keeping it here for consistency
    ):
        
        super(LSTM,self).__init__()
        
        # save values for use outside init
        self.hidden_size, self.num_layers = hidden_size, num_layers
        
        # lstm
        self.lstm = nn.LSTM(
            input_size = input_size,
            hidden_size = hidden_size,
            num_layers = num_layers,
            bias = True,
            batch_first = True,
            dropout = dropout,
            bidirectional = False,
            proj_size = 0,
            device = None
        )

        # projector
        self.proj = nn.Linear(in_features=hidden_size, out_features=1, bias=False)

        # dropout
        self.dropout = nn.Dropout(p=dropout)
        
        
    def init_h_c_(self, B, device, dtype):
        
        h = torch.zeros((self.num_layers,B,self.hidden_size),dtype=dtype,device=device)
        c = torch.zeros((self.num_layers,B,self.hidden_size),dtype=dtype,device=device)
        
        return h,c
    
    def forward(self, x, fut_time):
        
        B, dev, dt = x.shape[0], x.device, x.dtype
        
        # generate states
        h,c = self.init_h_c_(B, dev, dt)
        
        # iterate 
        out,(_,_) = self.lstm(x,(h,c))
        return self.proj(self.dropout(out[:,-1,:]))