cvejoski commited on
Commit
afd14e4
1 Parent(s): 9637da5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +180 -121
README.md CHANGED
@@ -1,201 +1,260 @@
1
  ---
2
- library_name: transformers
 
 
 
 
3
  license: cc-by-4.0
4
  datasets:
5
- - cvejoski/MJP
 
 
 
6
  ---
7
 
8
- # Model Card for Model ID
9
 
10
- <!-- Provide a quick summary of what the model is/does. -->
11
 
 
12
 
 
13
 
14
- ## Model Details
 
15
 
16
- ### Model Description
17
 
18
- <!-- Provide a longer summary of what this model is. -->
 
 
 
 
 
 
 
 
19
 
20
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
 
 
21
 
22
- - **Developed by:** [More Information Needed]
23
- - **Funded by [optional]:** [More Information Needed]
24
- - **Shared by [optional]:** [More Information Needed]
25
- - **Model type:** [More Information Needed]
26
- - **Language(s) (NLP):** [More Information Needed]
27
- - **License:** [More Information Needed]
28
- - **Finetuned from model [optional]:** [More Information Needed]
29
 
30
- ### Model Sources [optional]
31
 
32
- <!-- Provide the basic links for the model. -->
33
 
34
- - **Repository:** [More Information Needed]
35
- - **Paper [optional]:** [More Information Needed]
36
- - **Demo [optional]:** [More Information Needed]
 
37
 
38
- ## Uses
39
 
40
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
 
 
41
 
42
- ### Direct Use
43
 
44
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
 
 
45
 
46
- [More Information Needed]
47
 
48
- ### Downstream Use [optional]
49
 
50
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
51
 
52
- [More Information Needed]
53
 
54
- ### Out-of-Scope Use
 
 
 
55
 
56
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
 
 
 
57
 
58
- [More Information Needed]
 
 
 
 
 
 
 
59
 
60
- ## Bias, Risks, and Limitations
 
 
 
 
 
 
 
 
61
 
62
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
 
 
 
63
 
64
- [More Information Needed]
65
 
66
- ### Recommendations
67
 
68
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
 
 
 
69
 
70
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
 
 
 
71
 
72
- ## How to Get Started with the Model
 
 
 
73
 
74
- Use the code below to get started with the model.
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- [More Information Needed]
77
 
78
- ## Training Details
79
 
80
- ### Training Data
 
 
81
 
82
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
83
 
84
- [More Information Needed]
 
 
 
85
 
86
- ### Training Procedure
 
 
87
 
88
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
 
89
 
90
- #### Preprocessing [optional]
 
91
 
92
- [More Information Needed]
 
 
93
 
 
 
 
94
 
95
- #### Training Hyperparameters
96
 
97
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
98
 
99
- #### Speeds, Sizes, Times [optional]
100
 
101
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
 
 
 
102
 
103
- [More Information Needed]
104
 
105
- ## Evaluation
106
-
107
- <!-- This section describes the evaluation protocols and provides the results. -->
108
-
109
- ### Testing Data, Factors & Metrics
110
-
111
- #### Testing Data
112
-
113
- <!-- This should link to a Dataset Card if possible. -->
114
-
115
- [More Information Needed]
116
-
117
- #### Factors
118
-
119
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
120
-
121
- [More Information Needed]
122
-
123
- #### Metrics
124
-
125
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
126
 
127
- [More Information Needed]
 
 
 
128
 
129
- ### Results
130
 
131
- [More Information Needed]
132
 
133
- #### Summary
134
-
135
-
136
-
137
- ## Model Examination [optional]
138
-
139
- <!-- Relevant interpretability work for the model goes here -->
140
-
141
- [More Information Needed]
142
-
143
- ## Environmental Impact
144
-
145
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
146
-
147
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
148
 
149
- - **Hardware Type:** [More Information Needed]
150
- - **Hours used:** [More Information Needed]
151
- - **Cloud Provider:** [More Information Needed]
152
- - **Compute Region:** [More Information Needed]
153
- - **Carbon Emitted:** [More Information Needed]
154
 
155
- ## Technical Specifications [optional]
156
 
157
- ### Model Architecture and Objective
158
 
159
- [More Information Needed]
160
 
161
- ### Compute Infrastructure
 
 
162
 
163
- [More Information Needed]
164
 
165
- #### Hardware
 
 
166
 
167
- [More Information Needed]
168
 
169
- #### Software
 
 
170
 
171
- [More Information Needed]
172
 
173
- ## Citation [optional]
 
 
174
 
175
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
176
 
177
- **BibTeX:**
178
 
179
- [More Information Needed]
180
 
181
- **APA:**
182
 
183
- [More Information Needed]
184
 
185
- ## Glossary [optional]
186
 
187
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
188
 
189
- [More Information Needed]
 
190
 
191
- ## More Information [optional]
192
 
193
- [More Information Needed]
194
 
195
- ## Model Card Authors [optional]
196
 
197
- [More Information Needed]
 
 
 
 
 
 
 
198
 
199
- ## Model Card Contact
200
 
201
- [More Information Needed]
 
1
  ---
2
+ tags:
3
+ - sequence-classification
4
+ - time-series
5
+ - stochastic-processes
6
+ - markov-jump-processes
7
  license: cc-by-4.0
8
  datasets:
9
+ - custom
10
+ metrics:
11
+ - rmse
12
+ - Hellinger distance
13
  ---
14
 
15
+ # Foundation Inference Model (FIM) for Markov Jump Processes Model Card
16
 
17
+ ## Model Description
18
 
19
+ The Foundation Inference Model (`FIM`) is a neural recognition model designed for zero-shot inference of Markov Jump Processes (MJPs) in bounded state spaces. FIM processes noisy and sparse observations to estimate the transition rate matrix and initial condition of MJPs, without requiring fine-tuning on the target dataset.
20
 
21
+ FIM combines supervised learning on a synthetic dataset of MJPs with attention mechanisms, enabling robust inference for empirical processes with varying dimensionalities. It is the first generative zero-shot model for MJPs, offering broad applicability across domains such as molecular dynamics, ion channel dynamics, and discrete flashing ratchet systems.
22
 
23
+ [![GitHub](https://img.shields.io/badge/GitHub-Repository-blue?logo=github)](https://github.com/cvejoski/OpenFIM)
24
+ [![arXiv](https://img.shields.io/badge/arXiv-2406.06419-B31B1B.svg)](https://arxiv.org/abs/2406.06419)
25
 
26
+ ## Intended Use
27
 
28
+ - Applications:
29
+ - Inferring dynamics of physical, chemical, and biological systems.
30
+ - Estimating transition rates and initial conditions from noisy observations.
31
+ - Zero-shot simulation and analysis of MJPs for:
32
+ - Molecular dynamics simulations (e.g., alanine dipeptide conformations).
33
+ - Protein folding models.
34
+ - Ion channel dynamics.
35
+ - Brownian motor systems.
36
+ - Users: Researchers in statistical physics, molecular biology, and stochastic processes.
37
 
38
+ - Limitations:
39
+ - The model performs well only for processes with dynamics similar to its synthetic training distribution.
40
+ - Poor estimates are likely for datasets with distributions significantly deviating from the synthetic priors (e.g., systems with power-law distributed rates).
41
 
42
+ ### Installation
 
 
 
 
 
 
43
 
44
+ To install the Foundation Inference Model (FIM) from the fim library, follow these steps:
45
 
46
+ 1. **Clone the repository:**
47
 
48
+ ```bash
49
+ git clone https://github.com/cvejoski/OpenFIM.git
50
+ cd OpenFIM
51
+ ```
52
 
53
+ 2. **Install the required dependencies:**
54
 
55
+ ```bash
56
+ pip install -r requirements.txt
57
+ ```
58
 
59
+ 3. **Install the fim library:**
60
 
61
+ ```bash
62
+ pip install .
63
+ ```
64
 
65
+ After completing these steps, you should have the `fim` library installed and ready to use.
66
 
67
+ ### Description of the Input and Output Data of the Forward Step of the FIMMJP Model
68
 
69
+ #### Input Data
70
 
71
+ The input data to the forward step of the `FIMMJP` model is a dictionary containing several key-value pairs. Each key corresponds to a specific type of input data required by the model. Below is a detailed description of each key and its corresponding value:
72
 
73
+ 1. **`observation_grid`**:
74
+ - **Type**: `torch.Tensor`
75
+ - **Shape**: `[B, P, L, 1]`
76
+ - **Description**: This tensor represents the observation grid. `B` is the batch size, `P` is the number of paths, `L` is the length of each path, and `1` indicates a single time dimension.
77
 
78
+ 2. **`observation_values`**:
79
+ - **Type**: torch.Tensor
80
+ - **Shape**: `[B, P, L, D]`
81
+ - **Description**: This tensor contains the observation values. `D` is the dimensionality of the observations.
82
 
83
+ 3. **`seq_lengths`**:
84
+ - **Type**: torch.Tensor
85
+ - **Shape**: `[B, P]`
86
+ - **Description**: This tensor represents the sequence lengths for each path in the batch.
87
+ 4. **`initial_distributions`**:
88
+ - **Type**: torch.Tensor
89
+ - **Shape**: `[B, N]`
90
+ - **Description**: This tensor represents the initial distributions.
91
 
92
+ 4. **Optional Keys**:
93
+ - **`time_normalization_factors`**:
94
+ - **Type**: torch.Tensor
95
+ - **Shape**: `[B, 1]`
96
+ - **Description**: This tensor represents the time normalization factors.
97
+ - **`intensity_matrices`**:
98
+ - **Type**: torch.Tensor
99
+ - **Shape**: `[B, N, N]`
100
+ - **Description**: This tensor represents the intensity matrices.
101
 
102
+ - **`adjacency_matrices`**:
103
+ - **Type**: torch.Tensor
104
+ - **Shape**: `[B, N, N]`
105
+ - **Description**: This tensor represents the adjacency matrices.
106
 
107
+ #### Output Data
108
 
109
+ The output data from the forward step of the `FIMMJP` model is a dictionary containing the following key-value pairs:
110
 
111
+ 1. **`intensity_matrices`**:
112
+ - **Type**: torch.Tensor
113
+ - **Shape**: `[B, N, N]`
114
+ - **Description**: This tensor represents the predicted intensity matrix for each sample in the batch. `N` is the number of states in the process.
115
 
116
+ 2. **`intensity_matrices_variance`**:
117
+ - **Type**: torch.Tensor
118
+ - **Shape**: `[B, N, N]`
119
+ - **Description**: This tensor represents the log variance of the predicted intensity matrix for each sample in the batch.
120
 
121
+ 3. **`initial_condition`**:
122
+ - **Type**: torch.Tensor
123
+ - **Shape**: `[B, N]`
124
+ - **Description**: This tensor represents the predicted initial distribution of states for each sample in the batch.
125
 
126
+ 4. **`losses`** (optional):
127
+ - **Type**: dict
128
+ - **Description**: This dictionary contains the calculated losses if the required keys (`intensity_matrices` and `initial_distributions`) are present in the input data. The keys in this dictionary include:
129
+ - **loss**: The total loss.
130
+ - **loss_gauss**: The Gaussian negative log-likelihood loss.
131
+ - **loss_initial**: The cross-entropy loss for the initial distribution.
132
+ - **loss_missing_link**: The loss for missing links in the intensity matrix.
133
+ - **rmse_loss**: The root mean square error loss.
134
+ - **`beta_gauss_nll`**: The weight for the Gaussian negative log-likelihood loss.
135
+ - **`beta_init_cross_entropy`**: The weight for the cross-entropy loss.
136
+ - **`beta_missing_link`**: The weight for the missing link loss.
137
+ - **`number_of_paths`**: The number of paths in the batch.
138
 
139
+ ### Example Usage
140
 
141
+ Here is an example of how to use the `FIMMJP` model for inference:
142
 
143
+ ```python
144
+ import torch
145
+ from transformers import AutoModel
146
 
147
+ device = "cuda" if torch.cuda.is_available() else "cpu"
148
 
149
+ # Loading the model
150
+ model = AutoModel.from_pretrained("cvejoski/FIMMJP", trust_remote_code=True)
151
+ model = model.to(device)
152
+ model.eval()
153
 
154
+ # Loading the Discrete Flashing Ratchet (DFR) dataset from Huggingface
155
+ data = load_dataset("cvejoski/mjp", download_mode="force_redownload", trust_remote_code=True, name="DFR_V=1")
156
+ data.set_format("torch")
157
 
158
+ # Create batch
159
+ inputs = {k: v.to(device) for k, v in data["train"][:1].items()}
160
 
161
+ # Perform inference
162
+ outputs = model(inputs, n_states=6)
163
 
164
+ # Process the output as needed
165
+ intensity_matrix = outputs["intensity_matrices"]
166
+ initial_distribution = outputs["initial_condition"]
167
 
168
+ print(intensity_matrix)
169
+ print(initial_distribution)
170
+ ```
171
 
172
+ In this example, the input data is prepared and passed to the model's forward step. The model returns the predicted intensity matrix and initial distribution, which can then be processed as needed.
173
 
174
+ ## Model Training
175
 
176
+ ### Training Dataset:
177
 
178
+ - Synthetic MJPs covering state spaces ranging from 2 to 6 states, with up to 300 paths per process.
179
+ - Training spans 45,000 MJPs sampled using the Gillespie algorithm, with various grid and noise configurations.
180
+ - Noise: Includes mislabeled states (1% to 10% noise).
181
+ - Observations: Regular and irregular grids with up to 100 time points.
182
 
 
183
 
184
+ ### Architecture:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ - `Input`: K time series of noisy observations and their associated grids.
187
+ - `Encoder`: LSTM or Transformer for time-series embedding.
188
+ - `Attention`: Self-attention mechanism aggregates embeddings.
189
+ - `Output`: Transition rate matrix, variance matrix, and initial distribution.
190
 
191
+ ### Loss Function:
192
 
193
+ - Supervised likelihood maximization, with regularization for missing links in the intensity matrix.
194
 
195
+ ## Evaluation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ The Foundation Inference Model (FIM) was evaluated on a diverse set of datasets to demonstrate its zero-shot inference capabilities for Markov Jump Processes (MJPs). The evaluation spans datasets representing different domains, such as statistical physics, molecular dynamics, and experimental biological data, testing FIM's ability to infer transition rate matrices, initial distributions, and compute physical properties like stationary distributions and relaxation times.
 
 
 
 
198
 
199
+ ### Datasets
200
 
201
+ The following datasets were used to evaluate FIM:
202
 
203
+ 1. Discrete Flashing Ratchet (DFR):
204
 
205
+ - A 6-state stochastic model of a Brownian motor under a periodic potential.
206
+ - Dataset: 5,000 paths recorded on an irregular grid of 50 time points.
207
+ - Metrics: Transition rates, stationary distributions, and entropy production.
208
 
209
+ 2. Switching Ion Channel (IonCh):
210
 
211
+ - A 3-state model of ion flow across viral potassium channels.
212
+ - Dataset: Experimental recordings of 5,000 paths sampled at 5kHz over one second.
213
+ - Metrics: Mean first-passage times, stationary distributions.
214
 
215
+ 3. Alanine Dipeptide (ADP):
216
 
217
+ - A 6-state molecular dynamics model describing dihedral angles of alanine dipeptide.
218
+ - Dataset: 1 microsecond simulation of atom trajectories, mapped to coarse-grained states.
219
+ - Metrics: Relaxation times, stationary distributions.
220
 
221
+ 4. Simple Protein Folding (PFold):
222
 
223
+ - A 2-state model describing folding and unfolding rates of proteins.
224
+ - Dataset: Simulated transitions between folded and unfolded states.
225
+ - Metrics: Transition rates, mean first-passage times.
226
 
 
227
 
228
+ #### Summary
229
 
230
+ The Foundation Inference Model (FIM) represents a groundbreaking approach to zero-shot inference for Markov Jump Processes (MJPs). FIM enables accurate estimation of a variety of properties, including stationary distributions, relaxation times, mean first-passage times, time-dependent moments, and thermodynamic quantities (e.g., entropy production), all from noisy and discretely observed MJPs with state spaces of varying dimensionalities. Importantly, FIM operates in a zero-shot mode, requiring no additional fine-tuning or retraining on target datasets.
231
 
232
+ To the best of our knowledge, FIM is the first zero-shot generative model for MJPs, showcasing a versatile and powerful methodology for a wide range of physical, chemical, and biological systems. Future directions for FIM include extending its applicability to Birth and Death processes and incorporating more complex prior distributions for transition rates to enhance its generalization capabilities.
233
 
234
+ ## Limitations
235
 
236
+ While FIM has demonstrated strong performance on synthetic datasets, its methodology relies heavily on the distribution of these synthetic data. As a result, the model's effectiveness diminishes when evaluated on empirical datasets that significantly deviate from the synthetic distribution. For instance, as shown in Figure 4 (right), FIM's performance degrades rapidly for cases where the ratio between the largest and smallest transition rates exceeds three orders of magnitude. Such scenarios fall outside the range of FIM's prior Beta distributions and present challenges for accurate inference.
237
 
238
+ Additionally, the dynamics of MJPs underlying systems with long-lived, metastable states depend heavily on the shape of the energy landscape defining the state space. Transition rates in these systems are characterized by the depth of energy traps and can follow distributions not represented in FIM's training prior, such as power-law distributions (e.g., in glassy systems). These distributions lie outside the synthetic ensemble used for training FIM, limiting its ability to generalize to such cases.
239
 
240
+ To address these limitations, future work will explore training FIM on synthetic MJPs with more complex transition rate distributions, such as those arising from systems with exponentially distributed energy traps or power-law-distributed rates, to better handle a broader range of real-world scenarios.
241
+ ## License
242
 
243
+ The model is licensed under the Apache 2.0 License.
244
 
245
+ ## Citation
246
 
247
+ If you use this model in your research, please cite:
248
 
249
+ ```
250
+ @article{berghaus2024foundation,
251
+ title={Foundation Inference Models for Markov Jump Processes},
252
+ author={Berghaus, David and Cvejoski, Kostadin and Seifner, Patrick and Ojeda, Cesar and Sanchez, Ramses J},
253
+ journal={arXiv preprint arXiv:2406.06419},
254
+ year={2024}
255
+ }
256
+ ```
257
 
258
+ ## Contact
259
 
260
+ For questions or issues, please contact Kostadin Cvejoski at cvejoski@gmail.com.