cvejoski commited on
Commit
c686e91
1 Parent(s): 9ee5338

Upload FIMMJP

Browse files
Files changed (1) hide show
  1. mjp.py +12 -9
mjp.py CHANGED
@@ -124,6 +124,9 @@ class FIMMJP(AModel):
124
  x (dict[str, Tensor]): A dictionary containing the input tensors:
125
  - "observation_grid": Tensor representing the observation grid.
126
  - "observation_values": Tensor representing the observation values.
 
 
 
127
  - Optional keys for loss calculation:
128
  - "intensity_matrices": Tensor representing the intensity matrices.
129
  - "initial_distributions": Tensor representing the initial distributions.
@@ -133,8 +136,8 @@ class FIMMJP(AModel):
133
  Returns:
134
  dict: A dictionary containing the following keys:
135
  - "im": Tensor representing the intensity matrix.
136
- - "log_var_im": Tensor representing the log variance of the intensity matrix.
137
- - "init_cond": Tensor representing the initial conditions.
138
  - "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
139
  """
140
 
@@ -155,13 +158,13 @@ class FIMMJP(AModel):
155
  pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
156
 
157
  out = {
158
- "im": create_matrix_from_off_diagonal(
159
  pred_offdiag_im_mean, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
160
  ),
161
- "log_var_im": create_matrix_from_off_diagonal(
162
- pred_offdiag_im_logvar, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
163
  ),
164
- "init_cond": init_cond,
165
  }
166
  if "intensity_matrices" in x and "initial_distributions" in x:
167
  out["losses"] = self.loss(
@@ -182,7 +185,7 @@ class FIMMJP(AModel):
182
  pos_enc = self.pos_encodings(obs_grid_normalized)
183
  path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
184
  if isinstance(self.ts_encoder, TransformerEncoder):
185
- padding_mask = create_padding_mask(x["mask_seq_lengths"].view(B * P), L)
186
  padding_mask[:, 0] = True
187
  h = self.ts_encoder(path.view(B * P, L, -1), padding_mask)[:, 1, :].view(B, P, -1)
188
  if isinstance(self.path_attention, nn.MultiheadAttention):
@@ -190,8 +193,8 @@ class FIMMJP(AModel):
190
  else:
191
  h = self.path_attention(h, h, h)
192
  elif isinstance(self.ts_encoder, RNNEncoder):
193
- h = self.ts_encoder(path.view(B * P, L, -1), x["mask_seq_lengths"].view(B * P))
194
- last_observation = x["mask_seq_lengths"].view(B * P) - 1
195
  h = h[torch.arange(B * P), last_observation].view(B, P, -1)
196
  h = self.path_attention(h, h, h)
197
 
 
124
  x (dict[str, Tensor]): A dictionary containing the input tensors:
125
  - "observation_grid": Tensor representing the observation grid.
126
  - "observation_values": Tensor representing the observation values.
127
+ - "seq_lengths": Tensor representing the sequence lengths.
128
+ - Optional keys:
129
+ - "time_normalization_factors": Tensor representing the time normalization factors.
130
  - Optional keys for loss calculation:
131
  - "intensity_matrices": Tensor representing the intensity matrices.
132
  - "initial_distributions": Tensor representing the initial distributions.
 
136
  Returns:
137
  dict: A dictionary containing the following keys:
138
  - "im": Tensor representing the intensity matrix.
139
+ - "intensity_matrices_variance": Tensor representing the log variance of the intensity matrix.
140
+ - "initial_condition": Tensor representing the initial conditions.
141
  - "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
142
  """
143
 
 
158
  pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
159
 
160
  out = {
161
+ "intensity_matrices": create_matrix_from_off_diagonal(
162
  pred_offdiag_im_mean, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
163
  ),
164
+ "intensity_matrices_variance": create_matrix_from_off_diagonal(
165
+ torch.exp(pred_offdiag_im_logvar), self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
166
  ),
167
+ "initial_condition": init_cond,
168
  }
169
  if "intensity_matrices" in x and "initial_distributions" in x:
170
  out["losses"] = self.loss(
 
185
  pos_enc = self.pos_encodings(obs_grid_normalized)
186
  path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
187
  if isinstance(self.ts_encoder, TransformerEncoder):
188
+ padding_mask = create_padding_mask(x["seq_lengths"].view(B * P), L)
189
  padding_mask[:, 0] = True
190
  h = self.ts_encoder(path.view(B * P, L, -1), padding_mask)[:, 1, :].view(B, P, -1)
191
  if isinstance(self.path_attention, nn.MultiheadAttention):
 
193
  else:
194
  h = self.path_attention(h, h, h)
195
  elif isinstance(self.ts_encoder, RNNEncoder):
196
+ h = self.ts_encoder(path.view(B * P, L, -1), x["seq_lengths"].view(B * P))
197
+ last_observation = x["seq_lengths"].view(B * P) - 1
198
  h = h[torch.arange(B * P), last_observation].view(B, P, -1)
199
  h = self.path_attention(h, h, h)
200