Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,101 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
Here's a sample Hugging Face model card for your code that includes sections on what the model does, its architecture, how to use it, and relevant details.
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Model Card for World Model with MCTS and Transformer Components
|
9 |
+
|
10 |
+
## Model Overview
|
11 |
+
|
12 |
+
This model is a **World Model** that combines **Transformers**, **Mixture of Experts (MoE)** layers, **Monte Carlo Tree Search (MCTS)**, and **Proximal Policy Optimization (PPO)** to simulate and optimize a state-based environment. Designed for complex tasks involving decision-making and action prediction, this model leverages powerful components to encode, predict, and enhance action sequences.
|
13 |
+
|
14 |
+
### Key Components
|
15 |
+
1. **Transformer**: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
|
16 |
+
2. **MCTS**: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
|
17 |
+
3. **PPO Agent**: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
|
18 |
+
4. **Custom Losses**: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.
|
19 |
+
|
20 |
+
### Intended Use
|
21 |
+
This model is suitable for tasks that require complex decision-making and optimization based on action-state transitions. It can be applied in fields like game development, reinforcement learning environments, and AI simulation tasks where sequential decision-making and policy optimization are essential.
|
22 |
+
|
23 |
+
## Model Architecture
|
24 |
+
|
25 |
+
The model is constructed with several primary components:
|
26 |
+
1. **Transformer**: The transformer has encoder and decoder layers with rotary positional encoding and Mixture of Experts (MoE) to improve generalization and reduce computational cost by routing only parts of the data to certain experts. GELU and SwiGLU activation functions are alternated between the experts.
|
27 |
+
2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
|
28 |
+
3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
|
29 |
+
4. **Prediction Network**: Predicts both the policy logits and value estimates for a given state. It outputs the probabilities of different actions as well as a single scalar value.
|
30 |
+
5. **MCTS**: This module performs Monte Carlo Tree Search to evaluate the quality of actions over multiple iterations. It expands nodes based on the policy logits from the Prediction Network and simulates the reward by backpropagating value estimates.
|
31 |
+
6. **PPO Agent**: Uses policy and value estimates to calculate PPO loss, which updates the policy while maintaining the constraint on the KL divergence between old and new policies.
|
32 |
+
|
33 |
+
## Training Details
|
34 |
+
|
35 |
+
The model is trained with the following components and techniques:
|
36 |
+
|
37 |
+
### Training Procedure
|
38 |
+
- **Data Loading**: The data is tokenized and prepared with attention to padding and truncation. Text data is grouped into sequences of fixed length for efficient training.
|
39 |
+
- **Optimization**: Training uses an **AdamW** optimizer with **CosineAnnealingLR** scheduler for learning rate adjustments. The **Gradient Scaler** helps prevent overflow when training with mixed precision.
|
40 |
+
- **Gradient Accumulation**: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
|
41 |
+
- **Loss Functions**: The training process leverages a comprehensive set of custom loss functions:
|
42 |
+
- **InfoNCE Loss**: A contrastive loss to encourage representation similarity between related pairs.
|
43 |
+
- **Covariance Regularization**: Encourages diverse state representations by minimizing co-linearity in embeddings.
|
44 |
+
- **Dynamics Performance Loss**: Combines MSE and variance losses to penalize incorrect state predictions.
|
45 |
+
- **Thought Consistency Loss**: Encourages the model to output consistent states for similar actions.
|
46 |
+
- **Policy Value Joint Loss**: A weighted combination of policy and value loss for the PPO agent.
|
47 |
+
- **Action Diversity Reward**: Rewards diverse action embeddings to avoid mode collapse.
|
48 |
+
- **Exploration Regularization**: Encourages exploration by penalizing high visitation counts.
|
49 |
+
- **KL Divergence Loss**: Keeps the policy update close to the previous policy to stabilize training.
|
50 |
+
|
51 |
+
### Evaluation
|
52 |
+
After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.
|
53 |
+
|
54 |
+
### Checkpoints
|
55 |
+
At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
|
56 |
+
|
57 |
+
## Usage
|
58 |
+
|
59 |
+
To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
|
60 |
+
|
61 |
+
```bash
|
62 |
+
python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
|
63 |
+
```
|
64 |
+
|
65 |
+
This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
|
66 |
+
|
67 |
+
### Model Hyperparameters
|
68 |
+
Here are the main parameters you can set:
|
69 |
+
- `--model_name`: Name of the pretrained model for tokenization.
|
70 |
+
- `--dataset_name`: Hugging Face dataset name.
|
71 |
+
- `--batch_size`: Batch size for training.
|
72 |
+
- `--num_epochs`: Number of epochs to train.
|
73 |
+
- `--max_length`: Max sequence length.
|
74 |
+
- `--transformer_model_path`: Path to the pretrained Transformer model.
|
75 |
+
- `--learning_rate`: Learning rate for optimizer.
|
76 |
+
- `--save_dir`: Directory to save model checkpoints.
|
77 |
+
- `--temperature`, `--alpha`, `--beta`, `--lambda_reg`: Hyperparameters for regularization.
|
78 |
+
|
79 |
+
### Expected Results
|
80 |
+
As training proceeds, you should see progressively lower training and evaluation losses. Upon completion, the model can perform complex decision-making tasks by generating sequences of actions with MCTS and PPO optimization.
|
81 |
+
|
82 |
+
## Requirements
|
83 |
+
|
84 |
+
This code requires:
|
85 |
+
- Python 3.7+
|
86 |
+
- `torch>=1.7.1`
|
87 |
+
- `transformers`
|
88 |
+
- `datasets`
|
89 |
+
- `argparse`
|
90 |
+
|
91 |
+
## Limitations
|
92 |
+
|
93 |
+
Due to the heavy computational nature of this model, training time may be significant, especially on a CPU. GPU support is recommended for efficient training. Additionally, the MCTS and PPO implementations here are designed for demonstration purposes and may need further tuning for specific use cases.
|
94 |
+
|
95 |
+
## Citation
|
96 |
+
|
97 |
+
If you use this model in your research, please cite the author.
|
98 |
+
|
99 |
+
---
|
100 |
+
|
101 |
+
This model card should provide an overview for anyone looking to understand, utilize, or modify your World Model with MCTS and Transformer components.
|