yuhangzang commited on
Commit
9b11cdd
ยท
verified ยท
1 Parent(s): acc09e0

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +175 -0
README.md ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ tags:
6
+ - chain-of-thought
7
+ - implicit-reasoning
8
+ - multimodal
9
+ - llama3
10
+ - instruction-tuned
11
+ datasets:
12
+ - gsm8k
13
+ - svamp
14
+ - multi_arith
15
+ model-index:
16
+ - name: SIM_COT-LLaMA3-CODI-8B
17
+ results:
18
+ - task:
19
+ type: math-word-problems
20
+ name: Arithmetic Reasoning
21
+ dataset:
22
+ name: GSM8K
23
+ type: gsm8k
24
+ metrics:
25
+ - type: accuracy
26
+ value: xx.x
27
+ - task:
28
+ type: math-word-problems
29
+ name: MultiArith
30
+ dataset:
31
+ name: MultiArith
32
+ type: multi_arith
33
+ metrics:
34
+ - type: accuracy
35
+ value: xx.x
36
+ - task:
37
+ type: math-word-problems
38
+ name: SVAMP
39
+ dataset:
40
+ name: SVAMP
41
+ type: svamp
42
+ metrics:
43
+ - type: accuracy
44
+ value: xx.x
45
+ ---
46
+
47
+ # ๐Ÿš€ SIM_COT-LLaMA3-CODI-8B
48
+
49
+ [![๐Ÿค— Model Repo](https://img.shields.io/badge/HuggingFace-Model-blue)](https://huggingface.co/internlm/SIM_COT-LLaMA3-CODI-8B)
50
+ [![๐Ÿ“‚ GitHub](https://img.shields.io/badge/Code-GitHub-black?logo=github)](https://github.com/InternLM/SIM-CoT)
51
+ [![๐Ÿ“„ Paper](https://img.shields.io/badge/Paper-arXiv-red?logo=arxiv)]()
52
+
53
+ <p align="center">
54
+ <img src="./asset/coconut_teaser.png" alt="Teaser Figure" width="600"/>
55
+ </p>
56
+
57
+ ## ๐Ÿ“– Introduction
58
+
59
+ Chain-of-Thought (CoT) prompting has become a widely adopted strategy for enhancing the reasoning capabilities of Large Language Models (LLMs). By decomposing problems into intermediate steps, explicit CoT improves accuracy across a variety of reasoning tasks. However, the token cost of explicit reasoning severely limits its scalability, especially when applied to long-horizon tasks or deployed under strict computational budgets.
60
+
61
+ Implicit CoT methods attempt to address this issue by replacing explicit intermediate steps with continuous latent representations. These approaches achieve higher token efficiency while retaining some of the benefits of step-wise reasoning. Despite this promise, a persistent performance gap remains: implicit CoT methods often underperform compared to explicit reasoning, especially as the number of latent tokens is scaled. Our analysis identifies a fundamental **latent instability problem**: as more implicit reasoning tokens are introduced, training frequently becomes unstable, with latent representations collapsing into homogeneous states that lack semantic diversity. This failure is largely due to the absence of fine-grained, step-level supervision in existing approaches.
62
+
63
+ To overcome this limitation, we introduce **SIM-CoT**, a plug-and-play training module designed to stabilize and enrich the latent reasoning space. SIM-CoT leverages an auxiliary decoder during training that aligns each implicit token with its corresponding explicit reasoning step. This step-level supervision ensures that latent states encode distinct and meaningful information. Importantly, the auxiliary decoder is removed during inference, meaning that SIM-CoT preserves the computational efficiency of implicit CoT without adding runtime overhead.
64
+
65
+ Empirical results demonstrate that SIM-CoT substantially improves both **in-domain accuracy** and **out-of-domain stability**. On smaller models such as GPT-2, SIM-CoT not only boosts implicit baselines like Coconut by +8.2% but also **surpasses explicit CoT by +2.1% while being 2.3ร— more token-efficient**. On larger models, including LLaMA-3.1 8B, SIM-CoT delivers consistent gains, improving CODI by +3.0% and significantly narrowing the performance gap with explicit reasoning. These findings highlight SIM-CoT as an effective and scalable solution for advancing implicit reasoning in LLMs.
66
+
67
+ ---
68
+
69
+ **SIM_COT-LLaMA3-CODI-8B** is a large implicit language model based on **Meta LLaMA-3.1-8B-Instruct**, fine-tuned with **SIM-CoT (Supervised Implicit Chain-of-Thought)** on top of the **CODI latent reasoning framework**.
70
+ It is designed to improve โœจ *implicit reasoning* and ๐Ÿงฎ *arithmetic multi-step problem solving* across benchmarks such as **GSM8K, GSM-Hard, MultiArith, and SVAMP**.
71
+
72
+ ---
73
+
74
+ ## ๐Ÿ“Š Experimental Results
75
+
76
+ We evaluate **SIM-CoT** across both **in-domain** (GSM8K-Aug) and **out-of-domain** (GSM-Hard, MultiArith, SVAMP) benchmarks, using **GPT-2**, **LLaMA-3.2 1B**, **LLaMA-3.2 3B**, and **LLaMA-3.1 8B** as backbones, applied to both **Coconut** and **CODI** frameworks.
77
+
78
+
79
+ <p align="center">
80
+ <img src="./asset/gpt2.png" alt="Main Results on GPT2" width="750"/>
81
+ </p>
82
+
83
+ *Main results on GPT-2. We report accuracy % on in-domain (GSM8k-Aug) and out-of-domain (GSM-Hard, MultiArith, SVAMP) benchmarks. Our SIM-CoT is shown to provide accuracy gains on top of existing methods such as Coconut and CODI.*
84
+
85
+ <p align="center">
86
+ <img src="./asset/llama1b.png" alt="Main Results on LLaMA3 1B" width="750"/>
87
+ </p>
88
+
89
+ *Main results on LLaMA 3.2 1B. We report accuracy % on in-domain (GSM8k-Aug) and out-of-domain (GSM-Hard, MultiArith, SVAMP) benchmarks. Our SIM-CoT builds on CODI to achieve a new SOTA in implicit reasoning while setting performance comparable to explicit CoT.*
90
+
91
+ <p align="center">
92
+ <img src="./asset/llama3b_8b.png" alt="Main Results on LLaMA3 3B and 8B" width="750"/>
93
+ </p>
94
+
95
+ *Main results on LLaMA 3.2 3B and 8B. We report accuracy % on in-domain (GSM8k-Aug) and out-of-domain (GSM-Hard, MultiArith, SVAMP) benchmarks.*
96
+
97
+ ---
98
+
99
+ ## ๐Ÿ“Œ Model Details
100
+
101
+ - ๐Ÿ—๏ธ **Base model**: [LLaMA-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
102
+ - โšก **Fine-tuning method**: LoRA (r=128, alpha=32)
103
+ - ๐Ÿ”‘ **Latent reasoning**: 6 latent steps, projection dimension = 4096
104
+ - ๐ŸŽฏ **Dropout**: 0.0 (projection layer)
105
+ - ๐Ÿ–ฅ๏ธ **Precision**: bf16
106
+ - ๐Ÿ“ **Context length**: 512 tokens
107
+
108
+ The model integrates **implicit reasoning tokens** during training and inference.
109
+ Unlike standard explicit CoT models, SIM-CoT encourages the model to generate **latent structured thoughts** that are decoded only during training, while remaining implicit during inference.
110
+
111
+ ---
112
+
113
+ ## ๐ŸŽฏ Intended Uses
114
+
115
+ - ๐Ÿ”ฌ **AI-related research** (reasoning, representation learning, interpretability)
116
+ - ๐Ÿ“Š **Benchmarking** on arithmetic reasoning datasets (e.g., GSM8K, SVAMP, MultiArith, GSM-Hard)
117
+ - ๐Ÿงฉ Studying **latent representation learning** and **reasoning generalization**
118
+
119
+ โš ๏ธ *Not intended for deployment in production without careful alignment and safety evaluation.*
120
+
121
+ ---
122
+
123
+ ## ๐Ÿ’ป Usage
124
+
125
+ To reproduce our results, follow the steps below:
126
+
127
+ ### 1. Clone the repository
128
+ ```bash
129
+ git clone https://github.com/InternLM/SIM-CoT.git
130
+ cd SIM-CoT/CODI
131
+ ```
132
+
133
+ ### 2. Run the evaluation script
134
+ We provide shell scripts for different backbones and datasets.
135
+ For example, to evaluate on **LLaMA-3.1 8B** with the **SVAMP** dataset, run:
136
+ ```
137
+ bash test_llama8b.sh
138
+ ```
139
+ This will internally call the following command:
140
+ ```
141
+ python test.py \
142
+ --data_name "svamp" \
143
+ --output_dir "$SAVE_DIR" \
144
+ --model_name_or_path path/to/Llama-3.1-8B-Instruct \
145
+ --seed 11 \
146
+ --model_max_length 512 \
147
+ --bf16 \
148
+ --lora_r 128 --lora_alpha 32 --lora_init \
149
+ --batch_size 128 \
150
+ --greedy True \
151
+ --num_latent 6 \
152
+ --use_prj True \
153
+ --prj_dim 4096 \
154
+ --prj_no_ln False \
155
+ --prj_dropout 0.0 \
156
+ --inf_latent_iterations 6 \
157
+ --inf_num_iterations 1 \
158
+ --remove_eos True \
159
+ --use_lora True \
160
+ --ckpt_dir path/to/sim_cot-checkpoints
161
+ ```
162
+ ### 3. Expected output
163
+ After running, the script will print the evaluation summary.
164
+ An example output format is:
165
+ ```
166
+ adapter: None | GSM8K test accuracy: xxx% |
167
+ average length of COT: xxx
168
+ Average accuracy over 1 sampling: xxx
169
+ ```
170
+ - test accuracy: accuracy on the specified benchmark.
171
+ - average length of COT: average number of latent reasoning tokens.
172
+ - average accuracy: aggregated accuracy across sampled runs.
173
+
174
+
175
+