blackwingedkite commited on
Commit
b87a3ce
1 Parent(s): ed0179e

Upload 96 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -35
  2. .gitignore +160 -0
  3. LICENSE +201 -0
  4. data/README.md +32 -0
  5. data/README_zh.md +32 -0
  6. data/belle_multiturn/belle_multiturn.py +79 -0
  7. data/dataset_info.json +193 -0
  8. data/example_dataset/example_dataset.py +46 -0
  9. data/example_dataset/examples.json +20 -0
  10. data/hh_rlhf_en/hh_rlhf_en.py +97 -0
  11. data/output.json +0 -0
  12. data/ultra_chat/ultra_chat.py +76 -0
  13. data/wiki_demo.txt +50 -0
  14. pyproject.toml +3 -0
  15. requirements.txt +19 -0
  16. setup.py +55 -0
  17. src/api_demo.py +14 -0
  18. src/cli_demo.py +38 -0
  19. src/export_model.py +9 -0
  20. src/llmtuner/__init__.py +9 -0
  21. src/llmtuner/api/__init__.py +1 -0
  22. src/llmtuner/api/app.py +126 -0
  23. src/llmtuner/api/protocol.py +85 -0
  24. src/llmtuner/chat/__init__.py +1 -0
  25. src/llmtuner/chat/stream_chat.py +101 -0
  26. src/llmtuner/dsets/__init__.py +3 -0
  27. src/llmtuner/dsets/loader.py +92 -0
  28. src/llmtuner/dsets/preprocess.py +201 -0
  29. src/llmtuner/dsets/utils.py +59 -0
  30. src/llmtuner/extras/__init__.py +0 -0
  31. src/llmtuner/extras/callbacks.py +150 -0
  32. src/llmtuner/extras/constants.py +82 -0
  33. src/llmtuner/extras/logging.py +43 -0
  34. src/llmtuner/extras/misc.py +90 -0
  35. src/llmtuner/extras/patches/__init__.py +0 -0
  36. src/llmtuner/extras/patches/flash_llama.py +301 -0
  37. src/llmtuner/extras/ploting.py +52 -0
  38. src/llmtuner/extras/save_and_load.py +21 -0
  39. src/llmtuner/extras/template.py +603 -0
  40. src/llmtuner/hparams/__init__.py +5 -0
  41. src/llmtuner/hparams/data_args.py +130 -0
  42. src/llmtuner/hparams/finetuning_args.py +98 -0
  43. src/llmtuner/hparams/general_args.py +13 -0
  44. src/llmtuner/hparams/generating_args.py +51 -0
  45. src/llmtuner/hparams/model_args.py +79 -0
  46. src/llmtuner/tuner/__init__.py +1 -0
  47. src/llmtuner/tuner/core/__init__.py +2 -0
  48. src/llmtuner/tuner/core/adapter.py +101 -0
  49. src/llmtuner/tuner/core/loader.py +225 -0
  50. src/llmtuner/tuner/core/parser.py +262 -0
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
data/README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
2
+
3
+ ```json
4
+ "dataset_name": {
5
+ "hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
6
+ "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
7
+ "file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
8
+ "file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
9
+ "ranking": "whether the examples contains ranked responses or not. (default: false)",
10
+ "columns": {
11
+ "prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
12
+ "query": "the name of the column in the datasets containing the queries. (default: input)",
13
+ "response": "the name of the column in the datasets containing the responses. (default: output)",
14
+ "history": "the name of the column in the datasets containing the history of chat. (default: None)"
15
+ }
16
+ }
17
+ ```
18
+
19
+ where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
20
+
21
+ For datasets used in reward modeling or DPO training, the `response` column should be a string list, with the preferred answers appearing first, for example:
22
+
23
+ ```json
24
+ {
25
+ "instruction": "Question",
26
+ "input": "",
27
+ "output": [
28
+ "Chosen answer",
29
+ "Rejected answer"
30
+ ]
31
+ }
32
+ ```
data/README_zh.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以如下格式提供您的数据集定义。
2
+
3
+ ```json
4
+ "数据集名称": {
5
+ "hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
6
+ "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
7
+ "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
8
+ "file_sha1": "数据集文件的SHA-1哈希值(可选)",
9
+ "ranking": "数据集是否包含排序后的回答(默认:false)",
10
+ "columns": {
11
+ "prompt": "数据集代表提示词的表头名称(默认:instruction)",
12
+ "query": "数据集代表请求的表头名称(默认:input)",
13
+ "response": "数据集代表回答的表头名称(默认:output)",
14
+ "history": "数据集代表历史对话的表头名称(默认:None)"
15
+ }
16
+ }
17
+ ```
18
+
19
+ 其中 `prompt` 和 `response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。
20
+
21
+ 对于训练奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如:
22
+
23
+ ```json
24
+ {
25
+ "instruction": "Question",
26
+ "input": "",
27
+ "output": [
28
+ "Chosen answer",
29
+ "Rejected answer"
30
+ ]
31
+ }
32
+ ```
data/belle_multiturn/belle_multiturn.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "BELLE multiturn chat dataset."
7
+
8
+ _CITATION = """\
9
+ @article{belle2023exploring,
10
+ title={Exploring the Impact of Instruction Data Scaling on Large Language Models: An Empirical Study on Real-World Use Cases},
11
+ author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
12
+ journal={arXiv preprint arXiv:2303.14742},
13
+ year={2023}
14
+ }
15
+ """
16
+
17
+ _HOMEPAGE = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M"
18
+ _LICENSE = "gpl-3.0"
19
+ _URL = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
20
+
21
+
22
+ class BelleMultiturn(datasets.GeneratorBasedBuilder):
23
+
24
+ VERSION = datasets.Version("0.0.0")
25
+
26
+ def _info(self) -> datasets.DatasetInfo:
27
+ features = datasets.Features({
28
+ "instruction": datasets.Value("string"),
29
+ "output": datasets.Value("string"),
30
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
31
+ })
32
+ return datasets.DatasetInfo(
33
+ description=_DESCRIPTION,
34
+ features=features,
35
+ homepage=_HOMEPAGE,
36
+ license=_LICENSE,
37
+ citation=_CITATION
38
+ )
39
+
40
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
41
+ file_path = dl_manager.download(_URL)
42
+ return [
43
+ datasets.SplitGenerator(
44
+ name=datasets.Split.TRAIN,
45
+ gen_kwargs={
46
+ "filepath": file_path
47
+ }
48
+ )
49
+ ]
50
+
51
+ def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat with history
52
+ with open(filepath, "r", encoding="utf-8") as f:
53
+ for key, row in enumerate(f):
54
+ data = json.loads(row)
55
+ prompt = data["instruction"].strip()
56
+ response = data["output"].strip()
57
+
58
+ assist_idx = prompt.rfind("Assistant:")
59
+ human_idx = prompt.rfind("Human:")
60
+ query = prompt[human_idx+6:assist_idx].strip()
61
+ prompt = prompt[:human_idx].strip()
62
+ history = []
63
+
64
+ while prompt.rfind("Assistant:") != -1:
65
+ assist_idx = prompt.rfind("Assistant:")
66
+ human_idx = prompt.rfind("Human:")
67
+ if human_idx != -1:
68
+ old_query = prompt[human_idx+6:assist_idx].strip()
69
+ old_resp = prompt[assist_idx+10:].strip()
70
+ history.insert(0, (old_query, old_resp))
71
+ else:
72
+ break
73
+ prompt = prompt[:human_idx].strip()
74
+
75
+ yield key, {
76
+ "instruction": query,
77
+ "output": response,
78
+ "history": history
79
+ }
data/dataset_info.json ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mydata": {
3
+ "file_name": "output.json",
4
+ "file_sha1": "123456789abc"
5
+ },
6
+ "example": {
7
+ "script_url": "example_dataset",
8
+ "columns": {
9
+ "prompt": "instruction",
10
+ "query": "input",
11
+ "response": "output",
12
+ "history": "history"
13
+ }
14
+ },
15
+ "guanaco": {
16
+ "hf_hub_url": "JosephusCheung/GuanacoDataset"
17
+ },
18
+ "belle_0.5m": {
19
+ "hf_hub_url": "BelleGroup/train_0.5M_CN"
20
+ },
21
+ "belle_1m": {
22
+ "hf_hub_url": "BelleGroup/train_1M_CN"
23
+ },
24
+ "belle_2m": {
25
+ "hf_hub_url": "BelleGroup/train_2M_CN"
26
+ },
27
+ "belle_dialog": {
28
+ "hf_hub_url": "BelleGroup/generated_chat_0.4M"
29
+ },
30
+ "belle_math": {
31
+ "hf_hub_url": "BelleGroup/school_math_0.25M"
32
+ },
33
+ "belle_multiturn": {
34
+ "script_url": "belle_multiturn",
35
+ "columns": {
36
+ "prompt": "instruction",
37
+ "query": "",
38
+ "response": "output",
39
+ "history": "history"
40
+ }
41
+ },
42
+ "codealpaca": {
43
+ "hf_hub_url": "sahil2801/CodeAlpaca-20k"
44
+ },
45
+ "alpaca_cot": {
46
+ "hf_hub_url": "QingyiSi/Alpaca-CoT"
47
+ },
48
+ "firefly": {
49
+ "hf_hub_url": "YeungNLP/firefly-train-1.1M",
50
+ "columns": {
51
+ "prompt": "input",
52
+ "query": "",
53
+ "response": "target",
54
+ "history": ""
55
+ }
56
+ },
57
+ "mathinstruct": {
58
+ "hf_hub_url": "TIGER-Lab/MathInstruct",
59
+ "columns": {
60
+ "prompt": "instruction",
61
+ "query": "",
62
+ "response": "output",
63
+ "history": ""
64
+ }
65
+ },
66
+ "webqa": {
67
+ "hf_hub_url": "suolyer/webqa",
68
+ "columns": {
69
+ "prompt": "input",
70
+ "query": "",
71
+ "response": "output",
72
+ "history": ""
73
+ }
74
+ },
75
+ "ultra_chat": {
76
+ "script_url": "ultra_chat",
77
+ "columns": {
78
+ "prompt": "instruction",
79
+ "query": "",
80
+ "response": "output",
81
+ "history": "history"
82
+ }
83
+ },
84
+ "novel_tokens512_50k": {
85
+ "hf_hub_url": "zxbsmk/webnovel_cn"
86
+ },
87
+ "adgen": {
88
+ "hf_hub_url": "HasturOfficial/adgen",
89
+ "columns": {
90
+ "prompt": "content",
91
+ "query": "",
92
+ "response": "summary",
93
+ "history": ""
94
+ }
95
+ },
96
+ "comparison_gpt4_en": {
97
+ "file_name": "comparison_gpt4_data_en.json",
98
+ "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae",
99
+ "ranking": true
100
+ },
101
+ "comparison_gpt4_zh": {
102
+ "file_name": "comparison_gpt4_data_zh.json",
103
+ "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
104
+ "ranking": true
105
+ },
106
+ "hh_rlhf_en": {
107
+ "script_url": "hh_rlhf_en",
108
+ "columns": {
109
+ "prompt": "instruction",
110
+ "query": "",
111
+ "response": "output",
112
+ "history": "history"
113
+ },
114
+ "ranking": true
115
+ },
116
+ "oaast_rm": {
117
+ "file_name": "oaast_rm.json",
118
+ "file_sha1": "622d420e9b70003b210618253bd3d9d2891d86cb",
119
+ "columns": {
120
+ "prompt": "instruction",
121
+ "query": "input",
122
+ "response": "output",
123
+ "history": "history"
124
+ },
125
+ "ranking": true
126
+ },
127
+ "oaast_rm_zh": {
128
+ "file_name": "oaast_rm_zh.json",
129
+ "file_sha1": "1065af1f3784dd61be5e79713a35f427b713a232",
130
+ "columns": {
131
+ "prompt": "instruction",
132
+ "query": "input",
133
+ "response": "output",
134
+ "history": "history"
135
+ },
136
+ "ranking": true
137
+ },
138
+ "wiki_demo": {
139
+ "file_name": "wiki_demo.txt",
140
+ "file_sha1": "b2288edb05b233e5b35250fd4b308a5fa21fa66d",
141
+ "columns": {
142
+ "prompt": "text",
143
+ "query": "",
144
+ "response": "",
145
+ "history": ""
146
+ }
147
+ },
148
+ "refinedweb": {
149
+ "hf_hub_url": "tiiuae/falcon-refinedweb",
150
+ "columns": {
151
+ "prompt": "content",
152
+ "query": "",
153
+ "response": "",
154
+ "history": ""
155
+ }
156
+ },
157
+ "wikipedia_en": {
158
+ "hf_hub_url": "olm/olm-wikipedia-20221220",
159
+ "columns": {
160
+ "prompt": "text",
161
+ "query": "",
162
+ "response": "",
163
+ "history": ""
164
+ }
165
+ },
166
+ "wikipedia_zh": {
167
+ "hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered",
168
+ "columns": {
169
+ "prompt": "completion",
170
+ "query": "",
171
+ "response": "",
172
+ "history": ""
173
+ }
174
+ },
175
+ "the_stack": {
176
+ "hf_hub_url": "bigcode/the-stack",
177
+ "columns": {
178
+ "prompt": "content",
179
+ "query": "",
180
+ "response": "",
181
+ "history": ""
182
+ }
183
+ },
184
+ "starcoder": {
185
+ "hf_hub_url": "bigcode/starcoderdata",
186
+ "columns": {
187
+ "prompt": "content",
188
+ "query": "",
189
+ "response": "",
190
+ "history": ""
191
+ }
192
+ }
193
+ }
data/example_dataset/example_dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "An example of dataset for LLaMA."
7
+ _CITATION = ""
8
+ _HOMEPAGE = ""
9
+ _LICENSE = ""
10
+ _URL = "examples.json"
11
+
12
+
13
+ class ExampleDataset(datasets.GeneratorBasedBuilder):
14
+
15
+ VERSION = datasets.Version("0.0.0")
16
+
17
+ def _info(self) -> datasets.DatasetInfo:
18
+ features = datasets.Features({
19
+ "instruction": datasets.Value("string"),
20
+ "input": datasets.Value("string"),
21
+ "output": datasets.Value("string"),
22
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
23
+ })
24
+ return datasets.DatasetInfo(
25
+ description=_DESCRIPTION,
26
+ features=features,
27
+ homepage=_HOMEPAGE,
28
+ license=_LICENSE,
29
+ citation=_CITATION
30
+ )
31
+
32
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
33
+ file_path = dl_manager.download(_URL)
34
+ return [
35
+ datasets.SplitGenerator(
36
+ name=datasets.Split.TRAIN,
37
+ gen_kwargs={
38
+ "filepath": file_path
39
+ }
40
+ )
41
+ ]
42
+
43
+ def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
44
+ example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
45
+ for key, example in enumerate(example_dataset):
46
+ yield key, example
data/example_dataset/examples.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
4
+ "input": "",
5
+ "output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
6
+ "history": [
7
+ ["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
8
+ ["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
9
+ ]
10
+ },
11
+ {
12
+ "instruction": "好的,谢谢你!",
13
+ "input": "",
14
+ "output": "不客气,有其他需要帮忙的地方可以继续问我。",
15
+ "history": [
16
+ ["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
17
+ ["我在纽约。", "纽约今天晴间多云,气温最高约26摄氏度,最低约18摄氏度,记得注意保暖喔。"]
18
+ ]
19
+ }
20
+ ]
data/hh_rlhf_en/hh_rlhf_en.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
7
+ _CITATION = ""
8
+ _HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
9
+ _LICENSE = "mit"
10
+ _URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/"
11
+ _URLS = {
12
+ "train": [
13
+ _URL + "harmless-base/train.jsonl.gz",
14
+ _URL + "helpful-base/train.jsonl.gz",
15
+ _URL + "helpful-online/train.jsonl.gz",
16
+ _URL + "helpful-rejection-sampled/train.jsonl.gz"
17
+ ],
18
+ "test": [
19
+ _URL + "harmless-base/test.jsonl.gz",
20
+ _URL + "helpful-base/test.jsonl.gz",
21
+ _URL + "helpful-online/test.jsonl.gz",
22
+ _URL + "helpful-rejection-sampled/test.jsonl.gz"
23
+ ]
24
+ }
25
+
26
+
27
+ class HhRlhfEn(datasets.GeneratorBasedBuilder):
28
+
29
+ VERSION = datasets.Version("0.0.0")
30
+
31
+ def _info(self) -> datasets.DatasetInfo:
32
+ features = datasets.Features({
33
+ "instruction": datasets.Value("string"),
34
+ "output": datasets.Sequence(datasets.Value("string")),
35
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
36
+ })
37
+ return datasets.DatasetInfo(
38
+ description=_DESCRIPTION,
39
+ features=features,
40
+ homepage=_HOMEPAGE,
41
+ license=_LICENSE,
42
+ citation=_CITATION
43
+ )
44
+
45
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
46
+ file_path = dl_manager.download_and_extract(_URLS)
47
+ return [
48
+ datasets.SplitGenerator(
49
+ name=datasets.Split.TRAIN,
50
+ gen_kwargs={
51
+ "filepaths": file_path["train"]
52
+ }
53
+ ),
54
+ datasets.SplitGenerator(
55
+ name=datasets.Split.TEST,
56
+ gen_kwargs={
57
+ "filepaths": file_path["test"]
58
+ }
59
+ )
60
+ ]
61
+
62
+ def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
63
+ key = 0
64
+ for filepath in filepaths:
65
+ with open(filepath, "r", encoding="utf-8") as f:
66
+ for row in f:
67
+ data = json.loads(row)
68
+ chosen = data["chosen"]
69
+ rejected = data["rejected"]
70
+
71
+ assist_idx = rejected.rfind("\n\nAssistant: ")
72
+ r_reject = rejected[assist_idx+13:].strip()
73
+ assist_idx = chosen.rfind("\n\nAssistant: ")
74
+ r_accept = chosen[assist_idx+13:].strip()
75
+
76
+ human_idx = chosen.rfind("\n\nHuman: ")
77
+ query = chosen[human_idx+9:assist_idx].strip()
78
+ prompt = chosen[:human_idx]
79
+ history = []
80
+
81
+ while prompt.rfind("\n\nAssistant: ") != -1:
82
+ assist_idx = prompt.rfind("\n\nAssistant: ")
83
+ human_idx = prompt.rfind("\n\nHuman: ")
84
+ if human_idx != -1:
85
+ old_query = prompt[human_idx+9:assist_idx].strip()
86
+ old_resp = prompt[assist_idx+13:].strip()
87
+ history.insert(0, (old_query, old_resp))
88
+ else:
89
+ break
90
+ prompt = prompt[:human_idx]
91
+
92
+ yield key, {
93
+ "instruction": query,
94
+ "output": [r_accept, r_reject],
95
+ "history": history
96
+ }
97
+ key += 1
data/output.json ADDED
The diff for this file is too large to render. See raw diff
 
data/ultra_chat/ultra_chat.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
7
+
8
+ _CITATION = """\
9
+ @misc{UltraChat,
10
+ author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
11
+ title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
12
+ year = {2023},
13
+ publisher = {GitHub},
14
+ journal = {GitHub repository},
15
+ howpublished = {\\url{https://github.com/thunlp/ultrachat}},
16
+ }
17
+ """
18
+
19
+ _HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat"
20
+ _LICENSE = "cc-by-nc-4.0"
21
+ _BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
22
+
23
+
24
+ class BelleMultiturn(datasets.GeneratorBasedBuilder):
25
+
26
+ VERSION = datasets.Version("0.0.0")
27
+
28
+ def _info(self) -> datasets.DatasetInfo:
29
+ features = datasets.Features({
30
+ "instruction": datasets.Value("string"),
31
+ "output": datasets.Value("string"),
32
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
33
+ })
34
+ return datasets.DatasetInfo(
35
+ description=_DESCRIPTION,
36
+ features=features,
37
+ homepage=_HOMEPAGE,
38
+ license=_LICENSE,
39
+ citation=_CITATION
40
+ )
41
+
42
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
43
+ file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
44
+ return [
45
+ datasets.SplitGenerator(
46
+ name=datasets.Split.TRAIN,
47
+ gen_kwargs={
48
+ "filepaths": file_paths
49
+ }
50
+ )
51
+ ]
52
+
53
+ def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
54
+ for filepath in filepaths:
55
+ with open(filepath, "r", encoding="utf-8") as f:
56
+ for row in f:
57
+ try:
58
+ data = json.loads(row)
59
+ except:
60
+ continue
61
+ key = data["id"]
62
+ content = data["data"]
63
+ if len(content) % 2 == 1:
64
+ content.pop(-1)
65
+ if len(content) < 2:
66
+ continue
67
+
68
+ query = content[-2]
69
+ response = content[-1]
70
+ history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
71
+
72
+ yield key, {
73
+ "instruction": query,
74
+ "output": response,
75
+ "history": history
76
+ }
data/wiki_demo.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Machine learning (ML) is a field devoted to understanding and building methods that let machines "learn" – that is, methods that leverage data to improve computer performance on some set of tasks.
2
+ Machine learning algorithms build a model based on sample data, known as training data, in order to make predictions or decisions without being explicitly programmed to do so. Machine learning algorithms are used in a wide variety of applications, such as in medicine, email filtering, speech recognition, agriculture, and computer vision, where it is difficult or unfeasible to develop conventional algorithms to perform the needed tasks.
3
+ A subset of machine learning is closely related to computational statistics, which focuses on making predictions using computers, but not all machine learning is statistical learning. The study of mathematical optimization delivers methods, theory and application domains to the field of machine learning. Data mining is a related field of study, focusing on exploratory data analysis through unsupervised learning.
4
+ Some implementations of machine learning use data and neural networks in a way that mimics the working of a biological brain.
5
+ In its application across business problems, machine learning is also referred to as predictive analytics.
6
+ Learning algorithms work on the basis that strategies, algorithms, and inferences that worked well in the past are likely to continue working well in the future. These inferences can sometimes be obvious, such as "since the sun rose every morning for the last 10,000 days, it will probably rise tomorrow morning as well". Other times, they can be more nuanced, such as "X% of families have geographically separate species with color variants, so there is a Y% chance that undiscovered black swans exist".
7
+ Machine learning programs can perform tasks without being explicitly programmed to do so. It involves computers learning from data provided so that they carry out certain tasks. For simple tasks assigned to computers, it is possible to program algorithms telling the machine how to execute all steps required to solve the problem at hand; on the computer's part, no learning is needed. For more advanced tasks, it can be challenging for a human to manually create the needed algorithms. In practice, it can turn out to be more effective to help the machine develop its own algorithm, rather than having human programmers specify every needed step.
8
+ The discipline of machine learning employs various approaches to teach computers to accomplish tasks where no fully satisfactory algorithm is available. In cases where vast numbers of potential answers exist, one approach is to label some of the correct answers as valid. This can then be used as training data for the computer to improve the algorithm(s) it uses to determine correct answers. For example, to train a system for the task of digital character recognition, the MNIST dataset of handwritten digits has often been used.
9
+ The term machine learning was coined in 1959 by Arthur Samuel, an IBM employee and pioneer in the field of computer gaming and artificial intelligence. The synonym self-teaching computers was also used in this time period.
10
+ By the early 1960s an experimental "learning machine" with punched tape memory, called Cybertron, had been developed by Raytheon Company to analyze sonar signals, electrocardiograms, and speech patterns using rudimentary reinforcement learning. It was repetitively "trained" by a human operator/teacher to recognize patterns and equipped with a "goof" button to cause it to re-evaluate incorrect decisions. A representative book on research into machine learning during the 1960s was Nilsson's book on Learning Machines, dealing mostly with machine learning for pattern classification. Interest related to pattern recognition continued into the 1970s, as described by Duda and Hart in 1973. In 1981 a report was given on using teaching strategies so that a neural network learns to recognize 40 characters (26 letters, 10 digits, and 4 special symbols) from a computer terminal.
11
+ Tom M. Mitchell provided a widely quoted, more formal definition of the algorithms studied in the machine learning field: "A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P if its performance at tasks in T, as measured by P, improves with experience E." This definition of the tasks in which machine learning is concerned offers a fundamentally operational definition rather than defining the field in cognitive terms. This follows Alan Turing's proposal in his paper "Computing Machinery and Intelligence", in which the question "Can machines think?" is replaced with the question "Can machines do what we (as thinking entities) can do?".
12
+ Modern-day machine learning has two objectives, one is to classify data based on models which have been developed, the other purpose is to make predictions for future outcomes based on these models. A hypothetical algorithm specific to classifying data may use computer vision of moles coupled with supervised learning in order to train it to classify the cancerous moles. A machine learning algorithm for stock trading may inform the trader of future potential predictions.
13
+ As a scientific endeavor, machine learning grew out of the quest for artificial intelligence (AI). In the early days of AI as an academic discipline, some researchers were interested in having machines learn from data. They attempted to approach the problem with various symbolic methods, as well as what were then termed "neural networks"; these were mostly perceptrons and other models that were later found to be reinventions of the generalized linear models of statistics. Probabilistic reasoning was also employed, especially in automated medical diagnosis.: 488 
14
+ However, an increasing emphasis on the logical, knowledge-based approach caused a rift between AI and machine learning. Probabilistic systems were plagued by theoretical and practical problems of data acquisition and representation.: 488  By 1980, expert systems had come to dominate AI, and statistics was out of favor. Work on symbolic/knowledge-based learning did continue within AI, leading to inductive logic programming, but the more statistical line of research was now outside the field of AI proper, in pattern recognition and information retrieval.: 708–710, 755  Neural networks research had been abandoned by AI and computer science around the same time. This line, too, was continued outside the AI/CS field, as "connectionism", by researchers from other disciplines including Hopfield, Rumelhart, and Hinton. Their main success came in the mid-1980s with the reinvention of backpropagation.: 25 
15
+ Machine learning (ML), reorganized and recognized as its own field, started to flourish in the 1990s. The field changed its goal from achieving artificial intelligence to tackling solvable problems of a practical nature. It shifted focus away from the symbolic approaches it had inherited from AI, and toward methods and models borrowed from statistics, fuzzy logic, and probability theory.
16
+ Machine learning and data mining often employ the same methods and overlap significantly, but while machine learning focuses on prediction, based on known properties learned from the training data, data mining focuses on the discovery of (previously) unknown properties in the data (this is the analysis step of knowledge discovery in databases). Data mining uses many machine learning methods, but with different goals; on the other hand, machine learning also employs data mining methods as "unsupervised learning" or as a preprocessing step to improve learner accuracy. Much of the confusion between these two research communities (which do often have separate conferences and separate journals, ECML PKDD being a major exception) comes from the basic assumptions they work with: in machine learning, performance is usually evaluated with respect to the ability to reproduce known knowledge, while in knowledge discovery and data mining (KDD) the key task is the discovery of previously unknown knowledge. Evaluated with respect to known knowledge, an uninformed (unsupervised) method will easily be outperformed by other supervised methods, while in a typical KDD task, supervised methods cannot be used due to the unavailability of training data.
17
+ Machine learning also has intimate ties to optimization: many learning problems are formulated as minimization of some loss function on a training set of examples. Loss functions express the discrepancy between the predictions of the model being trained and the actual problem instances (for example, in classification, one wants to assign a label to instances, and models are trained to correctly predict the pre-assigned labels of a set of examples).
18
+ The difference between optimization and machine learning arises from the goal of generalization: while optimization algorithms can minimize the loss on a training set, machine learning is concerned with minimizing the loss on unseen samples. Characterizing the generalization of various learning algorithms is an active topic of current research, especially for deep learning algorithms.
19
+ Machine learning and statistics are closely related fields in terms of methods, but distinct in their principal goal: statistics draws population inferences from a sample, while machine learning finds generalizable predictive patterns. According to Michael I. Jordan, the ideas of machine learning, from methodological principles to theoretical tools, have had a long pre-history in statistics. He also suggested the term data science as a placeholder to call the overall field.
20
+ Leo Breiman distinguished two statistical modeling paradigms: data model and algorithmic model, wherein "algorithmic model" means more or less the machine learning algorithms like Random Forest.
21
+ Some statisticians have adopted methods from machine learning, leading to a combined field that they call statistical learning.
22
+ Analytical and computational techniques derived from deep-rooted physics of disordered systems can be extended to large-scale problems, including machine learning, e.g., to analyze the weight space of deep neural networks. Statistical physics is thus finding applications in the area of medical diagnostics.
23
+ A core objective of a learner is to generalize from its experience. Generalization in this context is the ability of a learning machine to perform accurately on new, unseen examples/tasks after having experienced a learning data set. The training examples come from some generally unknown probability distribution (considered representative of the space of occurrences) and the learner has to build a general model about this space that enables it to produce sufficiently accurate predictions in new cases.
24
+ The computational analysis of machine learning algorithms and their performance is a branch of theoretical computer science known as computational learning theory via the Probably Approximately Correct Learning (PAC) model. Because training sets are finite and the future is uncertain, learning theory usually does not yield guarantees of the performance of algorithms. Instead, probabilistic bounds on the performance are quite common. The bias–variance decomposition is one way to quantify generalization error.
25
+ For the best performance in the context of generalization, the complexity of the hypothesis should match the complexity of the function underlying the data. If the hypothesis is less complex than the function, then the model has under fitted the data. If the complexity of the model is increased in response, then the training error decreases. But if the hypothesis is too complex, then the model is subject to overfitting and generalization will be poorer.
26
+ In addition to performance bounds, learning theorists study the time complexity and feasibility of learning. In computational learning theory, a computation is considered feasible if it can be done in polynomial time. There are two kinds of time complexity results: Positive results show that a certain class of functions can be learned in polynomial time. Negative results show that certain classes cannot be learned in polynomial time.
27
+ Machine learning approaches are traditionally divided into three broad categories, which correspond to learning paradigms, depending on the nature of the "signal" or "feedback" available to the learning system:
28
+ Supervised learning: The computer is presented with example inputs and their desired outputs, given by a "teacher", and the goal is to learn a general rule that maps inputs to outputs.
29
+ Unsupervised learning: No labels are given to the learning algorithm, leaving it on its own to find structure in its input. Unsupervised learning can be a goal in itself (discovering hidden patterns in data) or a means towards an end (feature learning).
30
+ Reinforcement learning: A computer program interacts with a dynamic environment in which it must perform a certain goal (such as driving a vehicle or playing a game against an opponent). As it navigates its problem space, the program is provided feedback that's analogous to rewards, which it tries to maximize. Although each algorithm has advantages and limitations, no single algorithm works for all problems.
31
+ Supervised learning algorithms build a mathematical model of a set of data that contains both the inputs and the desired outputs. The data is known as training data, and consists of a set of training examples. Each training example has one or more inputs and the desired output, also known as a supervisory signal. In the mathematical model, each training example is represented by an array or vector, sometimes called a feature vector, and the training data is represented by a matrix. Through iterative optimization of an objective function, supervised learning algorithms learn a function that can be used to predict the output associated with new inputs. An optimal function will allow the algorithm to correctly determine the output for inputs that were not a part of the training data. An algorithm that improves the accuracy of its outputs or predictions over time is said to have learned to perform that task.
32
+ Types of supervised-learning algorithms include active learning, classification and regression. Classification algorithms are used when the outputs are restricted to a limited set of values, and regression algorithms are used when the outputs may have any numerical value within a range. As an example, for a classification algorithm that filters emails, the input would be an incoming email, and the output would be the name of the folder in which to file the email.
33
+ Similarity learning is an area of supervised machine learning closely related to regression and classification, but the goal is to learn from examples using a similarity function that measures how similar or related two objects are. It has applications in ranking, recommendation systems, visual identity tracking, face verification, and speaker verification.
34
+ Unsupervised learning algorithms take a set of data that contains only inputs, and find structure in the data, like grouping or clustering of data points. The algorithms, therefore, learn from test data that has not been labeled, classified or categorized. Instead of responding to feedback, unsupervised learning algorithms identify commonalities in the data and react based on the presence or absence of such commonalities in each new piece of data. A central application of unsupervised learning is in the field of density estimation in statistics, such as finding the probability density function. Though unsupervised learning encompasses other domains involving summarizing and explaining data features. Unsupervised learning algorithms streamlined the process of survey and graph large indel based haplotypes of a gene of interest from pan-genome.
35
+ Cluster analysis is the assignment of a set of observations into subsets (called clusters) so that observations within the same cluster are similar according to one or more predesignated criteria, while observations drawn from different clusters are dissimilar. Different clustering techniques make different assumptions on the structure of the data, often defined by some similarity metric and evaluated, for example, by internal compactness, or the similarity between members of the same cluster, and separation, the difference between clusters. Other methods are based on estimated density and graph connectivity.
36
+ Semi-supervised learning falls between unsupervised learning (without any labeled training data) and supervised learning (with completely labeled training data). Some of the training examples are missing training labels, yet many machine-learning researchers have found that unlabeled data, when used in conjunction with a small amount of labeled data, can produce a considerable improvement in learning accuracy.
37
+ In weakly supervised learning, the training labels are noisy, limited, or imprecise; however, these labels are often cheaper to obtain, resulting in larger effective training sets.
38
+ Reinforcement learning is an area of machine learning concerned with how software agents ought to take actions in an environment so as to maximize some notion of cumulative reward. Due to its generality, the field is studied in many other disciplines, such as game theory, control theory, operations research, information theory, simulation-based optimization, multi-agent systems, swarm intelligence, statistics and genetic algorithms. In machine learning, the environment is typically represented as a Markov decision process (MDP). Many reinforcements learning algorithms use dynamic programming techniques. Reinforcement learning algorithms do not assume knowledge of an exact mathematical model of the MDP and are used when exact models are infeasible. Reinforcement learning algorithms are used in autonomous vehicles or in learning to play a game against a human opponent.
39
+ Dimensionality reduction is a process of reducing the number of random variables under consideration by obtaining a set of principal variables. In other words, it is a process of reducing the dimension of the feature set, also called the "number of features". Most of the dimensionality reduction techniques can be considered as either feature elimination or extraction. One of the popular methods of dimensionality reduction is principal component analysis (PCA). PCA involves changing higher-dimensional data (e.g., 3D) to a smaller space (e.g., 2D). This results in a smaller dimension of data (2D instead of 3D), while keeping all original variables in the model without changing the data. The manifold hypothesis proposes that high-dimensional data sets lie along low-dimensional manifolds, and many dimensionality reduction techniques make this assumption, leading to the area of manifold learning and manifold regularization.
40
+ Although machine learning has been transformative in some fields, machine-learning programs often fail to deliver expected results. Reasons for this are numerous: lack of (suitable) data, lack of access to the data, data bias, privacy problems, badly chosen tasks and algorithms, wrong tools and people, lack of resources, and evaluation problems.
41
+ In 2018, a self-driving car from Uber failed to detect a pedestrian, who was killed after a collision. Attempts to use machine learning in healthcare with the IBM Watson system failed to deliver even after years of time and billions of dollars invested.
42
+ Machine learning has been used as a strategy to update the evidence related to a systematic review and increased reviewer burden related to the growth of biomedical literature. While it has improved with training sets, it has not yet developed sufficiently to reduce the workload burden without limiting the necessary sensitivity for the findings research themselves.
43
+ Machine learning approaches in particular can suffer from different data biases. A machine learning system trained specifically on current customers may not be able to predict the needs of new customer groups that are not represented in the training data. When trained on human-made data, machine learning is likely to pick up the constitutional and unconscious biases already present in society. Language models learned from data have been shown to contain human-like biases. Machine learning systems used for criminal risk assessment have been found to be biased against black people. In 2015, Google photos would often tag black people as gorillas, and in 2018 this still was not well resolved, but Google reportedly was still using the workaround to remove all gorillas from the training data, and thus was not able to recognize real gorillas at all. Similar issues with recognizing non-white people have been found in many other systems. In 2016, Microsoft tested a chatbot that learned from Twitter, and it quickly picked up racist and sexist language. Because of such challenges, the effective use of machine learning may take longer to be adopted in other domains. Concern for fairness in machine learning, that is, reducing bias in machine learning and propelling its use for human good is increasingly expressed by artificial intelligence scientists, including Fei-Fei Li, who reminds engineers that "There's nothing artificial about AI...It's inspired by people, it's created by people, and—most importantly—it impacts people. It is a powerful tool we are only just beginning to understand, and that is a profound responsibility."
44
+ Learners can also disappoint by "learning the wrong lesson". A toy example is that an image classifier trained only on pictures of brown horses and black cats might conclude that all brown patches are likely to be horses. A real-world example is that, unlike humans, current image classifiers often do not primarily make judgments from the spatial relationship between components of the picture, and they learn relationships between pixels that humans are oblivious to, but that still correlate with images of certain types of real objects. Modifying these patterns on a legitimate image can result in "adversarial" images that the system misclassifies.
45
+ Adversarial vulnerabilities can also result in nonlinear systems, or from non-pattern perturbations. Some systems are so brittle that changing a single adversarial pixel predictably induces misclassification.[citation needed] Machine learning models are often vulnerable to manipulation and/or evasion via adversarial machine learning.
46
+ Researchers have demonstrated how backdoors can be placed undetectably into classifying (e.g., for categories "spam" and well-visible "not spam" of posts) machine learning models which are often developed and/or trained by third parties. Parties can change the classification of any input, including in cases for which a type of data/software transparency is provided, possibly including white-box access.
47
+ Machine learning poses a host of ethical questions. Systems that are trained on datasets collected with biases may exhibit these biases upon use (algorithmic bias), thus digitizing cultural prejudices. For example, in 1988, the UK's Commission for Racial Equality found that St. George's Medical School had been using a computer program trained from data of previous admissions staff and this program had denied nearly 60 candidates who were found to be either women or had non-European sounding names. Using job hiring data from a firm with racist hiring policies may lead to a machine learning system duplicating the bias by scoring job applicants by similarity to previous successful applicants. Responsible collection of data and documentation of algorithmic rules used by a system thus is a critical part of machine learning.
48
+ AI can be well-equipped to make decisions in technical fields, which rely heavily on data and historical information. These decisions rely on the objectivity and logical reasoning. Because human languages contain biases, machines trained on language corpora will necessarily also learn these biases.
49
+ Other forms of ethical challenges, not related to personal biases, are seen in health care. There are concerns among health care professionals that these systems might not be designed in the public's interest but as income-generating machines. This is especially true in the United States where there is a long-standing ethical dilemma of improving health care, but also increase profits. For example, the algorithms could be designed to provide patients with unnecessary tests or medication in which the algorithm's proprietary owners hold stakes. There is potential for machine learning in health care to provide professionals an additional tool to diagnose, medicate, and plan recovery paths for patients, but this requires these biases to be mitigated.
50
+ Since the 2010s, advances in both machine learning algorithms and computer hardware have led to more efficient methods for training deep neural networks (a particular narrow subdomain of machine learning) that contain many layers of non-linear hidden units. By 2019, graphic processing units (GPUs), often with AI-specific enhancements, had displaced CPUs as the dominant method of training large-scale commercial cloud AI. OpenAI estimated the hardware computing used in the largest deep learning projects from AlexNet (2012) to AlphaZero (2017), and found a 300,000-fold increase in the amount of compute required, with a doubling-time trendline of 3.4 months.
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.13.1
2
+ transformers>=4.30.0
3
+ datasets>=2.12.0
4
+ accelerate>=0.21.0
5
+ peft>=0.4.0
6
+ trl>=0.7.1
7
+ scipy
8
+ sentencepiece
9
+ protobuf
10
+ tiktoken
11
+ jieba
12
+ rouge-chinese
13
+ nltk
14
+ gradio>=3.36.0
15
+ uvicorn
16
+ pydantic==1.10.11
17
+ fastapi==0.95.1
18
+ sse-starlette
19
+ matplotlib
setup.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from setuptools import setup, find_packages
4
+
5
+
6
+ def get_version():
7
+ with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
8
+ file_content = f.read()
9
+ pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
10
+ version, = re.findall(pattern, file_content)
11
+ return version
12
+
13
+
14
+ def get_requires():
15
+ with open("requirements.txt", "r", encoding="utf-8") as f:
16
+ file_content = f.read()
17
+ lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
18
+ return lines
19
+
20
+
21
+ def main():
22
+
23
+ setup(
24
+ name="llmtuner",
25
+ version=get_version(),
26
+ author="hiyouga",
27
+ author_email="hiyouga" "@" "buaa.edu.cn",
28
+ description="Easy-to-use fine-tuning framework using PEFT",
29
+ long_description=open("README.md", "r", encoding="utf-8").read(),
30
+ long_description_content_type="text/markdown",
31
+ keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
32
+ license="Apache 2.0 License",
33
+ url="https://github.com/hiyouga/LLaMA-Efficient-Tuning",
34
+ package_dir={"": "src"},
35
+ packages=find_packages("src"),
36
+ python_requires=">=3.8.0",
37
+ install_requires=get_requires(),
38
+ classifiers=[
39
+ "Development Status :: 3 - Alpha",
40
+ "Intended Audience :: Developers",
41
+ "Intended Audience :: Education",
42
+ "Intended Audience :: Science/Research",
43
+ "License :: OSI Approved :: Apache Software License",
44
+ "Operating System :: OS Independent",
45
+ "Programming Language :: Python :: 3",
46
+ "Programming Language :: Python :: 3.8",
47
+ "Programming Language :: Python :: 3.9",
48
+ "Programming Language :: Python :: 3.10",
49
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
50
+ ]
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
src/api_demo.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+
3
+ from llmtuner import ChatModel, create_app
4
+
5
+
6
+ def main():
7
+ chat_model = ChatModel()
8
+ app = create_app(chat_model)
9
+ uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
10
+ print("Visit http://localhost:8000/docs for API document.")
11
+
12
+
13
+ if __name__ == "__main__":
14
+ main()
src/cli_demo.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmtuner import ChatModel
2
+
3
+
4
+ def main():
5
+ chat_model = ChatModel()
6
+ history = []
7
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
8
+
9
+ while True:
10
+ try:
11
+ query = input("\nUser: ")
12
+ except UnicodeDecodeError:
13
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
14
+ continue
15
+ except Exception:
16
+ raise
17
+
18
+ if query.strip() == "exit":
19
+ break
20
+
21
+ if query.strip() == "clear":
22
+ history = []
23
+ print("History has been removed.")
24
+ continue
25
+
26
+ print("Assistant: ", end="", flush=True)
27
+
28
+ response = ""
29
+ for new_text in chat_model.stream_chat(query, history):
30
+ print(new_text, end="", flush=True)
31
+ response += new_text
32
+ print()
33
+
34
+ history = history + [(query, response)]
35
+
36
+
37
+ if __name__ == "__main__":
38
+ main()
src/export_model.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from llmtuner import export_model
2
+
3
+
4
+ def main():
5
+ export_model()
6
+
7
+
8
+ if __name__ == "__main__":
9
+ main()
src/llmtuner/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Level: api, webui > chat > tuner > dsets > extras, hparams
2
+
3
+ from llmtuner.api import create_app
4
+ from llmtuner.chat import ChatModel
5
+ from llmtuner.tuner import export_model, run_exp
6
+ from llmtuner.webui import create_ui, create_web_demo
7
+
8
+
9
+ __version__ = "0.1.8"
src/llmtuner/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from llmtuner.api.app import create_app
src/llmtuner/api/app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from contextlib import asynccontextmanager
5
+ from sse_starlette import EventSourceResponse
6
+ from typing import List, Tuple
7
+
8
+ from llmtuner.extras.misc import torch_gc
9
+ from llmtuner.chat import ChatModel
10
+ from llmtuner.api.protocol import (
11
+ Role,
12
+ Finish,
13
+ ModelCard,
14
+ ModelList,
15
+ ChatMessage,
16
+ DeltaMessage,
17
+ ChatCompletionRequest,
18
+ ChatCompletionResponse,
19
+ ChatCompletionStreamResponse,
20
+ ChatCompletionResponseChoice,
21
+ ChatCompletionResponseStreamChoice,
22
+ ChatCompletionResponseUsage
23
+ )
24
+
25
+
26
+ @asynccontextmanager
27
+ async def lifespan(app: FastAPI): # collects GPU memory
28
+ yield
29
+ torch_gc()
30
+
31
+
32
+ def create_app(chat_model: ChatModel) -> FastAPI:
33
+ app = FastAPI(lifespan=lifespan)
34
+
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ @app.get("/v1/models", response_model=ModelList)
44
+ async def list_models():
45
+ model_card = ModelCard(id="gpt-3.5-turbo")
46
+ return ModelList(data=[model_card])
47
+
48
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
49
+ async def create_chat_completion(request: ChatCompletionRequest):
50
+ if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
51
+ raise HTTPException(status_code=400, detail="Invalid request")
52
+
53
+ query = request.messages[-1].content
54
+ prev_messages = request.messages[:-1]
55
+ if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
56
+ system = prev_messages.pop(0).content
57
+ else:
58
+ system = None
59
+
60
+ history = []
61
+ if len(prev_messages) % 2 == 0:
62
+ for i in range(0, len(prev_messages), 2):
63
+ if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
64
+ history.append([prev_messages[i].content, prev_messages[i+1].content])
65
+
66
+ if request.stream:
67
+ generate = predict(query, history, system, request)
68
+ return EventSourceResponse(generate, media_type="text/event-stream")
69
+
70
+ response, (prompt_length, response_length) = chat_model.chat(
71
+ query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
72
+ )
73
+
74
+ usage = ChatCompletionResponseUsage(
75
+ prompt_tokens=prompt_length,
76
+ completion_tokens=response_length,
77
+ total_tokens=prompt_length+response_length
78
+ )
79
+
80
+ choice_data = ChatCompletionResponseChoice(
81
+ index=0,
82
+ message=ChatMessage(role=Role.ASSISTANT, content=response),
83
+ finish_reason=Finish.STOP
84
+ )
85
+
86
+ return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
87
+
88
+ async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
89
+ choice_data = ChatCompletionResponseStreamChoice(
90
+ index=0,
91
+ delta=DeltaMessage(role=Role.ASSISTANT),
92
+ finish_reason=None
93
+ )
94
+ chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
95
+ yield chunk.json(exclude_unset=True, ensure_ascii=False)
96
+
97
+ for new_text in chat_model.stream_chat(
98
+ query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
99
+ ):
100
+ if len(new_text) == 0:
101
+ continue
102
+
103
+ choice_data = ChatCompletionResponseStreamChoice(
104
+ index=0,
105
+ delta=DeltaMessage(content=new_text),
106
+ finish_reason=None
107
+ )
108
+ chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
109
+ yield chunk.json(exclude_unset=True, ensure_ascii=False)
110
+
111
+ choice_data = ChatCompletionResponseStreamChoice(
112
+ index=0,
113
+ delta=DeltaMessage(),
114
+ finish_reason=Finish.STOP
115
+ )
116
+ chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
117
+ yield chunk.json(exclude_unset=True, ensure_ascii=False)
118
+ yield "[DONE]"
119
+
120
+ return app
121
+
122
+
123
+ if __name__ == "__main__":
124
+ chat_model = ChatModel()
125
+ app = create_app(chat_model)
126
+ uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
src/llmtuner/api/protocol.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from enum import Enum
3
+ from pydantic import BaseModel, Field
4
+ from typing import List, Optional
5
+
6
+
7
+ class Role(str, Enum):
8
+ USER = "user"
9
+ ASSISTANT = "assistant"
10
+ SYSTEM = "system"
11
+
12
+
13
+ class Finish(str, Enum):
14
+ STOP = "stop"
15
+ LENGTH = "length"
16
+
17
+
18
+ class ModelCard(BaseModel):
19
+ id: str
20
+ object: Optional[str] = "model"
21
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
22
+ owned_by: Optional[str] = "owner"
23
+ root: Optional[str] = None
24
+ parent: Optional[str] = None
25
+ permission: Optional[list] = []
26
+
27
+
28
+ class ModelList(BaseModel):
29
+ object: Optional[str] = "list"
30
+ data: Optional[List[ModelCard]] = []
31
+
32
+
33
+ class ChatMessage(BaseModel):
34
+ role: Role
35
+ content: str
36
+
37
+
38
+ class DeltaMessage(BaseModel):
39
+ role: Optional[Role] = None
40
+ content: Optional[str] = None
41
+
42
+
43
+ class ChatCompletionRequest(BaseModel):
44
+ model: str
45
+ messages: List[ChatMessage]
46
+ temperature: Optional[float] = None
47
+ top_p: Optional[float] = None
48
+ n: Optional[int] = 1
49
+ max_tokens: Optional[int] = None
50
+ stream: Optional[bool] = False
51
+
52
+
53
+ class ChatCompletionResponseChoice(BaseModel):
54
+ index: int
55
+ message: ChatMessage
56
+ finish_reason: Finish
57
+
58
+
59
+ class ChatCompletionResponseStreamChoice(BaseModel):
60
+ index: int
61
+ delta: DeltaMessage
62
+ finish_reason: Optional[Finish] = None
63
+
64
+
65
+ class ChatCompletionResponseUsage(BaseModel):
66
+ prompt_tokens: int
67
+ completion_tokens: int
68
+ total_tokens: int
69
+
70
+
71
+ class ChatCompletionResponse(BaseModel):
72
+ id: Optional[str] = "chatcmpl-default"
73
+ object: Optional[str] = "chat.completion"
74
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
75
+ model: str
76
+ choices: List[ChatCompletionResponseChoice]
77
+ usage: ChatCompletionResponseUsage
78
+
79
+
80
+ class ChatCompletionStreamResponse(BaseModel):
81
+ id: Optional[str] = "chatcmpl-default"
82
+ object: Optional[str] = "chat.completion.chunk"
83
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
84
+ model: str
85
+ choices: List[ChatCompletionResponseStreamChoice]
src/llmtuner/chat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from llmtuner.chat.stream_chat import ChatModel
src/llmtuner/chat/stream_chat.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Dict, Generator, List, Optional, Tuple
3
+ from threading import Thread
4
+ from transformers import GenerationConfig, TextIteratorStreamer
5
+
6
+ from llmtuner.extras.misc import dispatch_model, get_logits_processor
7
+ from llmtuner.extras.template import get_template_and_fix_tokenizer
8
+ from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
9
+
10
+
11
+ class ChatModel:
12
+
13
+ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
14
+ model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
15
+ self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
16
+ self.tokenizer.padding_side = "left"
17
+ self.model = dispatch_model(self.model)
18
+ self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
19
+ self.system_prompt = data_args.system_prompt
20
+
21
+ def process_args(
22
+ self,
23
+ query: str,
24
+ history: Optional[List[Tuple[str, str]]] = None,
25
+ system: Optional[str] = None,
26
+ **input_kwargs
27
+ ) -> Tuple[Dict[str, Any], int]:
28
+ system = system or self.system_prompt
29
+
30
+ prompt, _ = self.template.encode_oneturn(
31
+ tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
32
+ )
33
+ input_ids = torch.tensor([prompt], device=self.model.device)
34
+ prompt_length = len(input_ids[0])
35
+
36
+ do_sample = input_kwargs.pop("do_sample", None)
37
+ temperature = input_kwargs.pop("temperature", None)
38
+ top_p = input_kwargs.pop("top_p", None)
39
+ top_k = input_kwargs.pop("top_k", None)
40
+ repetition_penalty = input_kwargs.pop("repetition_penalty", None)
41
+ max_length = input_kwargs.pop("max_length", None)
42
+ max_new_tokens = input_kwargs.pop("max_new_tokens", None)
43
+
44
+ generating_args = self.generating_args.to_dict()
45
+ generating_args.update(dict(
46
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
47
+ temperature=temperature or generating_args["temperature"],
48
+ top_p=top_p or generating_args["top_p"],
49
+ top_k=top_k or generating_args["top_k"],
50
+ repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
51
+ eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
52
+ pad_token_id=self.tokenizer.pad_token_id
53
+ ))
54
+
55
+ if max_length:
56
+ generating_args.pop("max_new_tokens", None)
57
+ generating_args["max_length"] = max_length
58
+
59
+ if max_new_tokens:
60
+ generating_args.pop("max_length", None)
61
+ generating_args["max_new_tokens"] = max_new_tokens
62
+
63
+ gen_kwargs = dict(
64
+ inputs=input_ids,
65
+ generation_config=GenerationConfig(**generating_args),
66
+ logits_processor=get_logits_processor()
67
+ )
68
+
69
+ return gen_kwargs, prompt_length
70
+
71
+ @torch.inference_mode()
72
+ def chat(
73
+ self,
74
+ query: str,
75
+ history: Optional[List[Tuple[str, str]]] = None,
76
+ system: Optional[str] = None,
77
+ **input_kwargs
78
+ ) -> Tuple[str, Tuple[int, int]]:
79
+ gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
80
+ generation_output = self.model.generate(**gen_kwargs)
81
+ outputs = generation_output.tolist()[0][prompt_length:]
82
+ response = self.tokenizer.decode(outputs, skip_special_tokens=True)
83
+ response_length = len(outputs)
84
+ return response, (prompt_length, response_length)
85
+
86
+ @torch.inference_mode()
87
+ def stream_chat(
88
+ self,
89
+ query: str,
90
+ history: Optional[List[Tuple[str, str]]] = None,
91
+ system: Optional[str] = None,
92
+ **input_kwargs
93
+ ) -> Generator[str, None, None]:
94
+ gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
95
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
96
+ gen_kwargs["streamer"] = streamer
97
+
98
+ thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
99
+ thread.start()
100
+
101
+ yield from streamer
src/llmtuner/dsets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from llmtuner.dsets.loader import get_dataset
2
+ from llmtuner.dsets.preprocess import preprocess_dataset
3
+ from llmtuner.dsets.utils import split_dataset
src/llmtuner/dsets/loader.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TYPE_CHECKING, List, Union
3
+
4
+ from datasets import concatenate_datasets, interleave_datasets, load_dataset
5
+
6
+ from llmtuner.dsets.utils import checksum, EXT2TYPE
7
+ from llmtuner.extras.logging import get_logger
8
+
9
+ if TYPE_CHECKING:
10
+ from datasets import Dataset, IterableDataset
11
+ from llmtuner.hparams import ModelArguments, DataArguments
12
+
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def get_dataset(
18
+ model_args: "ModelArguments",
19
+ data_args: "DataArguments"
20
+ ) -> Union["Dataset", "IterableDataset"]:
21
+ max_samples = data_args.max_samples
22
+ all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
23
+
24
+ for dataset_attr in data_args.dataset_list:
25
+ logger.info("Loading dataset {}...".format(dataset_attr))
26
+
27
+ if dataset_attr.load_from == "hf_hub":
28
+ data_path = dataset_attr.dataset_name
29
+ data_files = None
30
+ elif dataset_attr.load_from == "script":
31
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
32
+ data_files = None
33
+ elif dataset_attr.load_from == "file":
34
+ data_path = None
35
+ data_files: List[str] = []
36
+
37
+ if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
38
+ for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
39
+ data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
40
+ if data_path is None:
41
+ data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
42
+ else:
43
+ assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
44
+ elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
45
+ data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
46
+ data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
47
+ else:
48
+ raise ValueError("File not found.")
49
+
50
+ assert data_path, "File extension must be txt, csv, json or jsonl."
51
+ checksum(data_files, dataset_attr.dataset_sha1)
52
+ else:
53
+ raise NotImplementedError
54
+
55
+ dataset = load_dataset(
56
+ data_path,
57
+ data_files=data_files,
58
+ split=data_args.split,
59
+ cache_dir=model_args.cache_dir,
60
+ streaming=data_args.streaming,
61
+ use_auth_token=True if model_args.use_auth_token else None
62
+ )
63
+
64
+ if max_samples is not None:
65
+ max_samples_temp = min(len(dataset), max_samples)
66
+ dataset = dataset.select(range(max_samples_temp))
67
+
68
+ for column_name in ["prompt", "query", "response", "history"]: # align datasets
69
+ if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
70
+ dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
71
+
72
+ if dataset_attr.system_prompt: # add system prompt
73
+ if data_args.streaming:
74
+ dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
75
+ else:
76
+ dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
77
+
78
+ all_datasets.append(dataset)
79
+
80
+ if len(data_args.dataset_list) == 1:
81
+ return all_datasets[0]
82
+ elif data_args.mix_strategy == "concat":
83
+ if data_args.streaming:
84
+ logger.warning("The samples between different datasets will not be mixed in streaming mode.")
85
+ return concatenate_datasets(all_datasets)
86
+ elif data_args.mix_strategy.startswith("interleave"):
87
+ if not data_args.streaming:
88
+ logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
89
+ stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
90
+ return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
91
+ else:
92
+ raise ValueError("Unknown mixing strategy.")
src/llmtuner/dsets/preprocess.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
3
+ from itertools import chain
4
+
5
+ from llmtuner.extras.constants import IGNORE_INDEX
6
+ from llmtuner.extras.template import get_template_and_fix_tokenizer
7
+
8
+ if TYPE_CHECKING:
9
+ from datasets import Dataset, IterableDataset
10
+ from transformers import Seq2SeqTrainingArguments
11
+ from transformers.tokenization_utils import PreTrainedTokenizer
12
+ from llmtuner.hparams import DataArguments
13
+
14
+
15
+ def preprocess_dataset(
16
+ dataset: Union["Dataset", "IterableDataset"],
17
+ tokenizer: "PreTrainedTokenizer",
18
+ data_args: "DataArguments",
19
+ training_args: "Seq2SeqTrainingArguments",
20
+ stage: Literal["pt", "sft", "rm", "ppo"]
21
+ ) -> Union["Dataset", "IterableDataset"]:
22
+ column_names = list(next(iter(dataset)).keys())
23
+ template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
24
+
25
+ def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
26
+ for i in range(len(examples["prompt"])):
27
+ query, response = examples["prompt"][i], examples["response"][i]
28
+ query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
29
+ history = examples["history"][i] if "history" in examples else None
30
+ system = examples["system"][i] if "system" in examples else None
31
+ yield query, response, history, system
32
+
33
+ def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
34
+ # build grouped texts with format `X1 X2 X3 ...`
35
+ if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
36
+ kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
37
+ else:
38
+ kwargs = dict(add_special_tokens=True)
39
+
40
+ if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
41
+ setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
42
+ setattr(tokenizer, "add_eos_token", True)
43
+
44
+ tokenized_examples = tokenizer(examples["prompt"], **kwargs)
45
+ concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
46
+ total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
47
+ block_size = data_args.cutoff_len
48
+ # we drop the small remainder, and if the total_length < block_size, we exclude this batch
49
+ total_length = (total_length // block_size) * block_size
50
+ # split by chunks of cutoff_len
51
+ result = {
52
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
53
+ for k, t in concatenated_examples.items()
54
+ }
55
+ return result
56
+
57
+ def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
58
+ # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
59
+ # for multiturn examples, we only mask the prompt part in each prompt-response pair.
60
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
61
+
62
+ for query, response, history, system in construct_example(examples):
63
+ input_ids, labels = [], []
64
+
65
+ for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
66
+ tokenizer, query, response, history, system
67
+ )):
68
+ total_len = len(source_ids) + len(target_ids)
69
+ max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
70
+ max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
71
+
72
+ if len(source_ids) > max_source_len:
73
+ source_ids = source_ids[:max_source_len]
74
+ if len(target_ids) > max_target_len:
75
+ target_ids = target_ids[:max_target_len]
76
+
77
+ if turn_idx != 0 and template.efficient_eos:
78
+ source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
79
+ else:
80
+ source_mask = [IGNORE_INDEX] * len(source_ids)
81
+
82
+ input_ids += source_ids + target_ids
83
+ labels += source_mask + target_ids
84
+
85
+ if template.efficient_eos:
86
+ input_ids += [tokenizer.eos_token_id]
87
+ labels += [tokenizer.eos_token_id]
88
+
89
+ if len(input_ids) > data_args.cutoff_len:
90
+ input_ids = input_ids[:data_args.cutoff_len]
91
+ labels = labels[:data_args.cutoff_len]
92
+
93
+ model_inputs["input_ids"].append(input_ids)
94
+ model_inputs["attention_mask"].append([1] * len(input_ids))
95
+ model_inputs["labels"].append(labels)
96
+
97
+ return model_inputs
98
+
99
+ def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
100
+ # build inputs with format `<bos> X` and labels with format `Y <eos>`
101
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
102
+
103
+ for query, response, history, system in construct_example(examples):
104
+ input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
105
+
106
+ if template.efficient_eos:
107
+ labels += [tokenizer.eos_token_id]
108
+
109
+ if len(input_ids) > data_args.cutoff_len:
110
+ input_ids = input_ids[:data_args.cutoff_len]
111
+ if len(labels) > data_args.cutoff_len:
112
+ labels = labels[:data_args.cutoff_len]
113
+
114
+ model_inputs["input_ids"].append(input_ids)
115
+ model_inputs["attention_mask"].append([1] * len(input_ids))
116
+ model_inputs["labels"].append(labels)
117
+
118
+ return model_inputs
119
+
120
+ def preprocess_pairwise_dataset(examples):
121
+ # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
122
+ model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
123
+ for query, response, history, system in construct_example(examples):
124
+ prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
125
+ _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
126
+
127
+ if template.efficient_eos:
128
+ chosen_ids += [tokenizer.eos_token_id]
129
+ rejected_ids += [tokenizer.eos_token_id]
130
+
131
+ total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
132
+ max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
133
+ max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
134
+
135
+ if len(prompt_ids) > max_source_len:
136
+ prompt_ids = prompt_ids[:max_source_len]
137
+ if len(chosen_ids) > max_target_len:
138
+ chosen_ids = chosen_ids[:max_target_len]
139
+ if len(rejected_ids) > max_target_len:
140
+ rejected_ids = rejected_ids[:max_target_len]
141
+
142
+ model_inputs["prompt_ids"].append(prompt_ids)
143
+ model_inputs["chosen_ids"].append(chosen_ids)
144
+ model_inputs["rejected_ids"].append(rejected_ids)
145
+ return model_inputs
146
+
147
+ def print_supervised_dataset_example(example):
148
+ print("input_ids:\n{}".format(example["input_ids"]))
149
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
150
+ print("label_ids:\n{}".format(example["labels"]))
151
+ print("labels:\n{}".format(
152
+ tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
153
+ ))
154
+
155
+ def print_pairwise_dataset_example(example):
156
+ print("prompt_ids:\n{}".format(example["prompt_ids"]))
157
+ print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
158
+ print("chosen_ids:\n{}".format(example["chosen_ids"]))
159
+ print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
160
+ print("rejected_ids:\n{}".format(example["rejected_ids"]))
161
+ print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
162
+
163
+ def print_unsupervised_dataset_example(example):
164
+ print("input_ids:\n{}".format(example["input_ids"]))
165
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
166
+
167
+ if stage == "pt":
168
+ dataset = dataset.filter(lambda example: example["prompt"])
169
+ preprocess_function = preprocess_pretrain_dataset
170
+ print_function = print_unsupervised_dataset_example
171
+ elif stage == "sft" and not training_args.predict_with_generate:
172
+ dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
173
+ preprocess_function = preprocess_supervised_dataset
174
+ print_function = print_supervised_dataset_example
175
+ elif stage == "rm":
176
+ dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
177
+ preprocess_function = preprocess_pairwise_dataset
178
+ print_function = print_pairwise_dataset_example
179
+ else:
180
+ dataset = dataset.filter(lambda example: example["prompt"])
181
+ preprocess_function = preprocess_unsupervised_dataset
182
+ print_function = print_unsupervised_dataset_example
183
+
184
+ with training_args.main_process_first(desc="dataset map pre-processing"):
185
+ kwargs = {}
186
+ if not data_args.streaming:
187
+ kwargs = dict(
188
+ num_proc=data_args.preprocessing_num_workers,
189
+ load_from_cache_file=not data_args.overwrite_cache,
190
+ desc="Running tokenizer on dataset"
191
+ )
192
+
193
+ dataset = dataset.map(
194
+ preprocess_function,
195
+ batched=True,
196
+ remove_columns=column_names,
197
+ **kwargs
198
+ )
199
+
200
+ print_function(next(iter(dataset)))
201
+ return dataset
src/llmtuner/dsets/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
3
+
4
+ from llmtuner.extras.logging import get_logger
5
+
6
+ if TYPE_CHECKING:
7
+ from datasets import Dataset, IterableDataset
8
+ from transformers import TrainingArguments
9
+ from llmtuner.hparams import DataArguments
10
+
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ EXT2TYPE = {
16
+ "csv": "csv",
17
+ "json": "json",
18
+ "jsonl": "json",
19
+ "txt": "text"
20
+ }
21
+
22
+
23
+ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
24
+ if file_sha1 is None:
25
+ logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
26
+ return
27
+
28
+ if len(data_files) != 1:
29
+ logger.warning("Checksum failed: too many files.")
30
+ return
31
+
32
+ with open(data_files[0], "rb") as f:
33
+ sha1 = hashlib.sha1(f.read()).hexdigest()
34
+ if sha1 != file_sha1:
35
+ logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
36
+
37
+
38
+ def split_dataset(
39
+ dataset: Union["Dataset", "IterableDataset"],
40
+ data_args: "DataArguments",
41
+ training_args: "TrainingArguments"
42
+ ) -> Dict[str, "Dataset"]:
43
+ if training_args.do_train:
44
+ if data_args.val_size > 1e-6: # Split the dataset
45
+ if data_args.streaming:
46
+ val_set = dataset.take(int(data_args.val_size))
47
+ train_set = dataset.skip(int(data_args.val_size))
48
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
49
+ return {"train_dataset": train_set, "eval_dataset": val_set}
50
+ else:
51
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
52
+ dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
53
+ return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
54
+ else:
55
+ if data_args.streaming:
56
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
57
+ return {"train_dataset": dataset}
58
+ else: # do_eval or do_predict
59
+ return {"eval_dataset": dataset}
src/llmtuner/extras/__init__.py ADDED
File without changes
src/llmtuner/extras/callbacks.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from typing import TYPE_CHECKING
5
+ from datetime import timedelta
6
+
7
+ from transformers import TrainerCallback
8
+ from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
9
+
10
+ from llmtuner.extras.constants import LOG_FILE_NAME
11
+ from llmtuner.extras.logging import get_logger
12
+
13
+ if TYPE_CHECKING:
14
+ from transformers import TrainingArguments, TrainerState, TrainerControl
15
+
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class SavePeftModelCallback(TrainerCallback):
21
+
22
+ def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
23
+ r"""
24
+ Event called after a checkpoint save.
25
+ """
26
+ if args.should_save:
27
+ output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
28
+ model = kwargs.pop("model")
29
+ if getattr(model, "is_peft_model", False):
30
+ getattr(model, "pretrained_model").save_pretrained(output_dir)
31
+
32
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
33
+ r"""
34
+ Event called at the end of training.
35
+ """
36
+ if args.should_save:
37
+ model = kwargs.pop("model")
38
+ if getattr(model, "is_peft_model", False):
39
+ getattr(model, "pretrained_model").save_pretrained(args.output_dir)
40
+
41
+
42
+ class LogCallback(TrainerCallback):
43
+
44
+ def __init__(self, runner=None):
45
+ self.runner = runner
46
+ self.in_training = False
47
+ self.start_time = time.time()
48
+ self.cur_steps = 0
49
+ self.max_steps = 0
50
+ self.elapsed_time = ""
51
+ self.remaining_time = ""
52
+
53
+ def timing(self):
54
+ cur_time = time.time()
55
+ elapsed_time = cur_time - self.start_time
56
+ avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
57
+ remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
58
+ self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
59
+ self.remaining_time = str(timedelta(seconds=int(remaining_time)))
60
+
61
+ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
62
+ r"""
63
+ Event called at the beginning of training.
64
+ """
65
+ if state.is_local_process_zero:
66
+ self.in_training = True
67
+ self.start_time = time.time()
68
+ self.max_steps = state.max_steps
69
+ if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
70
+ logger.warning("Previous log file in this folder will be deleted.")
71
+ os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
72
+
73
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
74
+ r"""
75
+ Event called at the end of training.
76
+ """
77
+ if state.is_local_process_zero:
78
+ self.in_training = False
79
+ self.cur_steps = 0
80
+ self.max_steps = 0
81
+
82
+ def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
83
+ r"""
84
+ Event called at the end of an substep during gradient accumulation.
85
+ """
86
+ if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
87
+ control.should_epoch_stop = True
88
+ control.should_training_stop = True
89
+
90
+ def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
91
+ r"""
92
+ Event called at the end of a training step.
93
+ """
94
+ if state.is_local_process_zero:
95
+ self.cur_steps = state.global_step
96
+ self.timing()
97
+ if self.runner is not None and self.runner.aborted:
98
+ control.should_epoch_stop = True
99
+ control.should_training_stop = True
100
+
101
+ def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
102
+ r"""
103
+ Event called after an evaluation phase.
104
+ """
105
+ if state.is_local_process_zero and not self.in_training:
106
+ self.cur_steps = 0
107
+ self.max_steps = 0
108
+
109
+ def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
110
+ r"""
111
+ Event called after a successful prediction.
112
+ """
113
+ if state.is_local_process_zero and not self.in_training:
114
+ self.cur_steps = 0
115
+ self.max_steps = 0
116
+
117
+ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
118
+ r"""
119
+ Event called after logging the last logs.
120
+ """
121
+ if not state.is_local_process_zero:
122
+ return
123
+
124
+ logs = dict(
125
+ current_steps=self.cur_steps,
126
+ total_steps=self.max_steps,
127
+ loss=state.log_history[-1].get("loss", None),
128
+ eval_loss=state.log_history[-1].get("eval_loss", None),
129
+ predict_loss=state.log_history[-1].get("predict_loss", None),
130
+ reward=state.log_history[-1].get("reward", None),
131
+ learning_rate=state.log_history[-1].get("learning_rate", None),
132
+ epoch=state.log_history[-1].get("epoch", None),
133
+ percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
134
+ elapsed_time=self.elapsed_time,
135
+ remaining_time=self.remaining_time
136
+ )
137
+ os.makedirs(args.output_dir, exist_ok=True)
138
+ with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
139
+ f.write(json.dumps(logs) + "\n")
140
+
141
+ def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
142
+ r"""
143
+ Event called after a prediction step.
144
+ """
145
+ eval_dataloader = kwargs.pop("eval_dataloader", None)
146
+ if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
147
+ if self.max_steps == 0:
148
+ self.max_steps = len(eval_dataloader)
149
+ self.cur_steps += 1
150
+ self.timing()
src/llmtuner/extras/constants.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IGNORE_INDEX = -100
2
+
3
+ LOG_FILE_NAME = "trainer_log.jsonl"
4
+
5
+ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
6
+
7
+ METHODS = ["full", "freeze", "lora"]
8
+
9
+ TRAINING_STAGES = {
10
+ "Supervised Fine-Tuning": "sft",
11
+ "Reward Modeling": "rm",
12
+ "PPO": "ppo",
13
+ "DPO": "dpo",
14
+ "Pre-Training": "pt"
15
+ }
16
+
17
+ SUPPORTED_MODELS = {
18
+ "LLaMA-7B": "huggyllama/llama-7b",
19
+ "LLaMA-13B": "huggyllama/llama-13b",
20
+ "LLaMA-30B": "huggyllama/llama-30b",
21
+ "LLaMA-65B": "huggyllama/llama-65b",
22
+ "LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
23
+ "LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
24
+ "LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
25
+ "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
26
+ "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
27
+ "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
28
+ "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
29
+ "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
30
+ "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
31
+ "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
32
+ "BLOOM-560M": "bigscience/bloom-560m",
33
+ "BLOOM-3B": "bigscience/bloom-3b",
34
+ "BLOOM-7B1": "bigscience/bloom-7b1",
35
+ "BLOOMZ-560M": "bigscience/bloomz-560m",
36
+ "BLOOMZ-3B": "bigscience/bloomz-3b",
37
+ "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
38
+ "Falcon-7B": "tiiuae/falcon-7b",
39
+ "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
40
+ "Falcon-40B": "tiiuae/falcon-40b",
41
+ "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
42
+ "Baichuan-7B": "baichuan-inc/Baichuan-7B",
43
+ "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
44
+ "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
45
+ "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
46
+ "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
47
+ "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
48
+ "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
49
+ "InternLM-7B": "internlm/internlm-7b",
50
+ "InternLM-7B-Chat": "internlm/internlm-chat-7b",
51
+ "Qwen-7B": "Qwen/Qwen-7B",
52
+ "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
53
+ "XVERSE-13B": "xverse/XVERSE-13B",
54
+ "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
55
+ "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
56
+ }
57
+
58
+ DEFAULT_MODULE = {
59
+ "LLaMA": "q_proj,v_proj",
60
+ "LLaMA2": "q_proj,v_proj",
61
+ "ChineseLLaMA2": "q_proj,v_proj",
62
+ "BLOOM": "query_key_value",
63
+ "BLOOMZ": "query_key_value",
64
+ "Falcon": "query_key_value",
65
+ "Baichuan": "W_pack",
66
+ "Baichuan2": "W_pack",
67
+ "InternLM": "q_proj,v_proj",
68
+ "Qwen": "c_attn",
69
+ "XVERSE": "q_proj,v_proj",
70
+ "ChatGLM2": "query_key_value"
71
+ }
72
+
73
+ DEFAULT_TEMPLATE = {
74
+ "LLaMA2": "llama2",
75
+ "ChineseLLaMA2": "llama2_zh",
76
+ "Baichuan": "baichuan",
77
+ "Baichuan2": "baichuan2",
78
+ "InternLM": "intern",
79
+ "Qwen": "chatml",
80
+ "XVERSE": "xverse",
81
+ "ChatGLM2": "chatglm2"
82
+ }
src/llmtuner/extras/logging.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+
4
+
5
+ class LoggerHandler(logging.Handler):
6
+
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.log = ""
10
+
11
+ def reset(self):
12
+ self.log = ""
13
+
14
+ def emit(self, record):
15
+ if record.name == "httpx":
16
+ return
17
+ log_entry = self.format(record)
18
+ self.log += log_entry
19
+ self.log += "\n\n"
20
+
21
+
22
+ def reset_logging():
23
+ r"""
24
+ Removes basic config of root logger
25
+ """
26
+ root = logging.getLogger()
27
+ list(map(root.removeHandler, root.handlers))
28
+ list(map(root.removeFilter, root.filters))
29
+
30
+
31
+ def get_logger(name: str) -> logging.Logger:
32
+ formatter = logging.Formatter(
33
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
34
+ datefmt="%m/%d/%Y %H:%M:%S"
35
+ )
36
+ handler = logging.StreamHandler(sys.stdout)
37
+ handler.setFormatter(formatter)
38
+
39
+ logger = logging.getLogger(name)
40
+ logger.setLevel(logging.INFO)
41
+ logger.addHandler(handler)
42
+
43
+ return logger
src/llmtuner/extras/misc.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ from typing import TYPE_CHECKING, Tuple
4
+ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
5
+
6
+ if TYPE_CHECKING:
7
+ from transformers.modeling_utils import PreTrainedModel
8
+
9
+
10
+ class AverageMeter:
11
+ r"""
12
+ Computes and stores the average and current value.
13
+ """
14
+ def __init__(self):
15
+ self.reset()
16
+
17
+ def reset(self):
18
+ self.val = 0
19
+ self.avg = 0
20
+ self.sum = 0
21
+ self.count = 0
22
+
23
+ def update(self, val, n=1):
24
+ self.val = val
25
+ self.sum += val * n
26
+ self.count += n
27
+ self.avg = self.sum / self.count
28
+
29
+
30
+ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
31
+ r"""
32
+ Returns the number of trainable parameters and number of all parameters in the model.
33
+ """
34
+ trainable_params, all_param = 0, 0
35
+ for param in model.parameters():
36
+ num_params = param.numel()
37
+ # if using DS Zero 3 and the weights are initialized empty
38
+ if num_params == 0 and hasattr(param, "ds_numel"):
39
+ num_params = param.ds_numel
40
+
41
+ # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
42
+ if param.__class__.__name__ == "Params4bit":
43
+ num_params = num_params * 2
44
+
45
+ all_param += num_params
46
+ if param.requires_grad:
47
+ trainable_params += num_params
48
+
49
+ return trainable_params, all_param
50
+
51
+
52
+ def get_logits_processor() -> LogitsProcessorList:
53
+ logits_processor = LogitsProcessorList()
54
+ logits_processor.append(InfNanRemoveLogitsProcessor())
55
+ return logits_processor
56
+
57
+
58
+ def torch_gc() -> None:
59
+ r"""
60
+ Collects GPU memory.
61
+ """
62
+ gc.collect()
63
+ if torch.cuda.is_available():
64
+ torch.cuda.empty_cache()
65
+ torch.cuda.ipc_collect()
66
+
67
+
68
+ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
69
+ r"""
70
+ Dispatches a pre-trained model to GPUs with balanced memory.
71
+ Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
72
+ """
73
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
74
+ return model
75
+
76
+ if torch.cuda.device_count() > 1:
77
+ from accelerate import dispatch_model
78
+ from accelerate.utils import infer_auto_device_map, get_balanced_memory
79
+
80
+ if model._no_split_modules is None:
81
+ raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
82
+
83
+ kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
84
+ max_memory = get_balanced_memory(model, **kwargs)
85
+ # Make sure tied weights are tied before creating the device map.
86
+ model.tie_weights()
87
+ device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
88
+ return dispatch_model(model, device_map)
89
+ else:
90
+ return model.cuda()
src/llmtuner/extras/patches/__init__.py ADDED
File without changes
src/llmtuner/extras/patches/flash_llama.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Modified from:
3
+ # [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py
4
+ # [2] https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py
5
+ # [3] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py
6
+ # [4] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
7
+ # With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
8
+
9
+ import torch
10
+ from typing import TYPE_CHECKING, Optional, Tuple
11
+ from transformers.utils import logging
12
+
13
+ if TYPE_CHECKING:
14
+ from transformers.models.llama.configuration_llama import LlamaConfig
15
+
16
+ try:
17
+ from flash_attn.flash_attn_interface import (
18
+ flash_attn_kvpacked_func,
19
+ flash_attn_varlen_kvpacked_func
20
+ )
21
+ from flash_attn.bert_padding import pad_input, unpad_input
22
+ print(">>>> FlashAttention installed")
23
+ except ImportError:
24
+ raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
25
+
26
+ try:
27
+ from flash_attn.layers.rotary import apply_rotary_emb_func
28
+ print(">>>> Flash RoPE installed")
29
+ except ImportError:
30
+ raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention")
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class LlamaRMSNorm(torch.nn.Module):
37
+
38
+ def __init__(self, hidden_size, eps=1e-6):
39
+ super().__init__()
40
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
41
+ self.variance_epsilon = eps
42
+
43
+ def forward(self, hidden_states):
44
+ input_dtype = hidden_states.dtype
45
+ hidden_states = hidden_states.to(torch.float32)
46
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
47
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
48
+ return (self.weight * hidden_states).to(input_dtype) # for fp32 weight
49
+
50
+
51
+ class FlashRotaryEmbedding(torch.nn.Module):
52
+
53
+ def __init__(
54
+ self,
55
+ dim: int,
56
+ base=10000.0,
57
+ interleaved=False,
58
+ scale_base=None,
59
+ scaling_factor=1.0,
60
+ pos_idx_in_fp32=True,
61
+ device=None
62
+ ):
63
+ super().__init__()
64
+ self.dim = dim
65
+ self.base = float(base)
66
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
67
+ # Generate and save the inverse frequency buffer (non trainable)
68
+ inv_freq = self._compute_inv_freq(device)
69
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
70
+ self.interleaved = interleaved
71
+ self.scale_base = scale_base
72
+ self.scaling_factor = scaling_factor
73
+ scale = (
74
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
75
+ if scale_base is not None else None
76
+ )
77
+ self.register_buffer("scale", scale)
78
+
79
+ self._seq_len_cached = 0
80
+ self._cos_cached = None
81
+ self._sin_cached = None
82
+ self._cos_k_cached = None
83
+ self._sin_k_cached = None
84
+
85
+ def _compute_inv_freq(self, device=None):
86
+ return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
87
+
88
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
89
+ if (
90
+ seqlen > self._seq_len_cached or self._cos_cached.device != device
91
+ or self._cos_cached.dtype != dtype
92
+ or (self.training and self._cos_cached.is_inference())
93
+ ):
94
+ self._seq_len_cached = seqlen
95
+ if self.pos_idx_in_fp32:
96
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
97
+ t /= self.scaling_factor
98
+ if self.inv_freq.dtype != torch.float32:
99
+ inv_freq = self.inv_freq.to(torch.float32)
100
+ else:
101
+ inv_freq = self.inv_freq
102
+ else:
103
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
104
+ t /= self.scaling_factor
105
+ inv_freq = self.inv_freq
106
+ freqs = torch.outer(t, inv_freq)
107
+ if self.scale is None:
108
+ self._cos_cached = torch.cos(freqs).to(dtype)
109
+ self._sin_cached = torch.sin(freqs).to(dtype)
110
+ else:
111
+ power = (
112
+ (torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base
113
+ )
114
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
115
+ # We want the multiplication by scale to happen in fp32
116
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
117
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
118
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
119
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
120
+
121
+ def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ r"""
123
+ q: (batch, seqlen, nheads, headdim)
124
+ k: (batch, seqlen, nheads, headdim)
125
+ seqlen_offset: can be used in generation where the qkv being passed in is only the last
126
+ token in the batch.
127
+ """
128
+ self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
129
+ if self.scale is None:
130
+ return apply_rotary_emb_func(
131
+ q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
132
+ self.interleaved, True # inplace=True
133
+ ), apply_rotary_emb_func(
134
+ k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
135
+ self.interleaved, True # inplace=True
136
+ )
137
+ else:
138
+ assert False
139
+
140
+
141
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
142
+ r"""
143
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
144
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
145
+ """
146
+ batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape
147
+ if n_rep == 1:
148
+ return hidden_states
149
+ hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim)
150
+ return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim)
151
+
152
+
153
+ class LlamaAttention(torch.nn.Module):
154
+
155
+ def __init__(self, config: "LlamaConfig"):
156
+ super().__init__()
157
+ self.config = config
158
+ self.hidden_size = config.hidden_size
159
+ self.num_heads = config.num_attention_heads
160
+ self.head_dim = self.hidden_size // self.num_heads
161
+ self.num_key_value_heads = config.num_key_value_heads
162
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
163
+ self.max_position_embeddings = config.max_position_embeddings
164
+
165
+ if (self.head_dim * self.num_heads) != self.hidden_size:
166
+ raise ValueError(
167
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
168
+ f" and `num_heads`: {self.num_heads})."
169
+ )
170
+
171
+ self.q_proj = torch.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
172
+ self.k_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
173
+ self.v_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
174
+ self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
175
+
176
+ self.register_buffer(
177
+ "norm_factor",
178
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
179
+ persistent=False,
180
+ )
181
+
182
+ if self.config.rope_scaling is None:
183
+ scaling_factor = 1
184
+ else:
185
+ scaling_type = self.config.rope_scaling["type"]
186
+ scaling_factor = self.config.rope_scaling["factor"]
187
+ assert scaling_type == "linear"
188
+
189
+ self.rotary_emb = FlashRotaryEmbedding(
190
+ self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor
191
+ )
192
+
193
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
194
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
195
+
196
+ def forward(
197
+ self,
198
+ hidden_states: torch.Tensor,
199
+ attention_mask: Optional[torch.Tensor] = None,
200
+ position_ids: Optional[torch.LongTensor] = None,
201
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
202
+ output_attentions: bool = False,
203
+ use_cache: bool = False
204
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
205
+ bsz, q_len, h_size = hidden_states.size()
206
+
207
+ has_layer_past = past_key_value is not None
208
+
209
+ if has_layer_past:
210
+ past_kv = past_key_value[0]
211
+ past_len = past_key_value[1]
212
+ else:
213
+ past_len = 0
214
+
215
+ q = self.q_proj(hidden_states)
216
+ k = self.k_proj(hidden_states)
217
+ v = self.v_proj(hidden_states)
218
+
219
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim)
220
+ k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
221
+ v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
222
+
223
+ q, k = self.rotary_emb(q, k, past_len)
224
+
225
+ kv = torch.stack([k, v], 2)
226
+ kv = repeat_kv(kv, self.num_key_value_groups)
227
+
228
+ # Cache QKV values
229
+ if has_layer_past:
230
+ new_len = past_len+q.size(1)
231
+ if new_len > past_kv.size(1):
232
+ past_kv = torch.cat(
233
+ [past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)],
234
+ dim=1
235
+ )
236
+ past_kv[:, past_len:new_len] = kv
237
+ kv = past_kv[:, :new_len]
238
+ else:
239
+ past_kv = kv
240
+
241
+ past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None
242
+
243
+ if attention_mask is not None:
244
+ # varlen, ignore padding tokens, efficient for large batch with many paddings
245
+ logger.warning_once("padded sequences is less efficient")
246
+
247
+ unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
248
+ unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
249
+ attn_outputs = flash_attn_varlen_kvpacked_func(
250
+ unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
251
+ max_seqlen_q, max_seqlen_k,
252
+ dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
253
+ causal=(not has_layer_past), return_attn_probs=output_attentions
254
+ )
255
+
256
+ attn_output = attn_outputs[0] if output_attentions else attn_outputs
257
+ attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size)
258
+ attn_weights = attn_outputs[2] if output_attentions else None
259
+
260
+ else:
261
+ # no padding tokens, more efficient
262
+ attn_outputs = flash_attn_kvpacked_func(
263
+ q, kv, dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
264
+ causal=(not has_layer_past), return_attn_probs=output_attentions
265
+ )
266
+ attn_output = attn_outputs[0] if output_attentions else attn_outputs
267
+ attn_output = attn_output.reshape(bsz, q_len, h_size)
268
+ attn_weights = attn_outputs[2] if output_attentions else None
269
+
270
+ attn_output = self.o_proj(attn_output)
271
+
272
+ if not output_attentions:
273
+ attn_weights = None
274
+
275
+ return attn_output, attn_weights, past_key_value
276
+
277
+
278
+ # Disable the transformation of the attention mask in LlamaModel as flash attention
279
+ # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
280
+ def _prepare_decoder_attention_mask(
281
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
282
+ ):
283
+ # [bsz, seq_len]
284
+ if past_key_values_length > 0 and attention_mask is not None:
285
+ attention_mask = torch.cat(
286
+ (
287
+ torch.full(
288
+ (input_shape[0], past_key_values_length),
289
+ True,
290
+ dtype=attention_mask.dtype,
291
+ device=attention_mask.device
292
+ ),
293
+ attention_mask
294
+ ),
295
+ dim=-1
296
+ )
297
+
298
+ if attention_mask is not None and torch.all(attention_mask):
299
+ return None # This uses the faster call when training with full samples
300
+
301
+ return attention_mask
src/llmtuner/extras/ploting.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import json
4
+ import matplotlib.pyplot as plt
5
+ from typing import List, Optional
6
+ from transformers.trainer import TRAINER_STATE_NAME
7
+
8
+ from llmtuner.extras.logging import get_logger
9
+
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def smooth(scalars: List[float]) -> List[float]:
15
+ r"""
16
+ EMA implementation according to TensorBoard.
17
+ """
18
+ last = scalars[0]
19
+ smoothed = list()
20
+ weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
21
+ for next_val in scalars:
22
+ smoothed_val = last * weight + (1 - weight) * next_val
23
+ smoothed.append(smoothed_val)
24
+ last = smoothed_val
25
+ return smoothed
26
+
27
+
28
+ def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
29
+
30
+ with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
31
+ data = json.load(f)
32
+
33
+ for key in keys:
34
+ steps, metrics = [], []
35
+ for i in range(len(data["log_history"])):
36
+ if key in data["log_history"][i]:
37
+ steps.append(data["log_history"][i]["step"])
38
+ metrics.append(data["log_history"][i][key])
39
+
40
+ if len(metrics) == 0:
41
+ logger.warning(f"No metric {key} to plot.")
42
+ continue
43
+
44
+ plt.figure()
45
+ plt.plot(steps, metrics, alpha=0.4, label="original")
46
+ plt.plot(steps, smooth(metrics), label="smoothed")
47
+ plt.title("training {} of {}".format(key, save_dictionary))
48
+ plt.xlabel("step")
49
+ plt.ylabel(key)
50
+ plt.legend()
51
+ plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
52
+ print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
src/llmtuner/extras/save_and_load.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers.trainer import WEIGHTS_NAME
4
+
5
+ from llmtuner.extras.logging import get_logger
6
+
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
12
+ vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
13
+ if not os.path.exists(vhead_file):
14
+ logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
15
+ return False
16
+ vhead_params = torch.load(vhead_file, map_location="cpu")
17
+ model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
18
+ model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
19
+ model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
20
+ model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
21
+ return True
src/llmtuner/extras/template.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
4
+
5
+ from llmtuner.extras.logging import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ from transformers import PreTrainedTokenizer
9
+
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class Template:
16
+
17
+ prefix: List[Union[str, Dict[str, str]]]
18
+ prompt: List[Union[str, Dict[str, str]]]
19
+ system: str
20
+ sep: List[Union[str, Dict[str, str]]]
21
+ stop_words: List[str]
22
+ use_history: bool
23
+ efficient_eos: bool
24
+
25
+ def encode_oneturn(
26
+ self,
27
+ tokenizer: "PreTrainedTokenizer",
28
+ query: str,
29
+ resp: str,
30
+ history: Optional[List[Tuple[str, str]]] = None,
31
+ system: Optional[str] = None
32
+ ) -> Tuple[List[int], List[int]]:
33
+ r"""
34
+ Returns a single pair of token ids representing prompt and response respectively.
35
+ """
36
+ system, history = self._format(query, resp, history, system)
37
+ encoded_pairs = self._encode(tokenizer, system, history)
38
+ prompt_ids = []
39
+ for query_ids, resp_ids in encoded_pairs[:-1]:
40
+ prompt_ids = prompt_ids + query_ids + resp_ids
41
+ prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
42
+ return prompt_ids, answer_ids
43
+
44
+ def encode_multiturn(
45
+ self,
46
+ tokenizer: "PreTrainedTokenizer",
47
+ query: str,
48
+ resp: str,
49
+ history: Optional[List[Tuple[str, str]]] = None,
50
+ system: Optional[str] = None
51
+ ) -> List[Tuple[List[int], List[int]]]:
52
+ r"""
53
+ Returns multiple pairs of token ids representing prompts and responses respectively.
54
+ """
55
+ system, history = self._format(query, resp, history, system)
56
+ encoded_pairs = self._encode(tokenizer, system, history)
57
+ return encoded_pairs
58
+
59
+ def _format(
60
+ self,
61
+ query: str,
62
+ resp: str,
63
+ history: Optional[List[Tuple[str, str]]] = None,
64
+ system: Optional[str] = None
65
+ ) -> Tuple[str, List[Tuple[str, str]]]:
66
+ r"""
67
+ Aligns inputs to the standard format.
68
+ """
69
+ system = system or self.system # use system if provided
70
+ history = history if (history and self.use_history) else []
71
+ history = history + [(query, resp)]
72
+ return system, history
73
+
74
+ def _get_special_ids(
75
+ self,
76
+ tokenizer: "PreTrainedTokenizer"
77
+ ) -> Tuple[List[int], List[int]]:
78
+ if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
79
+ bos_ids = [tokenizer.bos_token_id]
80
+ else: # baichuan, qwen and gpt2 models have no bos token
81
+ bos_ids = []
82
+
83
+ if tokenizer.eos_token_id is None:
84
+ raise ValueError("EOS token is required.")
85
+
86
+ if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
87
+ eos_ids = []
88
+ else:
89
+ eos_ids = [tokenizer.eos_token_id]
90
+
91
+ return bos_ids, eos_ids
92
+
93
+ def _encode(
94
+ self,
95
+ tokenizer: "PreTrainedTokenizer",
96
+ system: str,
97
+ history: List[Tuple[str, str]]
98
+ ) -> List[Tuple[List[int], List[int]]]:
99
+ r"""
100
+ Encodes formatted inputs to pairs of token ids.
101
+ Turn 0: bos + prefix + sep + query resp + eos
102
+ Turn t: sep + bos + query resp + eos
103
+ """
104
+ bos_ids, eos_ids = self._get_special_ids(tokenizer)
105
+ sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
106
+ encoded_pairs = []
107
+ for turn_idx, (query, resp) in enumerate(history):
108
+ if turn_idx == 0:
109
+ prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
110
+ if len(prefix_ids) != 0: # has prefix
111
+ prefix_ids = bos_ids + prefix_ids + sep_ids
112
+ else:
113
+ prefix_ids = bos_ids
114
+ else:
115
+ prefix_ids = sep_ids + bos_ids
116
+
117
+ query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
118
+ resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
119
+ encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
120
+ return encoded_pairs
121
+
122
+ def _convert_inputs_to_ids(
123
+ self,
124
+ tokenizer: "PreTrainedTokenizer",
125
+ context: List[Union[str, Dict[str, str]]],
126
+ system: Optional[str] = None,
127
+ query: Optional[str] = None,
128
+ idx: Optional[str] = None
129
+ ) -> List[int]:
130
+ r"""
131
+ Converts context to token ids.
132
+ """
133
+ if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
134
+ kwargs = dict(allowed_special="all")
135
+ else:
136
+ kwargs = dict(add_special_tokens=False)
137
+
138
+ token_ids = []
139
+ for elem in context:
140
+ if isinstance(elem, str):
141
+ if len(elem) == 0:
142
+ continue
143
+ elem = elem.replace("{{system}}", system, 1) if system is not None else elem
144
+ elem = elem.replace("{{query}}", query, 1) if query is not None else elem
145
+ elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
146
+ token_ids = token_ids + tokenizer.encode(elem, **kwargs)
147
+ elif isinstance(elem, dict):
148
+ token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
149
+ else:
150
+ raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
151
+
152
+ return token_ids
153
+
154
+
155
+ @dataclass
156
+ class Llama2Template(Template):
157
+
158
+ def _encode(
159
+ self,
160
+ tokenizer: "PreTrainedTokenizer",
161
+ system: str,
162
+ history: List[Tuple[str, str]]
163
+ ) -> List[Tuple[List[int], List[int]]]:
164
+ r"""
165
+ Encodes formatted inputs to pairs of token ids.
166
+ Turn 0: bos + prefix + query resp + eos
167
+ Turn t: bos + query resp + eos
168
+ """
169
+ bos_ids, eos_ids = self._get_special_ids(tokenizer)
170
+ encoded_pairs = []
171
+ for turn_idx, (query, resp) in enumerate(history):
172
+ if turn_idx == 0: # llama2 template has no sep_ids
173
+ query = self.prefix[0].replace("{{system}}", system) + query
174
+ query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
175
+ resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
176
+ encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
177
+ return encoded_pairs
178
+
179
+
180
+ templates: Dict[str, Template] = {}
181
+
182
+
183
+ def register_template(
184
+ name: str,
185
+ prefix: List[Union[str, Dict[str, str]]],
186
+ prompt: List[Union[str, Dict[str, str]]],
187
+ system: str,
188
+ sep: List[Union[str, Dict[str, str]]],
189
+ stop_words: Optional[List[str]] = [],
190
+ use_history: Optional[bool] = True,
191
+ efficient_eos: Optional[bool] = False
192
+ ) -> None:
193
+ template_class = Llama2Template if "llama2" in name else Template
194
+ templates[name] = template_class(
195
+ prefix=prefix,
196
+ prompt=prompt,
197
+ system=system,
198
+ sep=sep,
199
+ stop_words=stop_words,
200
+ use_history=use_history,
201
+ efficient_eos=efficient_eos
202
+ )
203
+
204
+
205
+ def get_template_and_fix_tokenizer(
206
+ name: str,
207
+ tokenizer: "PreTrainedTokenizer"
208
+ ) -> Template:
209
+ if tokenizer.eos_token_id is None:
210
+ tokenizer.eos_token = "<|endoftext|>"
211
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
212
+
213
+ if tokenizer.pad_token_id is None:
214
+ tokenizer.pad_token = tokenizer.eos_token
215
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
216
+
217
+ if name is None:
218
+ return None
219
+
220
+ template = templates.get(name, None)
221
+ assert template is not None, "Template {} does not exist.".format(name)
222
+ tokenizer.add_special_tokens(
223
+ dict(additional_special_tokens=template.stop_words),
224
+ replace_additional_special_tokens=False
225
+ )
226
+ return template
227
+
228
+
229
+ r"""
230
+ Supports language model inference without histories.
231
+ """
232
+ register_template(
233
+ name="vanilla",
234
+ prefix=[],
235
+ prompt=[
236
+ "{{query}}"
237
+ ],
238
+ system="",
239
+ sep=[],
240
+ use_history=False
241
+ )
242
+
243
+
244
+ r"""
245
+ Default template.
246
+ """
247
+ register_template(
248
+ name="default",
249
+ prefix=[
250
+ "{{system}}"
251
+ ],
252
+ prompt=[
253
+ "Human: {{query}}\nAssistant: "
254
+ ],
255
+ system=(
256
+ "A chat between a curious user and an artificial intelligence assistant. "
257
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
258
+ ),
259
+ sep=[
260
+ "\n"
261
+ ]
262
+ )
263
+
264
+
265
+ r"""
266
+ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
267
+ https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
268
+ https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
269
+ """
270
+ register_template(
271
+ name="llama2",
272
+ prefix=[
273
+ "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
274
+ ],
275
+ prompt=[
276
+ "[INST] {{query}} [/INST] "
277
+ ],
278
+ system=(
279
+ "You are a helpful, respectful and honest assistant. "
280
+ "Always answer as helpfully as possible, while being safe. "
281
+ "Your answers should not include any harmful, unethical, "
282
+ "racist, sexist, toxic, dangerous, or illegal content. "
283
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
284
+ "If a question does not make any sense, or is not factually coherent, "
285
+ "explain why instead of answering something not correct. "
286
+ "If you don't know the answer to a question, please don't share false information."
287
+ ),
288
+ sep=[]
289
+ )
290
+
291
+
292
+ r"""
293
+ Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
294
+ https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
295
+ """
296
+ register_template(
297
+ name="llama2_zh",
298
+ prefix=[
299
+ "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
300
+ ],
301
+ prompt=[
302
+ "[INST] {{query}} [/INST] "
303
+ ],
304
+ system="You are a helpful assistant. 你是一个乐于助人的助手。",
305
+ sep=[]
306
+ )
307
+
308
+
309
+ r"""
310
+ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
311
+ https://github.com/ymcui/Chinese-LLaMA-Alpaca
312
+ """
313
+ register_template(
314
+ name="alpaca",
315
+ prefix=[
316
+ "{{system}}"
317
+ ],
318
+ prompt=[
319
+ "### Instruction:\n{{query}}\n\n### Response:\n"
320
+ ],
321
+ system=(
322
+ "Below is an instruction that describes a task. "
323
+ "Write a response that appropriately completes the request."
324
+ ),
325
+ sep=[
326
+ "\n\n"
327
+ ]
328
+ )
329
+
330
+
331
+ r"""
332
+ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
333
+ https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
334
+ """
335
+ register_template(
336
+ name="vicuna",
337
+ prefix=[
338
+ "{{system}}"
339
+ ],
340
+ prompt=[
341
+ "USER: {{query}} ASSISTANT: "
342
+ ],
343
+ system=(
344
+ "A chat between a curious user and an artificial intelligence assistant. "
345
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
346
+ ),
347
+ sep=[]
348
+ )
349
+
350
+
351
+ r"""
352
+ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
353
+ """
354
+ register_template(
355
+ name="belle",
356
+ prefix=[
357
+ "{{system}}"
358
+ ],
359
+ prompt=[
360
+ "Human: {{query}}\n\nBelle: "
361
+ ],
362
+ system="",
363
+ sep=[
364
+ "\n\n"
365
+ ]
366
+ )
367
+
368
+
369
+ r"""
370
+ Supports: https://github.com/CVI-SZU/Linly
371
+ """
372
+ register_template(
373
+ name="linly",
374
+ prefix=[
375
+ "{{system}}"
376
+ ],
377
+ prompt=[
378
+ "User: {{query}}\nBot: "
379
+ ],
380
+ system="",
381
+ sep=[
382
+ "\n"
383
+ ]
384
+ )
385
+
386
+
387
+ r"""
388
+ Supports: https://github.com/Neutralzz/BiLLa
389
+ """
390
+ register_template(
391
+ name="billa",
392
+ prefix=[
393
+ "{{system}}"
394
+ ],
395
+ prompt=[
396
+ "Human: {{query}}\nAssistant: "
397
+ ],
398
+ system="",
399
+ sep=[
400
+ "\n"
401
+ ]
402
+ )
403
+
404
+
405
+ r"""
406
+ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
407
+ """
408
+ register_template(
409
+ name="ziya",
410
+ prefix=[
411
+ "{{system}}"
412
+ ],
413
+ prompt=[
414
+ {"token": "<human>"},
415
+ ":{{query}}\n",
416
+ {"token": "<bot>"},
417
+ ":"
418
+ ],
419
+ system="",
420
+ sep=[
421
+ "\n"
422
+ ]
423
+ )
424
+
425
+
426
+ r"""
427
+ Supports: https://huggingface.co/qhduan/aquilachat-7b
428
+ """
429
+ register_template(
430
+ name="aquila",
431
+ prefix=[
432
+ "{{system}}"
433
+ ],
434
+ prompt=[
435
+ "Human: {{query}}###Assistant: "
436
+ ],
437
+ system=(
438
+ "A chat between a curious human and an artificial intelligence assistant. "
439
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
440
+ ),
441
+ sep=[
442
+ "###"
443
+ ]
444
+ )
445
+
446
+
447
+ r"""
448
+ Supports: https://huggingface.co/internlm/internlm-chat-7b
449
+ """
450
+ register_template(
451
+ name="intern",
452
+ prefix=[
453
+ "{{system}}"
454
+ ],
455
+ prompt=[
456
+ "<|User|>:{{query}}",
457
+ {"token": "<eoh>"},
458
+ "\n<|Bot|>:"
459
+ ],
460
+ system="",
461
+ sep=[
462
+ {"token": "<eoa>"},
463
+ "\n"
464
+ ],
465
+ stop_words=[
466
+ "<eoa>"
467
+ ],
468
+ efficient_eos=True
469
+ )
470
+
471
+
472
+ r"""
473
+ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
474
+ """
475
+ register_template(
476
+ name="baichuan",
477
+ prefix=[
478
+ "{{system}}"
479
+ ],
480
+ prompt=[
481
+ {"token": "<reserved_102>"}, # user token
482
+ "{{query}}",
483
+ {"token": "<reserved_103>"} # assistant token
484
+ ],
485
+ system="",
486
+ sep=[],
487
+ efficient_eos=True
488
+ )
489
+
490
+
491
+ r"""
492
+ Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
493
+ https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
494
+ """
495
+ register_template(
496
+ name="baichuan2",
497
+ prefix=[
498
+ "{{system}}"
499
+ ],
500
+ prompt=[
501
+ {"token": "<reserved_106>"}, # user token
502
+ "{{query}}",
503
+ {"token": "<reserved_107>"} # assistant token
504
+ ],
505
+ system="",
506
+ sep=[],
507
+ efficient_eos=True
508
+ )
509
+
510
+
511
+ r"""
512
+ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
513
+ https://huggingface.co/HuggingFaceH4/starchat-beta
514
+ """
515
+ register_template(
516
+ name="starchat",
517
+ prefix=[
518
+ {"token": "<|system|>"},
519
+ "\n{{system}}",
520
+ ],
521
+ prompt=[
522
+ {"token": "<|user|>"},
523
+ "\n{{query}}",
524
+ {"token": "<|end|>"},
525
+ "\n",
526
+ {"token": "<|assistant|>"}
527
+ ],
528
+ system="",
529
+ sep=[
530
+ {"token": "<|end|>"},
531
+ "\n"
532
+ ],
533
+ stop_words=[
534
+ "<|end|>"
535
+ ],
536
+ efficient_eos=True
537
+ )
538
+
539
+
540
+ r"""
541
+ Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
542
+ """
543
+ register_template(
544
+ name="chatml",
545
+ prefix=[
546
+ {"token": "<|im_start|>"},
547
+ "system\n{{system}}"
548
+ ],
549
+ prompt=[
550
+ {"token": "<|im_start|>"},
551
+ "user\n{{query}}",
552
+ {"token": "<|im_end|>"},
553
+ "\n",
554
+ {"token": "<|im_start|>"},
555
+ "assistant\n"
556
+ ],
557
+ system="You are a helpful assistant.",
558
+ sep=[
559
+ {"token": "<|im_end|>"},
560
+ "\n"
561
+ ],
562
+ stop_words=[
563
+ "<|im_end|>"
564
+ ],
565
+ efficient_eos=True
566
+ )
567
+
568
+
569
+ r"""
570
+ Supports: https://huggingface.co/THUDM/chatglm2-6b
571
+ """
572
+ register_template(
573
+ name="chatglm2",
574
+ prefix=[
575
+ {"token": "[gMASK]"},
576
+ {"token": "sop"},
577
+ "{{system}}"
578
+ ],
579
+ prompt=[
580
+ "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
581
+ ],
582
+ system="",
583
+ sep=[
584
+ "\n\n"
585
+ ],
586
+ efficient_eos=True
587
+ )
588
+
589
+
590
+ r"""
591
+ Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
592
+ """
593
+ register_template(
594
+ name="xverse",
595
+ prefix=[
596
+ "{{system}}"
597
+ ],
598
+ prompt=[
599
+ "Human: {{query}}\n\nAssistant: "
600
+ ],
601
+ system="",
602
+ sep=[]
603
+ )
src/llmtuner/hparams/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .data_args import DataArguments
2
+ from .finetuning_args import FinetuningArguments
3
+ from .general_args import GeneralArguments
4
+ from .generating_args import GeneratingArguments
5
+ from .model_args import ModelArguments
src/llmtuner/hparams/data_args.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import List, Literal, Optional
4
+ from dataclasses import dataclass, field
5
+
6
+
7
+ @dataclass
8
+ class DatasetAttr:
9
+
10
+ load_from: str
11
+ dataset_name: Optional[str] = None
12
+ dataset_sha1: Optional[str] = None
13
+ system_prompt: Optional[str] = None
14
+ ranking: Optional[bool] = False
15
+ prompt: Optional[str] = "instruction"
16
+ query: Optional[str] = "input"
17
+ response: Optional[str] = "output"
18
+ history: Optional[str] = None
19
+
20
+ def __repr__(self) -> str:
21
+ return self.dataset_name
22
+
23
+
24
+ @dataclass
25
+ class DataArguments:
26
+ r"""
27
+ Arguments pertaining to what data we are going to input our model for training and evaluation.
28
+ """
29
+ template: Optional[str] = field(
30
+ default=None,
31
+ metadata={"help": "Which template to use for constructing prompts in training and inference."}
32
+ )
33
+ dataset: Optional[str] = field(
34
+ default="alpaca_en",
35
+ metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
36
+ )
37
+ dataset_dir: Optional[str] = field(
38
+ default="data",
39
+ metadata={"help": "The name of the folder containing datasets."}
40
+ )
41
+ split: Optional[str] = field(
42
+ default="train",
43
+ metadata={"help": "Which dataset split to use for training and evaluation."}
44
+ )
45
+ cutoff_len: Optional[int] = field(
46
+ default=1024,
47
+ metadata={"help": "The maximum length of the model inputs after tokenization."}
48
+ )
49
+ streaming: Optional[bool] = field(
50
+ default=False,
51
+ metadata={"help": "Enable streaming mode."}
52
+ )
53
+ buffer_size: Optional[int] = field(
54
+ default=16384,
55
+ metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
56
+ )
57
+ mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
58
+ default="concat",
59
+ metadata={"help": "Strategy to use in dataset mixing."}
60
+ )
61
+ interleave_probs: Optional[str] = field(
62
+ default=None,
63
+ metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
64
+ )
65
+ overwrite_cache: Optional[bool] = field(
66
+ default=False,
67
+ metadata={"help": "Overwrite the cached training and evaluation sets."}
68
+ )
69
+ preprocessing_num_workers: Optional[int] = field(
70
+ default=None,
71
+ metadata={"help": "The number of processes to use for the preprocessing."}
72
+ )
73
+ max_samples: Optional[int] = field(
74
+ default=None,
75
+ metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
76
+ )
77
+ eval_num_beams: Optional[int] = field(
78
+ default=None,
79
+ metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
80
+ )
81
+ ignore_pad_token_for_loss: Optional[bool] = field(
82
+ default=True,
83
+ metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
84
+ )
85
+ system_prompt: Optional[str] = field(
86
+ default=None,
87
+ metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
88
+ )
89
+ val_size: Optional[float] = field(
90
+ default=0,
91
+ metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
92
+ )
93
+
94
+ def init_for_training(self): # support mixing multiple datasets
95
+ dataset_names = [ds.strip() for ds in self.dataset.split(",")]
96
+ with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
97
+ dataset_info = json.load(f)
98
+
99
+ prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
100
+ prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
101
+ assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
102
+
103
+ if self.interleave_probs is not None:
104
+ self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
105
+
106
+ self.dataset_list: List[DatasetAttr] = []
107
+ for i, name in enumerate(dataset_names):
108
+ if name not in dataset_info:
109
+ raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
110
+
111
+ if "hf_hub_url" in dataset_info[name]:
112
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
113
+ elif "script_url" in dataset_info[name]:
114
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
115
+ else:
116
+ dataset_attr = DatasetAttr(
117
+ "file",
118
+ dataset_name=dataset_info[name]["file_name"],
119
+ dataset_sha1=dataset_info[name].get("file_sha1", None)
120
+ )
121
+
122
+ if "columns" in dataset_info[name]:
123
+ dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
124
+ dataset_attr.query = dataset_info[name]["columns"].get("query", None)
125
+ dataset_attr.response = dataset_info[name]["columns"].get("response", None)
126
+ dataset_attr.history = dataset_info[name]["columns"].get("history", None)
127
+
128
+ dataset_attr.ranking = dataset_info[name].get("ranking", False)
129
+ dataset_attr.system_prompt = prompt_list[i]
130
+ self.dataset_list.append(dataset_attr)
src/llmtuner/hparams/finetuning_args.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Literal, Optional
3
+ from dataclasses import asdict, dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class FinetuningArguments:
8
+ r"""
9
+ Arguments pertaining to which techniques we are going to fine-tuning with.
10
+ """
11
+ finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
12
+ default="lora",
13
+ metadata={"help": "Which fine-tuning method to use."}
14
+ )
15
+ num_hidden_layers: Optional[int] = field(
16
+ default=32,
17
+ metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
18
+ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
19
+ LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
20
+ BLOOM choices: [\"24\", \"30\", \"70\"], \
21
+ Falcon choices: [\"32\", \"60\"], \
22
+ Baichuan choices: [\"32\", \"40\"] \
23
+ Qwen choices: [\"32\"], \
24
+ XVERSE choices: [\"40\"], \
25
+ ChatGLM2 choices: [\"28\"]"}
26
+ )
27
+ num_layer_trainable: Optional[int] = field(
28
+ default=3,
29
+ metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
30
+ )
31
+ name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
32
+ default="mlp",
33
+ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
34
+ LLaMA choices: [\"mlp\", \"self_attn\"], \
35
+ BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
36
+ Baichuan choices: [\"mlp\", \"self_attn\"], \
37
+ Qwen choices: [\"mlp\", \"attn\"], \
38
+ LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
39
+ )
40
+ lora_rank: Optional[int] = field(
41
+ default=8,
42
+ metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
43
+ )
44
+ lora_alpha: Optional[float] = field(
45
+ default=32.0,
46
+ metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
47
+ )
48
+ lora_dropout: Optional[float] = field(
49
+ default=0.1,
50
+ metadata={"help": "Dropout rate for the LoRA fine-tuning."}
51
+ )
52
+ lora_target: Optional[str] = field(
53
+ default=None,
54
+ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
55
+ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
56
+ BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
57
+ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
58
+ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
59
+ LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
60
+ )
61
+ resume_lora_training: Optional[bool] = field(
62
+ default=True,
63
+ metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
64
+ )
65
+ ppo_score_norm: Optional[bool] = field(
66
+ default=False,
67
+ metadata={"help": "Use score normalization in PPO Training."}
68
+ )
69
+ dpo_beta: Optional[float] = field(
70
+ default=0.1,
71
+ metadata={"help": "The beta parameter for the DPO loss."}
72
+ )
73
+
74
+ def __post_init__(self):
75
+ if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
76
+ self.lora_target = [target.strip() for target in self.lora_target.split(",")]
77
+
78
+ if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
79
+ trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
80
+ else: # fine-tuning the first n layers if num_layer_trainable < 0
81
+ trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
82
+
83
+ self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
84
+
85
+ assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
86
+
87
+ def save_to_json(self, json_path: str):
88
+ r"""Saves the content of this instance in JSON format inside `json_path`."""
89
+ json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
90
+ with open(json_path, "w", encoding="utf-8") as f:
91
+ f.write(json_string)
92
+
93
+ @classmethod
94
+ def load_from_json(cls, json_path: str):
95
+ r"""Creates an instance from the content of `json_path`."""
96
+ with open(json_path, "r", encoding="utf-8") as f:
97
+ text = f.read()
98
+ return cls(**json.loads(text))
src/llmtuner/hparams/general_args.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+ from dataclasses import dataclass, field
3
+
4
+
5
+ @dataclass
6
+ class GeneralArguments:
7
+ r"""
8
+ Arguments pertaining to which stage we are going to perform.
9
+ """
10
+ stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
11
+ default="sft",
12
+ metadata={"help": "Which stage will be performed in training."}
13
+ )
src/llmtuner/hparams/generating_args.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ from dataclasses import asdict, dataclass, field
3
+
4
+
5
+ @dataclass
6
+ class GeneratingArguments:
7
+ r"""
8
+ Arguments pertaining to specify the decoding parameters.
9
+ """
10
+ do_sample: Optional[bool] = field(
11
+ default=True,
12
+ metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
13
+ )
14
+ temperature: Optional[float] = field(
15
+ default=0.95,
16
+ metadata={"help": "The value used to modulate the next token probabilities."}
17
+ )
18
+ top_p: Optional[float] = field(
19
+ default=0.7,
20
+ metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
21
+ )
22
+ top_k: Optional[int] = field(
23
+ default=50,
24
+ metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
25
+ )
26
+ num_beams: Optional[int] = field(
27
+ default=1,
28
+ metadata={"help": "Number of beams for beam search. 1 means no beam search."}
29
+ )
30
+ max_length: Optional[int] = field(
31
+ default=None,
32
+ metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
33
+ )
34
+ max_new_tokens: Optional[int] = field(
35
+ default=512,
36
+ metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
37
+ )
38
+ repetition_penalty: Optional[float] = field(
39
+ default=1.0,
40
+ metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
41
+ )
42
+ length_penalty: Optional[float] = field(
43
+ default=1.0,
44
+ metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
45
+ )
46
+
47
+ def to_dict(self) -> Dict[str, Any]:
48
+ args = asdict(self)
49
+ if args.get("max_new_tokens", None):
50
+ args.pop("max_length", None)
51
+ return args
src/llmtuner/hparams/model_args.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Literal, Optional
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class ModelArguments:
8
+ r"""
9
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
10
+ """
11
+ model_name_or_path: str = field(
12
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
13
+ )
14
+ cache_dir: Optional[str] = field(
15
+ default=None,
16
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
17
+ )
18
+ use_fast_tokenizer: Optional[bool] = field(
19
+ default=True,
20
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
21
+ )
22
+ use_auth_token: Optional[bool] = field(
23
+ default=False,
24
+ metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
25
+ )
26
+ model_revision: Optional[str] = field(
27
+ default="main",
28
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
29
+ )
30
+ quantization_bit: Optional[int] = field(
31
+ default=None,
32
+ metadata={"help": "The number of bits to quantize the model."}
33
+ )
34
+ quantization_type: Optional[Literal["fp4", "nf4"]] = field(
35
+ default="nf4",
36
+ metadata={"help": "Quantization data type to use in int4 training."}
37
+ )
38
+ double_quantization: Optional[bool] = field(
39
+ default=True,
40
+ metadata={"help": "Whether to use double quantization in int4 training or not."}
41
+ )
42
+ rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
43
+ default=None,
44
+ metadata={"help": "Adopt scaled rotary positional embeddings."}
45
+ )
46
+ flash_attn: Optional[bool] = field(
47
+ default=False,
48
+ metadata={"help": "Enable flash attention for faster training."}
49
+ )
50
+ checkpoint_dir: Optional[str] = field(
51
+ default=None,
52
+ metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
53
+ )
54
+ reward_model: Optional[str] = field(
55
+ default=None,
56
+ metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
57
+ )
58
+ plot_loss: Optional[bool] = field(
59
+ default=False,
60
+ metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
61
+ )
62
+ hf_auth_token: Optional[str] = field(
63
+ default=None,
64
+ metadata={"help": "Auth token to log in with Hugging Face Hub."}
65
+ )
66
+
67
+ def __post_init__(self):
68
+ self.compute_dtype = None
69
+ self.model_max_length = None
70
+
71
+ if self.checkpoint_dir is not None: # support merging multiple lora weights
72
+ self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
73
+
74
+ if self.quantization_bit is not None:
75
+ assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
76
+
77
+ if self.use_auth_token == True and self.hf_auth_token is not None:
78
+ from huggingface_hub.hf_api import HfFolder # lazy load
79
+ HfFolder.save_token(self.hf_auth_token)
src/llmtuner/tuner/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from llmtuner.tuner.tune import export_model, run_exp
src/llmtuner/tuner/core/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from llmtuner.tuner.core.parser import get_train_args, get_infer_args
2
+ from llmtuner.tuner.core.loader import load_model_and_tokenizer
src/llmtuner/tuner/core/adapter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import TYPE_CHECKING
4
+
5
+ from peft import (
6
+ PeftModel,
7
+ TaskType,
8
+ LoraConfig,
9
+ get_peft_model
10
+ )
11
+ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
12
+
13
+ from llmtuner.extras.logging import get_logger
14
+ from llmtuner.tuner.core.utils import find_all_linear_modules
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from llmtuner.hparams import ModelArguments, FinetuningArguments
19
+
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ def init_adapter(
25
+ model: "PreTrainedModel",
26
+ model_args: "ModelArguments",
27
+ finetuning_args: "FinetuningArguments",
28
+ is_trainable: bool,
29
+ is_mergeable: bool
30
+ ) -> "PreTrainedModel":
31
+ r"""
32
+ Initializes the adapters.
33
+
34
+ Support full-parameter, freeze and LoRA training.
35
+
36
+ Note that the trainable parameters must be cast to float32.
37
+ """
38
+
39
+ if finetuning_args.finetuning_type == "none" and is_trainable:
40
+ raise ValueError("You cannot use finetuning_type=none while training.")
41
+
42
+ if finetuning_args.finetuning_type == "full" and is_trainable:
43
+ logger.info("Fine-tuning method: Full")
44
+ model = model.float()
45
+
46
+ if finetuning_args.finetuning_type == "freeze":
47
+ logger.info("Fine-tuning method: Freeze")
48
+
49
+ for name, param in model.named_parameters():
50
+ if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
51
+ param.requires_grad_(False)
52
+ else:
53
+ param.data = param.data.to(torch.float32)
54
+
55
+ if finetuning_args.finetuning_type == "lora":
56
+ logger.info("Fine-tuning method: LoRA")
57
+ latest_checkpoint = None
58
+
59
+ if model_args.checkpoint_dir is not None:
60
+ assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
61
+ "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
62
+ assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
63
+ "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
64
+
65
+ if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
66
+ checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
67
+ else:
68
+ checkpoints_to_merge = model_args.checkpoint_dir
69
+
70
+ for checkpoint in checkpoints_to_merge:
71
+ model = PeftModel.from_pretrained(model, checkpoint)
72
+ model = model.merge_and_unload()
73
+
74
+ if len(checkpoints_to_merge) > 0:
75
+ logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
76
+
77
+ if latest_checkpoint is not None: # resume lora training or quantized inference
78
+ model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
79
+
80
+ if is_trainable and latest_checkpoint is None: # create new lora weights while training
81
+ if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
82
+ target_modules = find_all_linear_modules(model, model_args.quantization_bit)
83
+ else:
84
+ target_modules = finetuning_args.lora_target
85
+
86
+ lora_config = LoraConfig(
87
+ task_type=TaskType.CAUSAL_LM,
88
+ inference_mode=False,
89
+ r=finetuning_args.lora_rank,
90
+ lora_alpha=finetuning_args.lora_alpha,
91
+ lora_dropout=finetuning_args.lora_dropout,
92
+ target_modules=target_modules
93
+ )
94
+ model = get_peft_model(model, lora_config)
95
+ if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
96
+ model.base_model.peft_config = model.peft_config
97
+
98
+ if model_args.checkpoint_dir is not None:
99
+ logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
100
+
101
+ return model
src/llmtuner/tuner/core/loader.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ from types import MethodType
5
+ from typing import TYPE_CHECKING, Literal, Optional, Tuple
6
+
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ BitsAndBytesConfig,
12
+ PretrainedConfig,
13
+ PreTrainedModel,
14
+ PreTrainedTokenizerBase
15
+ )
16
+ from transformers.utils import check_min_version
17
+ from transformers.utils.versions import require_version
18
+ from trl import AutoModelForCausalLMWithValueHead
19
+
20
+ try:
21
+ from transformers.integrations import is_deepspeed_zero3_enabled
22
+ except ImportError:
23
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
24
+
25
+ from llmtuner.extras.logging import reset_logging, get_logger
26
+ from llmtuner.extras.misc import count_parameters
27
+ from llmtuner.extras.save_and_load import load_valuehead_params
28
+ from llmtuner.hparams import FinetuningArguments
29
+ from llmtuner.tuner.core.adapter import init_adapter
30
+ from llmtuner.tuner.core.utils import prepare_model_for_training
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import PreTrainedTokenizer
34
+ from llmtuner.hparams import ModelArguments
35
+
36
+
37
+ logger = get_logger(__name__)
38
+
39
+
40
+ check_min_version("4.30.0")
41
+ require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
42
+ require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
43
+ require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
44
+ require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
45
+
46
+
47
+ def load_model_and_tokenizer(
48
+ model_args: "ModelArguments",
49
+ finetuning_args: "FinetuningArguments",
50
+ is_trainable: Optional[bool] = False,
51
+ stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
52
+ ) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
53
+ r"""
54
+ Loads pretrained model and tokenizer.
55
+
56
+ Support both training and inference.
57
+ """
58
+ if (not is_trainable) and model_args.checkpoint_dir is None:
59
+ logger.warning("Checkpoint is not found at evaluation, load the original model.")
60
+ finetuning_args = FinetuningArguments(finetuning_type="none")
61
+
62
+ config_kwargs = {
63
+ "trust_remote_code": True,
64
+ "cache_dir": model_args.cache_dir,
65
+ "revision": model_args.model_revision,
66
+ "use_auth_token": True if model_args.use_auth_token else None,
67
+ }
68
+
69
+ tokenizer = AutoTokenizer.from_pretrained(
70
+ model_args.model_name_or_path,
71
+ use_fast=model_args.use_fast_tokenizer,
72
+ padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
73
+ **config_kwargs
74
+ )
75
+
76
+ # Fix tokenizer (for ChatGLM2)
77
+ if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
78
+ tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
79
+
80
+ if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
81
+ model_to_load = model_args.checkpoint_dir[0]
82
+ else:
83
+ model_to_load = model_args.model_name_or_path
84
+
85
+ config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
86
+
87
+ # Fix config (for Qwen)
88
+ if hasattr(config, "fp16") and hasattr(config, "bf16"):
89
+ setattr(config, "fp16", model_args.compute_dtype == torch.float16)
90
+ setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
91
+
92
+ # Set RoPE scaling
93
+ if model_args.rope_scaling is not None:
94
+ if hasattr(config, "use_dynamic_ntk"): # for Qwen models
95
+ if is_trainable:
96
+ logger.warning("Qwen model does not support RoPE scaling in training.")
97
+ else:
98
+ setattr(config, "use_dynamic_ntk", True)
99
+ setattr(config, "use_logn_attn", True)
100
+ logger.info("Using dynamic NTK scaling.")
101
+
102
+ elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
103
+ require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
104
+ if is_trainable:
105
+ if model_args.rope_scaling == "dynamic":
106
+ assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
107
+ logger.warning(
108
+ "Dynamic NTK may not work well with fine-tuning. "
109
+ "See: https://github.com/huggingface/transformers/pull/24653"
110
+ )
111
+
112
+ current_max_length = getattr(config, "max_position_embeddings", None)
113
+ if current_max_length and model_args.model_max_length > current_max_length:
114
+ scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
115
+ else:
116
+ logger.warning("Input length is smaller than max length. Consider increase input length.")
117
+ scaling_factor = 1.0
118
+ else:
119
+ scaling_factor = 2.0
120
+
121
+ setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
122
+ logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
123
+ model_args.rope_scaling, scaling_factor
124
+ ))
125
+
126
+ else:
127
+ logger.warning("Current model does not support RoPE scaling.")
128
+
129
+ # Set flash attention
130
+ if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
131
+ import transformers.models.llama.modeling_llama as LlamaModule
132
+ import llmtuner.extras.patches.flash_llama as FlashLlama
133
+ LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm
134
+ LlamaModule.LlamaAttention = FlashLlama.LlamaAttention
135
+ LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask
136
+ if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models
137
+ setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
138
+ if getattr(config, "pretraining_tp", 1) != 1:
139
+ setattr(config, "pretraining_tp", 1)
140
+
141
+ # Quantization configurations (using bitsandbytes library).
142
+ is_mergeable = True
143
+ if model_args.quantization_bit is not None:
144
+ if is_deepspeed_zero3_enabled():
145
+ raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
146
+
147
+ if model_args.quantization_bit == 8:
148
+ require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
149
+ config_kwargs["load_in_8bit"] = True
150
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
151
+
152
+ elif model_args.quantization_bit == 4:
153
+ require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
154
+ config_kwargs["load_in_4bit"] = True
155
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(
156
+ load_in_4bit=True,
157
+ bnb_4bit_compute_dtype=model_args.compute_dtype,
158
+ bnb_4bit_use_double_quant=model_args.double_quantization,
159
+ bnb_4bit_quant_type=model_args.quantization_type
160
+ )
161
+
162
+ is_mergeable = False
163
+ config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
164
+ logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
165
+
166
+ # Load and prepare pre-trained models (without valuehead).
167
+ model = AutoModelForCausalLM.from_pretrained(
168
+ model_to_load,
169
+ config=config,
170
+ torch_dtype=model_args.compute_dtype,
171
+ low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
172
+ **config_kwargs
173
+ )
174
+
175
+ # Disable custom generate method (for Qwen)
176
+ if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
177
+ model.generate = MethodType(PreTrainedModel.generate, model)
178
+
179
+ # Fix LM head (for ChatGLM2)
180
+ if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
181
+ setattr(model, "lm_head", model.transformer.output_layer)
182
+
183
+ # Register auto class to save the custom code files.
184
+ if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
185
+ config.__class__.register_for_auto_class()
186
+ if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
187
+ model.__class__.register_for_auto_class()
188
+ if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
189
+ tokenizer.__class__.register_for_auto_class()
190
+
191
+ # Initialize adapters
192
+ model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
193
+ model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
194
+ model = model.train() if is_trainable else model.eval()
195
+
196
+ # Prepare model with valuehead for RLHF
197
+ if stage == "rm" or stage == "ppo":
198
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
199
+ model._keys_to_ignore_on_save = None
200
+ reset_logging()
201
+ if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
202
+ logger.warning("Only the last checkpoint containing valuehead will be loaded.")
203
+ if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
204
+ model.v_head.load_state_dict({
205
+ "summary.weight": getattr(model, "reward_head_weight"),
206
+ "summary.bias": getattr(model, "reward_head_bias")
207
+ })
208
+
209
+ if stage == "ppo": # load reward model
210
+ logger.info("Load reward model from {}".format(model_args.reward_model))
211
+ if getattr(model, "is_peft_model", False):
212
+ model.pretrained_model.load_adapter(model_args.reward_model, "reward")
213
+ assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
214
+
215
+ # Prepare model for inference
216
+ if not is_trainable:
217
+ model.requires_grad_(False) # fix all model params
218
+ model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
219
+
220
+ trainable_params, all_param = count_parameters(model)
221
+ logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
222
+ trainable_params, all_param, 100 * trainable_params / all_param
223
+ ))
224
+
225
+ return model, tokenizer
src/llmtuner/tuner/core/parser.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import datasets
5
+ import transformers
6
+ from typing import Any, Dict, Optional, Tuple
7
+ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
8
+ from transformers.utils.versions import require_version
9
+ from transformers.trainer_utils import get_last_checkpoint
10
+
11
+ try:
12
+ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
13
+ is_bf16_available = is_torch_bf16_gpu_available()
14
+ is_npu_available = is_torch_npu_available()
15
+ except ImportError:
16
+ is_bf16_available = torch.cuda.is_bf16_supported()
17
+ is_npu_available = False
18
+
19
+ from llmtuner.extras.logging import get_logger
20
+ from llmtuner.hparams import (
21
+ ModelArguments,
22
+ DataArguments,
23
+ FinetuningArguments,
24
+ GeneratingArguments,
25
+ GeneralArguments
26
+ )
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
33
+ if args is not None:
34
+ return parser.parse_dict(args)
35
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
36
+ return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
37
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
38
+ return parser.parse_json_file(os.path.abspath(sys.argv[1]))
39
+ else:
40
+ return parser.parse_args_into_dataclasses()
41
+
42
+
43
+ def parse_train_args(
44
+ args: Optional[Dict[str, Any]] = None
45
+ ) -> Tuple[
46
+ ModelArguments,
47
+ DataArguments,
48
+ Seq2SeqTrainingArguments,
49
+ FinetuningArguments,
50
+ GeneratingArguments,
51
+ GeneralArguments
52
+ ]:
53
+ parser = HfArgumentParser((
54
+ ModelArguments,
55
+ DataArguments,
56
+ Seq2SeqTrainingArguments,
57
+ FinetuningArguments,
58
+ GeneratingArguments,
59
+ GeneralArguments
60
+ ))
61
+ return _parse_args(parser, args)
62
+
63
+
64
+ def parse_infer_args(
65
+ args: Optional[Dict[str, Any]] = None
66
+ ) -> Tuple[
67
+ ModelArguments,
68
+ DataArguments,
69
+ FinetuningArguments,
70
+ GeneratingArguments
71
+ ]:
72
+ parser = HfArgumentParser((
73
+ ModelArguments,
74
+ DataArguments,
75
+ FinetuningArguments,
76
+ GeneratingArguments
77
+ ))
78
+ return _parse_args(parser, args)
79
+
80
+
81
+ def get_train_args(
82
+ args: Optional[Dict[str, Any]] = None
83
+ ) -> Tuple[
84
+ ModelArguments,
85
+ DataArguments,
86
+ Seq2SeqTrainingArguments,
87
+ FinetuningArguments,
88
+ GeneratingArguments,
89
+ GeneralArguments
90
+ ]:
91
+ model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
92
+
93
+ # Setup logging
94
+ if training_args.should_log:
95
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
96
+ transformers.utils.logging.set_verbosity_info()
97
+
98
+ log_level = training_args.get_process_log_level()
99
+ datasets.utils.logging.set_verbosity(log_level)
100
+ transformers.utils.logging.set_verbosity(log_level)
101
+ transformers.utils.logging.enable_default_handler()
102
+ transformers.utils.logging.enable_explicit_format()
103
+
104
+ # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
105
+ data_args.init_for_training()
106
+
107
+ if general_args.stage != "pt" and data_args.template is None:
108
+ raise ValueError("Please specify which `template` to use.")
109
+
110
+ if general_args.stage != "sft" and training_args.predict_with_generate:
111
+ raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
112
+
113
+ if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
114
+ raise ValueError("Please enable `predict_with_generate` to save model predictions.")
115
+
116
+ if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
117
+ raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
118
+
119
+ if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
120
+ raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
121
+
122
+ if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
123
+ raise ValueError("PPO and DPO stages can only be performed at training.")
124
+
125
+ if general_args.stage in ["rm", "dpo"]:
126
+ for dataset_attr in data_args.dataset_list:
127
+ if not dataset_attr.ranking:
128
+ raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
129
+
130
+ if general_args.stage == "ppo" and model_args.reward_model is None:
131
+ raise ValueError("Reward model is necessary for PPO training.")
132
+
133
+ if general_args.stage == "ppo" and training_args.deepspeed is not None:
134
+ raise ValueError("PPO training is incompatible with DeepSpeed, use Accelerate instead.")
135
+
136
+ if general_args.stage == "ppo" and data_args.streaming:
137
+ raise ValueError("Streaming mode does not suppport PPO training currently.")
138
+
139
+ if training_args.max_steps == -1 and data_args.streaming:
140
+ raise ValueError("Please specify `max_steps` in streaming mode.")
141
+
142
+ if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
143
+ raise ValueError("Streaming mode should have an integer val size.")
144
+
145
+ if training_args.do_train and training_args.predict_with_generate:
146
+ raise ValueError("`predict_with_generate` cannot be set as True while training.")
147
+
148
+ if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
149
+ raise ValueError("Please specify `lora_target` in LoRA training.")
150
+
151
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
152
+ raise ValueError("Quantization is only compatible with the LoRA method.")
153
+
154
+ if model_args.checkpoint_dir is not None:
155
+ if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
156
+ raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
157
+
158
+ if model_args.quantization_bit is not None:
159
+ if len(model_args.checkpoint_dir) != 1:
160
+ raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
161
+
162
+ if not finetuning_args.resume_lora_training:
163
+ raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
164
+
165
+ if model_args.quantization_bit is not None and (not training_args.do_train):
166
+ logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
167
+
168
+ if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
169
+ logger.warning("We recommend enable mixed precision training.")
170
+
171
+ # postprocess data_args
172
+ if data_args.max_samples is not None and data_args.streaming:
173
+ logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
174
+ data_args.max_samples = None
175
+
176
+ # postprocess training_args
177
+ if (
178
+ training_args.local_rank != -1
179
+ and training_args.ddp_find_unused_parameters is None
180
+ and finetuning_args.finetuning_type == "lora"
181
+ ):
182
+ logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
183
+ training_args_dict = training_args.to_dict()
184
+ training_args_dict.update(dict(ddp_find_unused_parameters=False))
185
+ training_args = Seq2SeqTrainingArguments(**training_args_dict)
186
+
187
+ if (
188
+ training_args.resume_from_checkpoint is None
189
+ and training_args.do_train
190
+ and os.path.isdir(training_args.output_dir)
191
+ and not training_args.overwrite_output_dir
192
+ ):
193
+ require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
194
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
195
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
196
+ raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
197
+
198
+ if last_checkpoint is not None:
199
+ training_args_dict = training_args.to_dict()
200
+ training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
201
+ training_args = Seq2SeqTrainingArguments(**training_args_dict)
202
+ logger.info(
203
+ "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
204
+ )
205
+
206
+ # postprocess model_args
207
+ if training_args.bf16:
208
+ if not is_bf16_available:
209
+ raise ValueError("Current device does not support bf16 training.")
210
+ model_args.compute_dtype = torch.bfloat16
211
+ elif training_args.fp16:
212
+ model_args.compute_dtype = torch.float16
213
+ else:
214
+ model_args.compute_dtype = torch.float32
215
+
216
+ model_args.model_max_length = data_args.cutoff_len
217
+
218
+ # Log on each process the small summary:
219
+ logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
220
+ training_args.local_rank, training_args.device, training_args.n_gpu,
221
+ bool(training_args.local_rank != -1), str(model_args.compute_dtype)
222
+ ))
223
+ logger.info(f"Training/evaluation parameters {training_args}")
224
+
225
+ # Set seed before initializing model.
226
+ transformers.set_seed(training_args.seed)
227
+
228
+ return model_args, data_args, training_args, finetuning_args, generating_args, general_args
229
+
230
+
231
+ def get_infer_args(
232
+ args: Optional[Dict[str, Any]] = None
233
+ ) -> Tuple[
234
+ ModelArguments,
235
+ DataArguments,
236
+ FinetuningArguments,
237
+ GeneratingArguments
238
+ ]:
239
+ model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
240
+
241
+ if data_args.template is None:
242
+ raise ValueError("Please specify which `template` to use.")
243
+
244
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
245
+ raise ValueError("Quantization is only compatible with the LoRA method.")
246
+
247
+ if model_args.checkpoint_dir is not None:
248
+ if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
249
+ raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
250
+
251
+ if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
252
+ raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
253
+
254
+ # auto-detect cuda capability
255
+ if is_npu_available:
256
+ model_args.compute_dtype = torch.float16
257
+ elif is_bf16_available:
258
+ model_args.compute_dtype = torch.bfloat16
259
+ else:
260
+ model_args.compute_dtype = torch.float16
261
+
262
+ return model_args, data_args, finetuning_args, generating_args