Upload FIMMJP
Browse files
mjp.py
CHANGED
@@ -124,6 +124,9 @@ class FIMMJP(AModel):
|
|
124 |
x (dict[str, Tensor]): A dictionary containing the input tensors:
|
125 |
- "observation_grid": Tensor representing the observation grid.
|
126 |
- "observation_values": Tensor representing the observation values.
|
|
|
|
|
|
|
127 |
- Optional keys for loss calculation:
|
128 |
- "intensity_matrices": Tensor representing the intensity matrices.
|
129 |
- "initial_distributions": Tensor representing the initial distributions.
|
@@ -133,8 +136,8 @@ class FIMMJP(AModel):
|
|
133 |
Returns:
|
134 |
dict: A dictionary containing the following keys:
|
135 |
- "im": Tensor representing the intensity matrix.
|
136 |
-
- "
|
137 |
-
- "
|
138 |
- "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
|
139 |
"""
|
140 |
|
@@ -155,13 +158,13 @@ class FIMMJP(AModel):
|
|
155 |
pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
|
156 |
|
157 |
out = {
|
158 |
-
"
|
159 |
pred_offdiag_im_mean, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
|
160 |
),
|
161 |
-
"
|
162 |
-
pred_offdiag_im_logvar, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
|
163 |
),
|
164 |
-
"
|
165 |
}
|
166 |
if "intensity_matrices" in x and "initial_distributions" in x:
|
167 |
out["losses"] = self.loss(
|
@@ -182,7 +185,7 @@ class FIMMJP(AModel):
|
|
182 |
pos_enc = self.pos_encodings(obs_grid_normalized)
|
183 |
path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
|
184 |
if isinstance(self.ts_encoder, TransformerEncoder):
|
185 |
-
padding_mask = create_padding_mask(x["
|
186 |
padding_mask[:, 0] = True
|
187 |
h = self.ts_encoder(path.view(B * P, L, -1), padding_mask)[:, 1, :].view(B, P, -1)
|
188 |
if isinstance(self.path_attention, nn.MultiheadAttention):
|
@@ -190,8 +193,8 @@ class FIMMJP(AModel):
|
|
190 |
else:
|
191 |
h = self.path_attention(h, h, h)
|
192 |
elif isinstance(self.ts_encoder, RNNEncoder):
|
193 |
-
h = self.ts_encoder(path.view(B * P, L, -1), x["
|
194 |
-
last_observation = x["
|
195 |
h = h[torch.arange(B * P), last_observation].view(B, P, -1)
|
196 |
h = self.path_attention(h, h, h)
|
197 |
|
|
|
124 |
x (dict[str, Tensor]): A dictionary containing the input tensors:
|
125 |
- "observation_grid": Tensor representing the observation grid.
|
126 |
- "observation_values": Tensor representing the observation values.
|
127 |
+
- "seq_lengths": Tensor representing the sequence lengths.
|
128 |
+
- Optional keys:
|
129 |
+
- "time_normalization_factors": Tensor representing the time normalization factors.
|
130 |
- Optional keys for loss calculation:
|
131 |
- "intensity_matrices": Tensor representing the intensity matrices.
|
132 |
- "initial_distributions": Tensor representing the initial distributions.
|
|
|
136 |
Returns:
|
137 |
dict: A dictionary containing the following keys:
|
138 |
- "im": Tensor representing the intensity matrix.
|
139 |
+
- "intensity_matrices_variance": Tensor representing the log variance of the intensity matrix.
|
140 |
+
- "initial_condition": Tensor representing the initial conditions.
|
141 |
- "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
|
142 |
"""
|
143 |
|
|
|
158 |
pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
|
159 |
|
160 |
out = {
|
161 |
+
"intensity_matrices": create_matrix_from_off_diagonal(
|
162 |
pred_offdiag_im_mean, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
|
163 |
),
|
164 |
+
"intensity_matrices_variance": create_matrix_from_off_diagonal(
|
165 |
+
torch.exp(pred_offdiag_im_logvar), self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
|
166 |
),
|
167 |
+
"initial_condition": init_cond,
|
168 |
}
|
169 |
if "intensity_matrices" in x and "initial_distributions" in x:
|
170 |
out["losses"] = self.loss(
|
|
|
185 |
pos_enc = self.pos_encodings(obs_grid_normalized)
|
186 |
path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
|
187 |
if isinstance(self.ts_encoder, TransformerEncoder):
|
188 |
+
padding_mask = create_padding_mask(x["seq_lengths"].view(B * P), L)
|
189 |
padding_mask[:, 0] = True
|
190 |
h = self.ts_encoder(path.view(B * P, L, -1), padding_mask)[:, 1, :].view(B, P, -1)
|
191 |
if isinstance(self.path_attention, nn.MultiheadAttention):
|
|
|
193 |
else:
|
194 |
h = self.path_attention(h, h, h)
|
195 |
elif isinstance(self.ts_encoder, RNNEncoder):
|
196 |
+
h = self.ts_encoder(path.view(B * P, L, -1), x["seq_lengths"].view(B * P))
|
197 |
+
last_observation = x["seq_lengths"].view(B * P) - 1
|
198 |
h = h[torch.arange(B * P), last_observation].view(B, P, -1)
|
199 |
h = self.path_attention(h, h, h)
|
200 |
|