AesSedai commited on
Commit
befade2
·
verified ·
1 Parent(s): 5106dce

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. chat_template.jinja +103 -0
  3. config.json +46 -0
  4. generation_config.json +10 -0
  5. model-00001-of-00075.safetensors +3 -0
  6. model-00004-of-00075.safetensors +3 -0
  7. model-00005-of-00075.safetensors +3 -0
  8. model-00007-of-00075.safetensors +3 -0
  9. model-00008-of-00075.safetensors +3 -0
  10. model-00010-of-00075.safetensors +3 -0
  11. model-00016-of-00075.safetensors +3 -0
  12. model-00020-of-00075.safetensors +3 -0
  13. model-00022-of-00075.safetensors +3 -0
  14. model-00023-of-00075.safetensors +3 -0
  15. model-00024-of-00075.safetensors +3 -0
  16. model-00025-of-00075.safetensors +3 -0
  17. model-00026-of-00075.safetensors +3 -0
  18. model-00027-of-00075.safetensors +3 -0
  19. model-00028-of-00075.safetensors +3 -0
  20. model-00032-of-00075.safetensors +3 -0
  21. model-00035-of-00075.safetensors +3 -0
  22. model-00037-of-00075.safetensors +3 -0
  23. model-00038-of-00075.safetensors +3 -0
  24. model-00039-of-00075.safetensors +3 -0
  25. model-00041-of-00075.safetensors +3 -0
  26. model-00042-of-00075.safetensors +3 -0
  27. model-00043-of-00075.safetensors +3 -0
  28. model-00047-of-00075.safetensors +3 -0
  29. model-00050-of-00075.safetensors +3 -0
  30. model-00052-of-00075.safetensors +3 -0
  31. model-00054-of-00075.safetensors +3 -0
  32. model-00055-of-00075.safetensors +3 -0
  33. model-00056-of-00075.safetensors +3 -0
  34. model-00059-of-00075.safetensors +3 -0
  35. model-00060-of-00075.safetensors +3 -0
  36. model-00061-of-00075.safetensors +3 -0
  37. model-00064-of-00075.safetensors +3 -0
  38. model-00066-of-00075.safetensors +3 -0
  39. model-00067-of-00075.safetensors +3 -0
  40. model-00068-of-00075.safetensors +3 -0
  41. model-00069-of-00075.safetensors +3 -0
  42. model-00071-of-00075.safetensors +3 -0
  43. model-00072-of-00075.safetensors +3 -0
  44. model-00074-of-00075.safetensors +3 -0
  45. model-00075-of-00075.safetensors +3 -0
  46. model.safetensors.index.json +0 -0
  47. modeling_glm4_moe.py +624 -0
  48. reap_args.yaml +76 -0
  49. special_tokens_map.json +40 -0
  50. tokenizer.json +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
chat_template.jinja ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [gMASK]<sop>
2
+ {%- if tools -%}
3
+ <|system|>
4
+ # Tools
5
+
6
+ You may call one or more functions to assist with the user query.
7
+
8
+ You are provided with function signatures within <tools></tools> XML tags:
9
+ <tools>
10
+ {% for tool in tools %}
11
+ {{ tool | tojson(ensure_ascii=False) }}
12
+ {% endfor %}
13
+ </tools>
14
+
15
+ For each function call, output the function name and arguments within the following XML format:
16
+ <tool_call>{function-name}
17
+ <arg_key>{arg-key-1}</arg_key>
18
+ <arg_value>{arg-value-1}</arg_value>
19
+ <arg_key>{arg-key-2}</arg_key>
20
+ <arg_value>{arg-value-2}</arg_value>
21
+ ...
22
+ </tool_call>{%- endif -%}
23
+ {%- macro visible_text(content) -%}
24
+ {%- if content is string -%}
25
+ {{- content }}
26
+ {%- elif content is iterable and content is not mapping -%}
27
+ {%- for item in content -%}
28
+ {%- if item is mapping and item.type == 'text' -%}
29
+ {{- item.text }}
30
+ {%- elif item is string -%}
31
+ {{- item }}
32
+ {%- endif -%}
33
+ {%- endfor -%}
34
+ {%- else -%}
35
+ {{- content }}
36
+ {%- endif -%}
37
+ {%- endmacro -%}
38
+ {%- set ns = namespace(last_user_index=-1) %}
39
+ {%- for m in messages %}
40
+ {%- if m.role == 'user' %}
41
+ {% set ns.last_user_index = loop.index0 -%}
42
+ {%- endif %}
43
+ {%- endfor %}
44
+ {% for m in messages %}
45
+ {%- if m.role == 'user' -%}<|user|>
46
+ {{ visible_text(m.content) }}
47
+ {{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}
48
+ {%- elif m.role == 'assistant' -%}
49
+ <|assistant|>
50
+ {%- set reasoning_content = '' %}
51
+ {%- set content = visible_text(m.content) %}
52
+ {%- if m.reasoning_content is string %}
53
+ {%- set reasoning_content = m.reasoning_content %}
54
+ {%- else %}
55
+ {%- if '</think>' in content %}
56
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
57
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
58
+ {%- endif %}
59
+ {%- endif %}
60
+ {%- if loop.index0 > ns.last_user_index and reasoning_content -%}
61
+ {{ '\n<think>' + reasoning_content.strip() + '</think>'}}
62
+ {%- else -%}
63
+ {{ '\n<think></think>' }}
64
+ {%- endif -%}
65
+ {%- if content.strip() -%}
66
+ {{ '\n' + content.strip() }}
67
+ {%- endif -%}
68
+ {% if m.tool_calls %}
69
+ {% for tc in m.tool_calls %}
70
+ {%- if tc.function %}
71
+ {%- set tc = tc.function %}
72
+ {%- endif %}
73
+ {{ '\n<tool_call>' + tc.name }}
74
+ {% set _args = tc.arguments %}
75
+ {% for k, v in _args.items() %}
76
+ <arg_key>{{ k }}</arg_key>
77
+ <arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>
78
+ {% endfor %}
79
+ </tool_call>{% endfor %}
80
+ {% endif %}
81
+ {%- elif m.role == 'tool' -%}
82
+ {%- if m.content is string -%}
83
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
84
+ {{- '<|observation|>' }}
85
+ {%- endif %}
86
+ {{- '\n<tool_response>\n' }}
87
+ {{- m.content }}
88
+ {{- '\n</tool_response>' }}
89
+ {%- else -%}
90
+ <|observation|>{% for tr in m.content %}
91
+
92
+ <tool_response>
93
+ {{ tr.output if tr.output is defined else tr }}
94
+ </tool_response>{% endfor -%}
95
+ {% endif -%}
96
+ {%- elif m.role == 'system' -%}
97
+ <|system|>
98
+ {{ visible_text(m.content) }}
99
+ {%- endif -%}
100
+ {%- endfor -%}
101
+ {%- if add_generation_prompt -%}
102
+ <|assistant|>{{- '\n<think></think>' if (enable_thinking is defined and not enable_thinking) else '' -}}
103
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Glm4MoeForCausalLM"
4
+ ],
5
+ "attention_bias": true,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModelForCausalLM": "modeling_glm4_moe.Glm4MoeForCausalLM"
9
+ },
10
+ "eos_token_id": [
11
+ 151329,
12
+ 151336,
13
+ 151338
14
+ ],
15
+ "first_k_dense_replace": 3,
16
+ "head_dim": 128,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 5120,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 12288,
21
+ "max_position_embeddings": 202752,
22
+ "model_type": "glm4_moe",
23
+ "moe_intermediate_size": 1536,
24
+ "n_group": 1,
25
+ "n_routed_experts": 80,
26
+ "n_shared_experts": 1,
27
+ "norm_topk_prob": true,
28
+ "num_attention_heads": 96,
29
+ "num_experts_per_tok": 8,
30
+ "num_hidden_layers": 92,
31
+ "num_key_value_heads": 8,
32
+ "num_nextn_predict_layers": 1,
33
+ "pad_token_id": 151329,
34
+ "partial_rotary_factor": 0.5,
35
+ "rms_norm_eps": 1e-05,
36
+ "rope_scaling": null,
37
+ "rope_theta": 1000000,
38
+ "routed_scaling_factor": 2.5,
39
+ "tie_word_embeddings": false,
40
+ "topk_group": 1,
41
+ "torch_dtype": "bfloat16",
42
+ "transformers_version": "4.55.0",
43
+ "use_cache": true,
44
+ "use_qk_norm": true,
45
+ "vocab_size": 151552
46
+ }
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": [
4
+ 151329,
5
+ 151336,
6
+ 151338
7
+ ],
8
+ "pad_token_id": 151329,
9
+ "transformers_version": "4.55.0"
10
+ }
model-00001-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc2614aad7b7f125cbe6665005daf458deeb2415c4125ee5d1b49c7601676d14
3
+ size 4986172552
model-00004-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:493b79e8398d37f09af8ed27b9c86a5262412b106de57843aac55669531b7784
3
+ size 4914326208
model-00005-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1183507706c9b62bc5c6cb21d50c9e35d458677be5134adbce802e054131355a
3
+ size 4997400520
model-00007-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0aa881e1ad09d159742c3e1a3e83d2c532b7017a236e6479fd5a8a3feae3a11e
3
+ size 4992129752
model-00008-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d89035f2d7cfa55395baadcb6d9af4fe75803566cac1924044de4cd68e393bfb
3
+ size 4992129888
model-00010-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aaa899b941123f413c99515d60737d9204bcf9027fc71a29e44d1f9be3a5401
3
+ size 4992129872
model-00016-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:639ea6c4185edfb1fafd584333a17217c469f357239b6579ad127167e52a137c
3
+ size 4992129888
model-00020-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b47f9c7e0a956ceae7b24c6f40f49a3893ceb6b67fea265e0019234546789aa6
3
+ size 4992129888
model-00022-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fd00ebf2d982f3953188a7fc07858f2ebc1caea9f3156bfd19e7d5ff30782f0
3
+ size 4992129888
model-00023-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:102c92f4f24c856fff2b9bbb8061b4a8845b0b845e0491d82385059f38db340f
3
+ size 4998241280
model-00024-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b66f97680b69f674faa733cadd93681989cf821b5a1abe3671df82e16369133
3
+ size 4992129888
model-00025-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c47cf41434b3bae0b12f2368f1108b02eb47600cdd8e8d8f41fff05e76be933c
3
+ size 4992129888
model-00026-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:270868fb827dcd0bc71383fc310955125461ef87954a35661feb107a4f71ff17
3
+ size 4992129888
model-00027-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50938aa2a5ad56baa0dd9799632e9f4e169ab71dee59046cf68694a47379e7a3
3
+ size 4966783928
model-00028-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2649719678c0bff61e4e46fc51c96af3140f173db2187d5a2f3b2d20838ec829
3
+ size 4992129712
model-00032-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a541c7503038db458ee565899ed2c762502d877cb21dfc68cb83ae51817d9897
3
+ size 4998241288
model-00035-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79d9054b9d5008ce549c20f47fce85b20c5fc20fba97e27fd1e931f41b6bac17
3
+ size 4992129888
model-00037-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8f1a9dd3c6482fac33eee59474a07d34dc47c33a43136aa01e46ea56ee4252d
3
+ size 4997400824
model-00038-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70ebe15ead11ae38b3201bf5804bc2610afa277cc4a751447163e57d6c6640e7
3
+ size 4992129888
model-00039-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c130a9b072c765aaec3426883e4e5588ca991511637ff7317013386876259f6
3
+ size 4992129888
model-00041-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3e8959a0eaaa8c84e323ca03f33e85c4b42c63a5fe861d32f9e7f716a73b994
3
+ size 4998241296
model-00042-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef9f8ff060a8e72f12cc73ff390ae2e1689c039c4d7ad25a8933935bd8afdd32
3
+ size 4992129872
model-00043-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db490d5668c50934aeadf4bddca146658f2292cba09aafd69178a6ebfafffa35
3
+ size 4992129888
model-00047-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:199b559733ae5a1a9130c0b63a54d5a52dd8896c5dfddf192dd102e6c56671df
3
+ size 4992129888
model-00050-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88bc64b28d732b669e9a99daf7d6a9ba4146b9dfb63da21609e97c491d26b0fc
3
+ size 4998241304
model-00052-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c4b7f8275bf25321fb34015ef0e4f064068898a3408b4d64dd218dac7b59f63
3
+ size 4992129888
model-00054-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e9155526068ed0b07761b84d613f9df11759c8778955d3c0cab8dddf5e513bb
3
+ size 4992129888
model-00055-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60e7adb20158d493162cd8d19de5c36b92c7a1fd1a2b5a61a77b87dd32b1fc8b
3
+ size 4998241280
model-00056-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e0ec8f6944758593ecd406a1538962a81ab251580273955b7caa7a6d72ae00f
3
+ size 4992129888
model-00059-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40332b2c2e3073b3eadaca092cb06e85a5be36856cd3dc89ce8e13fa316f7b3d
3
+ size 4966783928
model-00060-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd800f47b5d18ec7a0eb7833a57087a1d5c0877ea52e62a81960c5f585ed7ab6
3
+ size 4992129712
model-00061-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7216dd3f84d8817efcc0369c9a2e996105b4a56a04babc067cb188d53d5c303d
3
+ size 4992129888
model-00064-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbcaa525dd438741cf4edde64edad984d535118cfb226f2d1df584664c118b87
3
+ size 4998241288
model-00066-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8578888eb758d57317b1763e61284e302f62697fae64dbba47c7f44a5cff7eee
3
+ size 4992129888
model-00067-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dbf39269b4bdec2d96c0f5cb128fef777288e18b76fdc24e9f8d15316d343e5
3
+ size 4992129888
model-00068-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd6da0d4f40e5652ced325a888746cf48caa17d65ee67bfa8c2ffe2a3b0e3987
3
+ size 4961512816
model-00069-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbbb42468a112a1b466ec12a2e371c41f87a5248d1e87ede340d48f11a3f4363
3
+ size 4997400824
model-00071-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2b09a61b9db2866fe55e000a9142b4c6f4c8e2042c92a22914f366ab5955334
3
+ size 4992129888
model-00072-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:755b0f2c87b75ff87d0768f86ed0f60f952222a3abd335978567cbd136c45774
3
+ size 4992129888
model-00074-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77b1a60271e8afe5762f608bab4edd3e0b1a40e3c8234065744e13e480189b11
3
+ size 3697110504
model-00075-of-00075.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7ed732a1f52d39814adc6a0b127bad1c816b89d7055fa07b029cba98deba7bf
3
+ size 1551892608
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_glm4_moe.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_glm4_moe.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from transformers.modeling_layers import GradientCheckpointingLayer
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
36
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
40
+ from transformers.utils.generic import check_model_inputs
41
+ from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig
42
+
43
+
44
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
45
+ """
46
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
47
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
48
+ """
49
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
50
+ if n_rep == 1:
51
+ return hidden_states
52
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
53
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
54
+
55
+
56
+ def eager_attention_forward(
57
+ module: nn.Module,
58
+ query: torch.Tensor,
59
+ key: torch.Tensor,
60
+ value: torch.Tensor,
61
+ attention_mask: Optional[torch.Tensor],
62
+ scaling: float,
63
+ dropout: float = 0.0,
64
+ **kwargs: Unpack[TransformersKwargs],
65
+ ):
66
+ key_states = repeat_kv(key, module.num_key_value_groups)
67
+ value_states = repeat_kv(value, module.num_key_value_groups)
68
+
69
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
70
+ if attention_mask is not None:
71
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
72
+ attn_weights = attn_weights + causal_mask
73
+
74
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
75
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
76
+ attn_output = torch.matmul(attn_weights, value_states)
77
+ attn_output = attn_output.transpose(1, 2).contiguous()
78
+
79
+ return attn_output, attn_weights
80
+
81
+
82
+ def rotate_half(x):
83
+ """Rotates half the hidden dims of the input."""
84
+ x1 = x[..., : x.shape[-1] // 2]
85
+ x2 = x[..., x.shape[-1] // 2 :]
86
+ return torch.cat((-x2, x1), dim=-1)
87
+
88
+
89
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
90
+ """Applies Rotary Position Embedding to the query and key tensors.
91
+
92
+ Args:
93
+ q (`torch.Tensor`): The query tensor.
94
+ k (`torch.Tensor`): The key tensor.
95
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
96
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
97
+ position_ids (`torch.Tensor`, *optional*):
98
+ Deprecated and unused.
99
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
100
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
101
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
102
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
103
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
104
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
105
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
106
+ Returns:
107
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
108
+ """
109
+ cos = cos.unsqueeze(unsqueeze_dim)
110
+ sin = sin.unsqueeze(unsqueeze_dim)
111
+
112
+ # Keep half or full tensor for later concatenation
113
+ rotary_dim = cos.shape[-1]
114
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
115
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
116
+
117
+ # Apply rotary embeddings on the first half or full tensor
118
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
119
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
120
+
121
+ # Concatenate back to full shape
122
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
123
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
124
+ return q_embed, k_embed
125
+
126
+
127
+ class Glm4MoeAttention(nn.Module):
128
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
129
+
130
+ def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None):
131
+ super().__init__()
132
+ self.config = config
133
+ self.layer_idx = layer_idx
134
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
135
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
136
+ self.scaling = self.head_dim**-0.5
137
+ self.attention_dropout = config.attention_dropout
138
+ self.is_causal = True
139
+
140
+ self.q_proj = nn.Linear(
141
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
142
+ )
143
+ self.k_proj = nn.Linear(
144
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
145
+ )
146
+ self.v_proj = nn.Linear(
147
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
148
+ )
149
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
150
+ self.use_qk_norm = config.use_qk_norm
151
+ if self.use_qk_norm:
152
+ self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
153
+ self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
154
+
155
+ def forward(
156
+ self,
157
+ hidden_states: torch.Tensor,
158
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
159
+ attention_mask: Optional[torch.Tensor],
160
+ past_key_value: Optional[Cache] = None,
161
+ cache_position: Optional[torch.LongTensor] = None,
162
+ **kwargs: Unpack[FlashAttentionKwargs],
163
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
164
+ input_shape = hidden_states.shape[:-1]
165
+ hidden_shape = (*input_shape, -1, self.head_dim)
166
+
167
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
168
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
169
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
170
+
171
+ if self.use_qk_norm: # main diff from Llama
172
+ query_states = self.q_norm(query_states)
173
+ key_states = self.k_norm(key_states)
174
+
175
+ query_states = query_states.transpose(1, 2)
176
+ key_states = key_states.transpose(1, 2)
177
+ value_states = value_states.transpose(1, 2)
178
+
179
+ cos, sin = position_embeddings
180
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
181
+
182
+ if past_key_value is not None:
183
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
184
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
185
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
186
+
187
+ attention_interface: Callable = eager_attention_forward
188
+ if self.config._attn_implementation != "eager":
189
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
190
+
191
+ attn_output, attn_weights = attention_interface(
192
+ self,
193
+ query_states,
194
+ key_states,
195
+ value_states,
196
+ attention_mask,
197
+ dropout=0.0 if not self.training else self.attention_dropout,
198
+ scaling=self.scaling,
199
+ **kwargs,
200
+ )
201
+
202
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
203
+ attn_output = self.o_proj(attn_output)
204
+ return attn_output, attn_weights
205
+
206
+
207
+ class Glm4MoeMLP(nn.Module):
208
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
209
+ super().__init__()
210
+ self.config = config
211
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
212
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
213
+
214
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
215
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
216
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
217
+ self.act_fn = ACT2FN[config.hidden_act]
218
+
219
+ def forward(self, x):
220
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
221
+ return down_proj
222
+
223
+
224
+ class Glm4MoeTopkRouter(nn.Module):
225
+ def __init__(self, config: Glm4MoeConfig):
226
+ super().__init__()
227
+ self.config = config
228
+ self.top_k = config.num_experts_per_tok
229
+ self.n_routed_experts = config.n_routed_experts
230
+ self.routed_scaling_factor = config.routed_scaling_factor
231
+ self.n_group = config.n_group
232
+ self.topk_group = config.topk_group
233
+ self.norm_topk_prob = config.norm_topk_prob
234
+
235
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
236
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32))
237
+
238
+ @torch.no_grad()
239
+ def get_topk_indices(self, scores):
240
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
241
+ group_scores = (
242
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
243
+ .topk(2, dim=-1)[0]
244
+ .sum(dim=-1)
245
+ )
246
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
247
+ group_mask = torch.zeros_like(group_scores)
248
+ group_mask.scatter_(1, group_idx, 1)
249
+ score_mask = (
250
+ group_mask.unsqueeze(-1)
251
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
252
+ .reshape(-1, self.n_routed_experts)
253
+ )
254
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
255
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
256
+ return topk_indices
257
+
258
+ def forward(self, hidden_states):
259
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
260
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
261
+ scores = router_logits.sigmoid()
262
+ topk_indices = self.get_topk_indices(scores)
263
+ topk_weights = scores.gather(1, topk_indices)
264
+ if self.norm_topk_prob:
265
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
266
+ topk_weights /= denominator
267
+ topk_weights = topk_weights * self.routed_scaling_factor
268
+ return topk_indices, topk_weights, router_logits
269
+
270
+
271
+ @use_kernel_forward_from_hub("RMSNorm")
272
+ class Glm4MoeRMSNorm(nn.Module):
273
+ def __init__(self, hidden_size, eps=1e-6):
274
+ """
275
+ Glm4MoeRMSNorm is equivalent to T5LayerNorm
276
+ """
277
+ super().__init__()
278
+ self.weight = nn.Parameter(torch.ones(hidden_size))
279
+ self.variance_epsilon = eps
280
+
281
+ def forward(self, hidden_states):
282
+ input_dtype = hidden_states.dtype
283
+ hidden_states = hidden_states.to(torch.float32)
284
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
285
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
286
+ return self.weight * hidden_states.to(input_dtype)
287
+
288
+ def extra_repr(self):
289
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
290
+
291
+
292
+ class Glm4MoeMoE(nn.Module):
293
+ """
294
+ A mixed expert module containing shared experts.
295
+ """
296
+
297
+ def __init__(self, config):
298
+ super().__init__()
299
+ self.config = config
300
+ self.experts = nn.ModuleList(
301
+ [
302
+ Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size)
303
+ for _ in range(config.n_routed_experts)
304
+ ]
305
+ )
306
+ self.gate = Glm4MoeTopkRouter(config)
307
+ self.shared_experts = Glm4MoeMLP(
308
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
309
+ )
310
+
311
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
312
+ r"""
313
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
314
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
315
+ """
316
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
317
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
318
+ expert_mask = expert_mask.permute(2, 0, 1)
319
+
320
+ for expert_idx in range(len(self.experts)):
321
+ expert = self.experts[expert_idx]
322
+ mask = expert_mask[expert_idx]
323
+ token_indices, weight_indices = torch.where(mask)
324
+
325
+ if token_indices.numel() > 0:
326
+ expert_weights = topk_weights[token_indices, weight_indices]
327
+ expert_input = hidden_states[token_indices]
328
+ expert_output = expert(expert_input)
329
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
330
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
331
+
332
+ # in original deepseek, the output of the experts are gathered once we leave this module
333
+ # thus the moe module is itelsf an IsolatedParallel module
334
+ # and all expert are "local" meaning we shard but we don't gather
335
+ return final_hidden_states.type(hidden_states.dtype)
336
+
337
+ def forward(self, hidden_states):
338
+ residuals = hidden_states
339
+ orig_shape = hidden_states.shape
340
+ topk_indices, topk_weights, router_logits = self.gate(hidden_states)
341
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
342
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
343
+ hidden_states = hidden_states + self.shared_experts(residuals)
344
+ return hidden_states, router_logits
345
+
346
+
347
+ class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
348
+ def __init__(self, config: Glm4MoeConfig, layer_idx: int):
349
+ super().__init__()
350
+ self.hidden_size = config.hidden_size
351
+
352
+ self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx)
353
+
354
+ if layer_idx >= config.first_k_dense_replace:
355
+ self.mlp = Glm4MoeMoE(config)
356
+ else:
357
+ self.mlp = Glm4MoeMLP(config)
358
+
359
+ self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
360
+ self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
361
+
362
+ def forward(
363
+ self,
364
+ hidden_states: torch.Tensor,
365
+ attention_mask: Optional[torch.Tensor] = None,
366
+ position_ids: Optional[torch.LongTensor] = None,
367
+ past_key_value: Optional[Cache] = None,
368
+ use_cache: Optional[bool] = False,
369
+ cache_position: Optional[torch.LongTensor] = None,
370
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
371
+ **kwargs: Unpack[TransformersKwargs],
372
+ ) -> torch.Tensor:
373
+ residual = hidden_states
374
+ hidden_states = self.input_layernorm(hidden_states)
375
+ # Self Attention
376
+ hidden_states, _ = self.self_attn(
377
+ hidden_states=hidden_states,
378
+ attention_mask=attention_mask,
379
+ position_ids=position_ids,
380
+ past_key_value=past_key_value,
381
+ use_cache=use_cache,
382
+ cache_position=cache_position,
383
+ position_embeddings=position_embeddings,
384
+ **kwargs,
385
+ )
386
+ hidden_states = residual + hidden_states
387
+
388
+ # Fully Connected
389
+ residual = hidden_states
390
+ hidden_states = self.post_attention_layernorm(hidden_states)
391
+ mlp_output = self.mlp(hidden_states)
392
+ if len(mlp_output) == 2:
393
+ # If the MLP returns both hidden states and router logits
394
+ hidden_states, _ = mlp_output
395
+ else:
396
+ # If the MLP returns only hidden states
397
+ hidden_states = mlp_output
398
+ hidden_states = residual + hidden_states
399
+ return hidden_states
400
+
401
+
402
+ @auto_docstring
403
+ class Glm4MoePreTrainedModel(PreTrainedModel):
404
+ config: Glm4MoeConfig
405
+ base_model_prefix = "model"
406
+ supports_gradient_checkpointing = True
407
+ _no_split_modules = ["Glm4MoeDecoderLayer"]
408
+ _skip_keys_device_placement = ["past_key_values"]
409
+ _supports_flash_attn = True
410
+ _supports_sdpa = True
411
+ _supports_flex_attn = True
412
+ _can_compile_fullgraph = False
413
+ _supports_attention_backend = True
414
+ _can_record_outputs = {
415
+ "hidden_states": Glm4MoeDecoderLayer,
416
+ "attentions": Glm4MoeAttention,
417
+ }
418
+
419
+ def _init_weights(self, module):
420
+ super()._init_weights(module)
421
+ if isinstance(module, Glm4MoeTopkRouter):
422
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
423
+
424
+
425
+ class Glm4MoeRotaryEmbedding(nn.Module):
426
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
427
+
428
+ def __init__(self, config: Glm4MoeConfig, device=None):
429
+ super().__init__()
430
+ # BC: "rope_type" was originally "type"
431
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
432
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
433
+ else:
434
+ self.rope_type = "default"
435
+ self.max_seq_len_cached = config.max_position_embeddings
436
+ self.original_max_seq_len = config.max_position_embeddings
437
+
438
+ self.config = config
439
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
440
+
441
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
442
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
443
+ self.original_inv_freq = self.inv_freq
444
+
445
+ @torch.no_grad()
446
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
447
+ def forward(self, x, position_ids):
448
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
449
+ position_ids_expanded = position_ids[:, None, :].float()
450
+
451
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
452
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
453
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
454
+ emb = torch.cat((freqs, freqs), dim=-1)
455
+ cos = emb.cos() * self.attention_scaling
456
+ sin = emb.sin() * self.attention_scaling
457
+
458
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
459
+
460
+
461
+ @auto_docstring
462
+ class Glm4MoeModel(Glm4MoePreTrainedModel):
463
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"]
464
+
465
+ def __init__(self, config: Glm4MoeConfig):
466
+ super().__init__(config)
467
+ self.padding_idx = config.pad_token_id
468
+ self.vocab_size = config.vocab_size
469
+
470
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
471
+ self.layers = nn.ModuleList(
472
+ [Glm4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
473
+ )
474
+ self.norm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
475
+ self.rotary_emb = Glm4MoeRotaryEmbedding(config=config)
476
+ self.gradient_checkpointing = False
477
+
478
+ # Initialize weights and apply final processing
479
+ self.post_init()
480
+
481
+ @check_model_inputs
482
+ @auto_docstring
483
+ def forward(
484
+ self,
485
+ input_ids: Optional[torch.LongTensor] = None,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ past_key_values: Optional[Cache] = None,
489
+ inputs_embeds: Optional[torch.FloatTensor] = None,
490
+ cache_position: Optional[torch.LongTensor] = None,
491
+ use_cache: Optional[bool] = None,
492
+ **kwargs: Unpack[TransformersKwargs],
493
+ ) -> BaseModelOutputWithPast:
494
+ if (input_ids is None) ^ (inputs_embeds is not None):
495
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
496
+
497
+ if inputs_embeds is None:
498
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
499
+
500
+ if use_cache and past_key_values is None:
501
+ past_key_values = DynamicCache()
502
+
503
+ if cache_position is None:
504
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
505
+ cache_position: torch.Tensor = torch.arange(
506
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
507
+ )
508
+
509
+ if position_ids is None:
510
+ position_ids = cache_position.unsqueeze(0)
511
+
512
+ causal_mask = create_causal_mask(
513
+ config=self.config,
514
+ input_embeds=inputs_embeds,
515
+ attention_mask=attention_mask,
516
+ cache_position=cache_position,
517
+ past_key_values=past_key_values,
518
+ position_ids=position_ids,
519
+ )
520
+
521
+ hidden_states = inputs_embeds
522
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
523
+
524
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
525
+ hidden_states = decoder_layer(
526
+ hidden_states,
527
+ attention_mask=causal_mask,
528
+ position_ids=position_ids,
529
+ past_key_value=past_key_values,
530
+ cache_position=cache_position,
531
+ position_embeddings=position_embeddings,
532
+ **kwargs,
533
+ )
534
+
535
+ hidden_states = self.norm(hidden_states)
536
+ return BaseModelOutputWithPast(
537
+ last_hidden_state=hidden_states,
538
+ past_key_values=past_key_values,
539
+ )
540
+
541
+
542
+ @auto_docstring
543
+ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
544
+ _tied_weights_keys = ["lm_head.weight"]
545
+ _tp_plan = {"lm_head": "colwise_rep"}
546
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
547
+
548
+ def __init__(self, config):
549
+ super().__init__(config)
550
+ self.model = Glm4MoeModel(config)
551
+ self.vocab_size = config.vocab_size
552
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
553
+
554
+ # Initialize weights and apply final processing
555
+ self.post_init()
556
+
557
+ def set_decoder(self, decoder):
558
+ self.model = decoder
559
+
560
+ def get_decoder(self):
561
+ return self.model
562
+
563
+ @can_return_tuple
564
+ @auto_docstring
565
+ def forward(
566
+ self,
567
+ input_ids: Optional[torch.LongTensor] = None,
568
+ attention_mask: Optional[torch.Tensor] = None,
569
+ position_ids: Optional[torch.LongTensor] = None,
570
+ past_key_values: Optional[Cache] = None,
571
+ inputs_embeds: Optional[torch.FloatTensor] = None,
572
+ labels: Optional[torch.LongTensor] = None,
573
+ use_cache: Optional[bool] = None,
574
+ cache_position: Optional[torch.LongTensor] = None,
575
+ logits_to_keep: Union[int, torch.Tensor] = 0,
576
+ **kwargs: Unpack[TransformersKwargs],
577
+ ) -> CausalLMOutputWithPast:
578
+ r"""
579
+ Example:
580
+
581
+ ```python
582
+ >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM
583
+
584
+ >>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf")
585
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf")
586
+
587
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
588
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
589
+
590
+ >>> # Generate
591
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
592
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
593
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
594
+ ```"""
595
+ outputs: BaseModelOutputWithPast = self.model(
596
+ input_ids=input_ids,
597
+ attention_mask=attention_mask,
598
+ position_ids=position_ids,
599
+ past_key_values=past_key_values,
600
+ inputs_embeds=inputs_embeds,
601
+ use_cache=use_cache,
602
+ cache_position=cache_position,
603
+ **kwargs,
604
+ )
605
+
606
+ hidden_states = outputs.last_hidden_state
607
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
608
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
609
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
610
+
611
+ loss = None
612
+ if labels is not None:
613
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
614
+
615
+ return CausalLMOutputWithPast(
616
+ loss=loss,
617
+ logits=logits,
618
+ past_key_values=outputs.past_key_values,
619
+ hidden_states=outputs.hidden_states,
620
+ attentions=outputs.attentions,
621
+ )
622
+
623
+
624
+ __all__ = ["Glm4MoePreTrainedModel", "Glm4MoeModel", "Glm4MoeForCausalLM"]
reap_args.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cluster_args:
2
+ cluster_description: null
3
+ cluster_method: agglomerative
4
+ compression_ratio: 0.5
5
+ expert_sim: ttm
6
+ frequency_penalty: true
7
+ linkage_method: average
8
+ max_cluster_size: null
9
+ multi_layer: null
10
+ num_clusters: null
11
+ singleton_outlier_experts: false
12
+ singleton_super_experts: false
13
+ softmax_temperature: null
14
+ ds_args:
15
+ dataset_config_name: all
16
+ dataset_name: theblackcat102/evol-codealpaca-v1
17
+ dataset_test_split: test
18
+ shuffle: true
19
+ split: train
20
+ eval_args:
21
+ evalplus_tasks:
22
+ - mbpp
23
+ - humaneval
24
+ greedy: true
25
+ lm_eval_tasks:
26
+ - winogrande
27
+ - arc_challenge
28
+ - arc_easy
29
+ - boolq
30
+ - hellaswag
31
+ - mmlu
32
+ - openbookqa
33
+ - rte
34
+ min_p: 0.0
35
+ parallel_tasks: 32
36
+ results_dir: null
37
+ run_evalplus: true
38
+ run_livecodebench: true
39
+ run_lm_eval: true
40
+ run_math: false
41
+ run_wildbench: false
42
+ server_log_file_name: pruning-cli-0.log
43
+ temperature: 0.7
44
+ top_k: 20
45
+ top_p: 0.8
46
+ use_server: true
47
+ vllm_port: 8000
48
+ model_args:
49
+ model_name: zai-org/GLM-4.6
50
+ num_experts_per_tok_override: null
51
+ obs_args:
52
+ distance_measure: cosine
53
+ model_max_length: 2048
54
+ output_file_name: observations_10_cosine-seed_42.pt
55
+ overwrite_observations: false
56
+ record_pruning_metrics_only: true
57
+ renormalize_router_weights: false
58
+ return_vllm_tokens_prompt: false
59
+ samples_per_category: 10
60
+ select_only_categories: null
61
+ split_by_category: false
62
+ truncate: false
63
+ prune_args:
64
+ n_experts_to_prune: null
65
+ overwrite_pruned_model: false
66
+ perserve_outliers: false
67
+ perserve_super_experts: false
68
+ prune_method: reap
69
+ reap_args:
70
+ debug: false
71
+ do_eval: false
72
+ plot_clusters: true
73
+ profile: false
74
+ run_observer_only: false
75
+ seed: 42
76
+ smoke_test: true
special_tokens_map.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>",
4
+ "[MASK]",
5
+ "[gMASK]",
6
+ "[sMASK]",
7
+ "<sop>",
8
+ "<eop>",
9
+ "<|system|>",
10
+ "<|user|>",
11
+ "<|assistant|>",
12
+ "<|observation|>",
13
+ "<|begin_of_image|>",
14
+ "<|end_of_image|>",
15
+ "<|begin_of_video|>",
16
+ "<|end_of_video|>",
17
+ "<|begin_of_audio|>",
18
+ "<|end_of_audio|>",
19
+ "<|begin_of_transcription|>",
20
+ "<|end_of_transcription|>",
21
+ "<|code_prefix|>",
22
+ "<|code_middle|>",
23
+ "<|code_suffix|>",
24
+ "/nothink"
25
+ ],
26
+ "eos_token": {
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ },
33
+ "pad_token": {
34
+ "content": "<|endoftext|>",
35
+ "lstrip": false,
36
+ "normalized": false,
37
+ "rstrip": false,
38
+ "single_word": false
39
+ }
40
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bda8e2146c3bb7b7e0fc96dcc4f0aeff041c6c27952e3ace0665663ebff346ba
3
+ size 19970700