import torch import torch.nn as nn from typing import Union, List, Tuple # Written by Shourya Bose, shbose@ucsc.edu class LSTM(nn.Module): def __init__( self, input_size: int = 8, hidden_size: int = 40, num_layers: int = 2, dropout: float = 0.1, lookback: int = 8, # this will not be used, but keeping it here for consistency ): super(LSTM,self).__init__() # save values for use outside init self.hidden_size, self.num_layers = hidden_size, num_layers # lstm self.lstm = nn.LSTM( input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, bias = True, batch_first = True, dropout = dropout, bidirectional = False, proj_size = 0, device = None ) # projector self.proj = nn.Linear(in_features=hidden_size, out_features=1, bias=False) # dropout self.dropout = nn.Dropout(p=dropout) def init_h_c_(self, B, device, dtype): h = torch.zeros((self.num_layers,B,self.hidden_size),dtype=dtype,device=device) c = torch.zeros((self.num_layers,B,self.hidden_size),dtype=dtype,device=device) return h,c def forward(self, x, fut_time): B, dev, dt = x.shape[0], x.device, x.dtype # generate states h,c = self.init_h_c_(B, dev, dt) # iterate out,(_,_) = self.lstm(x,(h,c)) return self.proj(self.dropout(out[:,-1,:]))