import copy from typing import Any, Dict import torch from torch import Tensor, nn from transformers import AutoConfig, AutoModel, PretrainedConfig from fim.models.blocks import AModel, ModelFactory, RNNEncoder, TransformerEncoder from fim.models.utils import create_matrix_from_off_diagonal, create_padding_mask, get_off_diagonal_elements from fim.utils.helper import create_class_instance class FIMMJPConfig(PretrainedConfig): model_type = "fimmjp" def __init__( self, n_states: int = 2, use_adjacency_matrix: bool = False, ts_encoder: dict = None, pos_encodings: dict = None, path_attention: dict = None, intensity_matrix_decoder: dict = None, initial_distribution_decoder: dict = None, **kwargs, ): self.n_states = n_states self.use_adjacency_matrix = use_adjacency_matrix self.ts_encoder = ts_encoder self.pos_encodings = pos_encodings self.path_attention = path_attention self.intensity_matrix_decoder = intensity_matrix_decoder self.initial_distribution_decoder = initial_distribution_decoder super().__init__(**kwargs) class FIMMJP(AModel): """ FIMMJP: A Neural Recognition Model for Zero-Shot Inference of Markov Jump Processes This class implements a neural recognition model for zero-shot inference of Markov jump processes (MJPs) on bounded state spaces from noisy and sparse observations. The methodology is based on the following paper: Markov jump processes are continuous-time stochastic processes which describe dynamical systems evolving in discrete state spaces. These processes find wide application in the natural sciences and machine learning, but their inference is known to be far from trivial. In this work we introduce a methodology for zero-shot inference of Markov jump processes (MJPs), on bounded state spaces, from noisy and sparse observations, which consists of two components. First, a broad probability distribution over families of MJPs, as well as over possible observation times and noise mechanisms, with which we simulate a synthetic dataset of hidden MJPs and their noisy observations. Second, a neural recognition model that processes subsets of the simulated observations, and that is trained to output the initial condition and rate matrix of the target MJP in a supervised way. We empirically demonstrate that one and the same (pretrained) recognition model can infer, in a zero-shot fashion, hidden MJPs evolving in state spaces of different dimensionalities. Specifically, we infer MJPs which describe (i) discrete flashing ratchet systems, which are a type of Brownian motors, and the conformational dynamics in (ii) molecular simulations, (iii) experimental ion channel data and (iv) simple protein folding models. What is more, we show that our model performs on par with state-of-the-art models which are trained on the target datasets. It is model from the paper:"Foundation Inference Models for Markov Jump Processes" --- https://arxiv.org/abs/2406.06419. Attributes: n_states (int): Number of states in the Markov jump process. use_adjacency_matrix (bool): Whether to use an adjacency matrix. ts_encoder (dict | TransformerEncoder): Time series encoder. pos_encodings (dict | SineTimeEncoding): Positional encodings. path_attention (dict | nn.Module): Path attention mechanism. intensity_matrix_decoder (dict | nn.Module): Decoder for the intensity matrix. initial_distribution_decoder (dict | nn.Module): Decoder for the initial distribution. gaussian_nll (nn.GaussianNLLLoss): Gaussian negative log-likelihood loss. init_cross_entropy (nn.CrossEntropyLoss): Cross-entropy loss for initial distribution. Methods: forward(x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict: Forward pass of the model. __decode(h: Tensor) -> tuple[Tensor, Tensor]: Decode the hidden representation to obtain the intensity matrix and initial condition. __encode(x: Tensor, obs_grid_normalized: Tensor, obs_values_one_hot: Tensor) -> Tensor: Encode the input observations to obtain the hidden representation. __denormalize_offdiag_mean_logvar(norm_constants: Tensor, pred_offdiag_im_mean_logvar: Tensor) -> tuple[Tensor, Tensor]: Denormalize the predicted off-diagonal mean and log-variance. __normalize_obs_grid(obs_grid: Tensor) -> tuple[Tensor, Tensor]: Normalize the observation grid. loss(pred_im: Tensor, pred_logvar_im: Tensor, pred_init_cond: Tensor, target_im: Tensor, target_init_cond: Tensor, adjaceny_matrix: Tensor, normalization_constants: Tensor, schedulers: dict = None, step: int = None) -> dict: Compute the loss for the model. new_stats() -> dict: Initialize new statistics. metric(y: Any, y_target: Any) -> Dict: Compute the metric for the model. """ config_class = FIMMJPConfig def __init__(self, config: FIMMJPConfig, **kwargs): super().__init__(config, **kwargs) self.n_states = config.n_states self.use_adjacency_matrix = config.use_adjacency_matrix self.ts_encoder = config.ts_encoder self.total_offdiagonal_transitions = self.n_states**2 - self.n_states self.__create_modules() self.gaussian_nll = nn.GaussianNLLLoss(full=True, reduction="none") self.init_cross_entropy = nn.CrossEntropyLoss(reduction="none") def __create_modules(self): pos_encodings = copy.deepcopy(self.config.pos_encodings) ts_encoder = copy.deepcopy(self.config.ts_encoder) path_attention = copy.deepcopy(self.config.path_attention) intensity_matrix_decoder = copy.deepcopy(self.config.intensity_matrix_decoder) initial_distribution_decoder = copy.deepcopy(self.config.initial_distribution_decoder) if ts_encoder["name"] == "fim.models.blocks.base.TransformerEncoder": pos_encodings["out_features"] -= self.n_states self.pos_encodings = create_class_instance(pos_encodings.pop("name"), pos_encodings) ts_encoder["in_features"] = self.n_states + self.pos_encodings.out_features self.ts_encoder = create_class_instance(ts_encoder.pop("name"), ts_encoder) self.path_attention = create_class_instance(path_attention.pop("name"), path_attention) in_features = intensity_matrix_decoder.get( "in_features", self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.use_adjacency_matrix else 1) ) intensity_matrix_decoder["in_features"] = in_features intensity_matrix_decoder["out_features"] = 2 * self.total_offdiagonal_transitions self.intensity_matrix_decoder = create_class_instance(intensity_matrix_decoder.pop("name"), intensity_matrix_decoder) in_features = initial_distribution_decoder.get( "in_features", self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.use_adjacency_matrix else 1) ) initial_distribution_decoder["in_features"] = in_features initial_distribution_decoder["out_features"] = self.n_states self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder) def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict: """ Forward pass for the model. Args: x (dict[str, Tensor]): A dictionary containing the input tensors: - "observation_grid": Tensor representing the observation grid. - "observation_values": Tensor representing the observation values. - "seq_lengths": Tensor representing the sequence lengths. - Optional keys: - "time_normalization_factors": Tensor representing the time normalization factors. - Optional keys for loss calculation: - "intensity_matrices": Tensor representing the intensity matrices. - "initial_distributions": Tensor representing the initial distributions. - "adjacency_matrices": Tensor representing the adjacency matrices. schedulers (dict, optional): A dictionary of schedulers for the training process. Default is None. step (int, optional): The current step in the training process. Default is None. Returns: dict: A dictionary containing the following keys: - "im": Tensor representing the intensity matrix. - "intensity_matrices_variance": Tensor representing the log variance of the intensity matrix. - "initial_condition": Tensor representing the initial conditions. - "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`. """ obs_grid = x["observation_grid"] if "time_normalization_factors" not in x: norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid) x["time_normalization_factors"] = norm_constants x["observation_grid_normalized"] = obs_grid else: norm_constants = x["time_normalization_factors"] x["observation_grid_normalized"] = obs_grid x["observation_values_one_hot"] = torch.nn.functional.one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.n_states) h = self.__encode(x) pred_offdiag_im_mean_logvar, init_cond = self.__decode(h) pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar) out = { "intensity_matrices": create_matrix_from_off_diagonal( pred_offdiag_im_mean, self.n_states, mode="negative_sum_row", n_states=self.n_states if n_states is None else n_states ), "intensity_matrices_variance": create_matrix_from_off_diagonal( torch.exp(pred_offdiag_im_logvar), self.n_states, mode="negative_sum_row", n_states=self.n_states if n_states is None else n_states, ), "initial_condition": init_cond, } if "intensity_matrices" in x and "initial_distributions" in x: out["losses"] = self.loss( pred_offdiag_im_mean, pred_offdiag_im_logvar, init_cond, x, norm_constants.view(-1, 1), schedulers, step ) return out def __decode(self, h: Tensor) -> tuple[Tensor, Tensor]: pred_offdiag_logmean_logstd = self.intensity_matrix_decoder(h) init_cond = self.initial_distribution_decoder(h) return pred_offdiag_logmean_logstd, init_cond def __encode(self, x: dict[str, Tensor]) -> Tensor: obs_grid_normalized = x["observation_grid_normalized"] obs_values_one_hot = x["observation_values_one_hot"] B, P, L = obs_grid_normalized.shape[:3] pos_enc = self.pos_encodings(obs_grid_normalized) path = torch.cat([pos_enc, obs_values_one_hot], dim=-1) if isinstance(self.ts_encoder, TransformerEncoder): padding_mask = create_padding_mask(x["seq_lengths"].view(B * P), L) padding_mask[:, 0] = True h = self.ts_encoder(path.view(B * P, L, -1), padding_mask)[:, 1, :].view(B, P, -1) if isinstance(self.path_attention, nn.MultiheadAttention): h = self.path_attention(h, h, h)[0][:, -1] else: h = self.path_attention(h, h, h) elif isinstance(self.ts_encoder, RNNEncoder): h = self.ts_encoder(path.view(B * P, L, -1), x["seq_lengths"].view(B * P)) last_observation = x["seq_lengths"].view(B * P) - 1 h = h[torch.arange(B * P), last_observation].view(B, P, -1) h = self.path_attention(h, h, h) h = torch.cat([h, torch.ones(B, 1).to(h.device) / 100.0 * P], dim=-1) if self.use_adjacency_matrix: h = torch.cat([h, get_off_diagonal_elements(x["adjacency_matrix"])], dim=-1) return h def __denormalize_offdiag_mean_logstd(self, norm_constants: Tensor, pred_offdiag_im_logmean_logstd: Tensor) -> tuple[Tensor, Tensor]: pred_offdiag_im_logmean, pred_offdiag_im_logstd = pred_offdiag_im_logmean_logstd.chunk(2, dim=-1) pred_offdiag_im_mean = torch.exp(pred_offdiag_im_logmean) / norm_constants.view(-1, 1) pred_offdiag_im_logstd = pred_offdiag_im_logstd - torch.log(norm_constants.view(-1, 1)) return pred_offdiag_im_mean, pred_offdiag_im_logstd def __normalize_obs_grid(self, obs_grid: Tensor) -> tuple[Tensor, Tensor]: norm_constants = obs_grid.amax(dim=[-3, -2, -1]) obs_grid_normalized = obs_grid / norm_constants.view(-1, 1, 1, 1) return norm_constants, obs_grid_normalized def loss( self, pred_im: Tensor, pred_logstd_im: Tensor, pred_init_cond: Tensor, target: dict, normalization_constants: Tensor, schedulers: dict = None, step: int = None, ) -> dict: target_im = target["intensity_matrices"] target_init_cond = target["initial_distributions"] adjaceny_matrix = target["adjacency_matrices"] target_mean = get_off_diagonal_elements(target_im) P = target["observation_grid"].shape[1] adjaceny_matrix = get_off_diagonal_elements(adjaceny_matrix) target_init_cond = torch.argmax(target_init_cond, dim=-1).long() pred_im_std = torch.exp(pred_logstd_im) loss_gauss = adjaceny_matrix * self.gaussian_nll(pred_im, target_mean, torch.pow(pred_im_std, 2)) loss_gauss = loss_gauss.sum() / (adjaceny_matrix.sum() + 1e-8) loss_initial = self.init_cross_entropy(pred_init_cond, target_init_cond).mean() zero_entries = 1.0 - adjaceny_matrix loss_missing_link = normalization_constants * zero_entries * (torch.pow(pred_im, 2) + torch.pow(pred_im_std, 2)) loss_missing_link = loss_missing_link.sum() / (zero_entries.sum() + 1e-8) rmse_loss = torch.sqrt(torch.mean((target_mean - pred_im) ** 2)) gaus_cons = schedulers.get("gauss_nll")(step) if schedulers else torch.tensor(1.0) init_cons = schedulers.get("init_cross_entropy")(step) if schedulers else torch.tensor(1.0) missing_link_cons = schedulers.get("missing_link")(step) if schedulers else torch.tensor(1.0) gaus_cons = gaus_cons.to(self.device) init_cons = init_cons.to(self.device) missing_link_cons = missing_link_cons.to(self.device) loss = gaus_cons * loss_gauss + init_cons * loss_initial + missing_link_cons * loss_missing_link # loss = rmse_loss return { "loss": loss, "loss_gauss": loss_gauss, "loss_initial": loss_initial, "loss_missing_link": loss_missing_link, "rmse_loss": rmse_loss, "beta_gauss_nll": gaus_cons, "beta_init_cross_entropy": init_cons, "beta_missing_link": missing_link_cons, "number_of_paths": torch.tensor(P, device=self.device), } def metric(self, y: Any, y_target: Any) -> Dict: return super().metric(y, y_target) ModelFactory.register(FIMMJPConfig.model_type, FIMMJP) AutoConfig.register(FIMMJPConfig.model_type, FIMMJPConfig) AutoModel.register(FIMMJPConfig, FIMMJP)