bwshen-mi commited on
Commit
2f5a22f
·
verified ·
1 Parent(s): 2192998

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. added_tokens.json +28 -0
  3. config.json +213 -0
  4. configuration_mimo_v2_flash.py +109 -0
  5. merges.txt +0 -0
  6. model.safetensors.index.json +0 -0
  7. model_10.safetensors +3 -0
  8. model_11.safetensors +3 -0
  9. model_12.safetensors +3 -0
  10. model_16.safetensors +3 -0
  11. model_2.safetensors +3 -0
  12. model_23.safetensors +3 -0
  13. model_24.safetensors +3 -0
  14. model_26.safetensors +3 -0
  15. model_27.safetensors +3 -0
  16. model_30.safetensors +3 -0
  17. model_31.safetensors +3 -0
  18. model_32_linear_fc2.safetensors +3 -0
  19. model_33.safetensors +3 -0
  20. model_33_linear_fc2.safetensors +3 -0
  21. model_34.safetensors +3 -0
  22. model_34_linear_fc2.safetensors +3 -0
  23. model_36_linear_fc2.safetensors +3 -0
  24. model_37_linear_fc2.safetensors +3 -0
  25. model_38_linear_fc2.safetensors +3 -0
  26. model_39_linear_fc2.safetensors +3 -0
  27. model_3_linear_fc2.safetensors +3 -0
  28. model_4.safetensors +3 -0
  29. model_41_linear_fc2.safetensors +3 -0
  30. model_42_linear_fc2.safetensors +3 -0
  31. model_43_linear_fc2.safetensors +3 -0
  32. model_44_linear_fc2.safetensors +3 -0
  33. model_45.safetensors +3 -0
  34. model_45_linear_fc2.safetensors +3 -0
  35. model_46_linear_fc2.safetensors +3 -0
  36. model_47.safetensors +3 -0
  37. model_47_linear_fc2.safetensors +3 -0
  38. model_4_linear_fc2.safetensors +3 -0
  39. model_5.safetensors +3 -0
  40. model_6.safetensors +3 -0
  41. model_7.safetensors +3 -0
  42. model_7_linear_fc2.safetensors +3 -0
  43. model_8.safetensors +3 -0
  44. model_8_linear_fc2.safetensors +3 -0
  45. model_9.safetensors +3 -0
  46. model_final.safetensors +3 -0
  47. modeling_mimo_v2_flash.py +664 -0
  48. special_tokens_map.json +31 -0
  49. tokenizer.json +3 -0
  50. tokenizer_config.json +240 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
config.json ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MiMoV2FlashForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mimo_v2_flash.MiMoV2FlashConfig",
7
+ "AutoModel": "modeling_mimo_v2_flash.MiMoV2FlashModel",
8
+ "AutoModelForCausalLM": "modeling_mimo_v2_flash.MiMoV2FlashForCausalLM"
9
+ },
10
+ "quantization_config": {
11
+ "activation_scheme": "dynamic",
12
+ "fmt": "e4m3",
13
+ "packed_modules_mapping": {},
14
+ "quant_method": "fp8",
15
+ "ignored_layers": [
16
+ "model.layers.0.self_attn.o_proj",
17
+ "model.layers.1.self_attn.o_proj",
18
+ "model.layers.2.self_attn.o_proj",
19
+ "model.layers.3.self_attn.o_proj",
20
+ "model.layers.4.self_attn.o_proj",
21
+ "model.layers.5.self_attn.o_proj",
22
+ "model.layers.6.self_attn.o_proj",
23
+ "model.layers.7.self_attn.o_proj",
24
+ "model.layers.8.self_attn.o_proj",
25
+ "model.layers.9.self_attn.o_proj",
26
+ "model.layers.10.self_attn.o_proj",
27
+ "model.layers.11.self_attn.o_proj",
28
+ "model.layers.12.self_attn.o_proj",
29
+ "model.layers.13.self_attn.o_proj",
30
+ "model.layers.14.self_attn.o_proj",
31
+ "model.layers.15.self_attn.o_proj",
32
+ "model.layers.16.self_attn.o_proj",
33
+ "model.layers.17.self_attn.o_proj",
34
+ "model.layers.18.self_attn.o_proj",
35
+ "model.layers.19.self_attn.o_proj",
36
+ "model.layers.20.self_attn.o_proj",
37
+ "model.layers.21.self_attn.o_proj",
38
+ "model.layers.22.self_attn.o_proj",
39
+ "model.layers.23.self_attn.o_proj",
40
+ "model.layers.24.self_attn.o_proj",
41
+ "model.layers.25.self_attn.o_proj",
42
+ "model.layers.26.self_attn.o_proj",
43
+ "model.layers.27.self_attn.o_proj",
44
+ "model.layers.28.self_attn.o_proj",
45
+ "model.layers.29.self_attn.o_proj",
46
+ "model.layers.30.self_attn.o_proj",
47
+ "model.layers.31.self_attn.o_proj",
48
+ "model.layers.32.self_attn.o_proj",
49
+ "model.layers.33.self_attn.o_proj",
50
+ "model.layers.34.self_attn.o_proj",
51
+ "model.layers.35.self_attn.o_proj",
52
+ "model.layers.36.self_attn.o_proj",
53
+ "model.layers.37.self_attn.o_proj",
54
+ "model.layers.38.self_attn.o_proj",
55
+ "model.layers.39.self_attn.o_proj",
56
+ "model.layers.40.self_attn.o_proj",
57
+ "model.layers.41.self_attn.o_proj",
58
+ "model.layers.42.self_attn.o_proj",
59
+ "model.layers.43.self_attn.o_proj",
60
+ "model.layers.44.self_attn.o_proj",
61
+ "model.layers.45.self_attn.o_proj",
62
+ "model.layers.46.self_attn.o_proj",
63
+ "model.layers.47.self_attn.o_proj",
64
+ "model.decoder.self_attn.o_proj"
65
+ ],
66
+ "weight_block_size": [
67
+ 128,
68
+ 128
69
+ ]
70
+ },
71
+ "attention_dropout": 0.0,
72
+ "attention_value_scale": 0.707,
73
+ "hidden_act": "silu",
74
+ "hidden_size": 4096,
75
+ "initializer_range": 0.02,
76
+ "intermediate_size": 16384,
77
+ "max_position_embeddings": 262144,
78
+ "model_type": "mimo_v2_flash",
79
+ "num_attention_heads": 64,
80
+ "head_dim": 192,
81
+ "num_hidden_layers": 48,
82
+ "num_key_value_heads": 4,
83
+ "layernorm_epsilon": 1e-05,
84
+ "rope_theta": 5000000,
85
+ "tie_word_embeddings": false,
86
+ "torch_dtype": "bfloat16",
87
+ "transformers_version": "4.40.1",
88
+ "use_cache": true,
89
+ "vocab_size": 152576,
90
+ "partial_rotary_factor": 0.334,
91
+ "sliding_window": 128,
92
+ "swa_rope_theta": 10000,
93
+ "attention_bias": false,
94
+ "v_head_dim": 128,
95
+ "hybrid_layer_pattern": [
96
+ 0,
97
+ 1,
98
+ 1,
99
+ 1,
100
+ 1,
101
+ 0,
102
+ 1,
103
+ 1,
104
+ 1,
105
+ 1,
106
+ 1,
107
+ 0,
108
+ 1,
109
+ 1,
110
+ 1,
111
+ 1,
112
+ 1,
113
+ 0,
114
+ 1,
115
+ 1,
116
+ 1,
117
+ 1,
118
+ 1,
119
+ 0,
120
+ 1,
121
+ 1,
122
+ 1,
123
+ 1,
124
+ 1,
125
+ 0,
126
+ 1,
127
+ 1,
128
+ 1,
129
+ 1,
130
+ 1,
131
+ 0,
132
+ 1,
133
+ 1,
134
+ 1,
135
+ 1,
136
+ 1,
137
+ 0,
138
+ 1,
139
+ 1,
140
+ 1,
141
+ 1,
142
+ 1,
143
+ 0
144
+ ],
145
+ "add_swa_attention_sink_bias": true,
146
+ "add_full_attention_sink_bias": false,
147
+ "sliding_window_size": 128,
148
+ "attention_chunk_size": 128,
149
+ "moe_layer_freq": [
150
+ 0,
151
+ 1,
152
+ 1,
153
+ 1,
154
+ 1,
155
+ 1,
156
+ 1,
157
+ 1,
158
+ 1,
159
+ 1,
160
+ 1,
161
+ 1,
162
+ 1,
163
+ 1,
164
+ 1,
165
+ 1,
166
+ 1,
167
+ 1,
168
+ 1,
169
+ 1,
170
+ 1,
171
+ 1,
172
+ 1,
173
+ 1,
174
+ 1,
175
+ 1,
176
+ 1,
177
+ 1,
178
+ 1,
179
+ 1,
180
+ 1,
181
+ 1,
182
+ 1,
183
+ 1,
184
+ 1,
185
+ 1,
186
+ 1,
187
+ 1,
188
+ 1,
189
+ 1,
190
+ 1,
191
+ 1,
192
+ 1,
193
+ 1,
194
+ 1,
195
+ 1,
196
+ 1,
197
+ 1
198
+ ],
199
+ "moe_intermediate_size": 2048,
200
+ "n_routed_experts": 256,
201
+ "n_shared_experts": null,
202
+ "num_experts_per_tok": 8,
203
+ "norm_topk_prob": true,
204
+ "scoring_func": "sigmoid",
205
+ "n_group": 1,
206
+ "topk_group": 1,
207
+ "topk_method": "noaux_tc",
208
+ "routed_scaling_factor": null,
209
+ "swa_num_attention_heads": 64,
210
+ "swa_num_key_value_heads": 8,
211
+ "swa_head_dim": 192,
212
+ "swa_v_head_dim": 128
213
+ }
configuration_mimo_v2_flash.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright 2025 Xiaomi Corporation.
4
+ # Copyright 2025 The HuggingFace Inc. team.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.modeling_rope_utils import rope_config_validation
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class MiMoV2FlashConfig(PretrainedConfig):
27
+
28
+ model_type = ""
29
+ keys_to_ignore_at_inference = ["past_key_values"]
30
+
31
+ # Default tensor parallel plan for base model `Hybrid`
32
+ base_model_tp_plan = {
33
+ "layers.*.self_attn.q_proj": "colwise",
34
+ "layers.*.self_attn.k_proj": "colwise",
35
+ "layers.*.self_attn.v_proj": "colwise",
36
+ "layers.*.self_attn.o_proj": "rowwise",
37
+ "layers.*.mlp.gate_proj": "colwise",
38
+ "layers.*.mlp.up_proj": "colwise",
39
+ "layers.*.mlp.down_proj": "rowwise",
40
+ }
41
+ base_model_pp_plan = {
42
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
43
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
44
+ "norm": (["hidden_states"], ["hidden_states"]),
45
+ }
46
+
47
+ attribute_map = {
48
+ "num_local_experts": "n_routed_experts",
49
+ }
50
+
51
+ def __init__(
52
+ self,
53
+ vocab_size=151936,
54
+ hidden_size=4096,
55
+ intermediate_size=22016,
56
+ num_hidden_layers=32,
57
+ num_attention_heads=32,
58
+ num_key_value_heads=32,
59
+ hidden_act="silu",
60
+ max_position_embeddings=32768,
61
+ initializer_range=0.02,
62
+ layernorm_epsilon=1e-6,
63
+ use_cache=True,
64
+ tie_word_embeddings=False,
65
+ rope_theta=10000.0,
66
+ rope_scaling=None,
67
+ attention_dropout=0.0,
68
+ hybrid_block_size=None,
69
+ hybrid_layer_pattern=None,
70
+ partial_rotary_factor=1.0,
71
+ **kwargs,
72
+ ):
73
+ self.vocab_size = vocab_size
74
+ self.max_position_embeddings = max_position_embeddings
75
+ self.hidden_size = hidden_size
76
+ self.intermediate_size = intermediate_size
77
+ self.num_hidden_layers = num_hidden_layers
78
+ self.num_attention_heads = num_attention_heads
79
+
80
+ # for backward compatibility
81
+ if num_key_value_heads is None:
82
+ num_key_value_heads = num_attention_heads
83
+
84
+ self.num_key_value_heads = num_key_value_heads
85
+ self.hidden_act = hidden_act
86
+ self.initializer_range = initializer_range
87
+ self.layernorm_epsilon = layernorm_epsilon
88
+ self.use_cache = use_cache
89
+ self.rope_theta = rope_theta
90
+ self.rope_scaling = rope_scaling
91
+ self.attention_dropout = attention_dropout
92
+
93
+ if hybrid_block_size is not None and hybrid_layer_pattern is None:
94
+ hybrid_layer_pattern = [0 if ((i + 1) % hybrid_block_size == 0) else 1 for i in range(num_hidden_layers)]
95
+ self.hybrid_block_size = hybrid_block_size
96
+ self.hybrid_layer_pattern = hybrid_layer_pattern
97
+
98
+ self.partial_rotary_factor = partial_rotary_factor
99
+
100
+ # Validate the correctness of rotary position embeddings parameters
101
+ # BC: if there is a 'type' field, move it to 'rope_type'.
102
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
103
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
104
+ rope_config_validation(self)
105
+
106
+ super().__init__(
107
+ tie_word_embeddings=tie_word_embeddings,
108
+ **kwargs,
109
+ )
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
model_10.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15a76d7cd96b8f855072b0b9b2eb2ef323f45605eab9942d1962f3b189d9ae38
3
+ size 132154328
model_11.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d34f5fc039a11df7686fed6a46f6f43e5241a49dd8e4df1b959ae3b512d889c5
3
+ size 126910184
model_12.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb75be363c969bc2487049d654b47418f453f8080a89eb510be9d5bd57c9620b
3
+ size 132154328
model_16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61e250132d8f4f753d1ee7d5cbffce109bac0e419685757781417d500d0bcc87
3
+ size 132154328
model_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ea731b8fe60181264e89e23b6f7ae43616353b2ceb843a9194806b424c7fcf
3
+ size 132154312
model_23.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e165718c4247b60b59d36338804175189c963231048b06013f5863b06510ac33
3
+ size 126910184
model_24.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:feea6437a2cd47a770360fe43766f63631556fa4db7ebb3dcad8a7a03f9f3e31
3
+ size 132154328
model_26.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d310e39d013dc4b31b4446e4608fae98b48ec18c1c01e2001dd49386dfdb0183
3
+ size 132154328
model_27.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68d833853c487b7b4d362eef59a4d1c24966822a124686973267da8922811794
3
+ size 132154328
model_30.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:153fd1cb28d8ded5b7ea3ad3b3da0d6fcee30b369205955fd7caa121cea1dc95
3
+ size 132154328
model_31.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ff636ef43f13b0c98ebc1b0fcbc9b87386a494b607657aeef003143f7c67e54
3
+ size 132154328
model_32_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc3124abcc242848f94c957a6041794815486bf2293f18847ea5a2e23ce37be8
3
+ size 2148072376
model_33.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2d03e3a4487247dc26331a1fb7feeb652dfcdc54366b8b1f3f9f2d6fa3ffb03
3
+ size 132154328
model_33_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6259518d3e7f46a7b6e13f41ee9db710e5aaa4408bb9509e8118202bb39a56bb
3
+ size 2148072376
model_34.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1d7181c26f84fd2e3b05849adae8cf4f056afad0877039c7275b50b5c52914e
3
+ size 132154328
model_34_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9071848efe17a71f40c2a53814667c6fc956cbd4167afabac1de46321c75afde
3
+ size 2148072376
model_36_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e90abaa875cdbc9b909577f209d7a6b9e2e6bcffa5154ec68fd1c42a05d52a5a
3
+ size 2148072376
model_37_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60c9c7550c59d9ceb8eb0b4e85d3fc6c3ad5e91fe49a0b38a7d5c82f1d2913c4
3
+ size 2148072376
model_38_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d420a10c40167b511dcd1014c735ef9b17b64b6d09ba18c06c65dd9e35247b6
3
+ size 2148072376
model_39_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2a0addc55f7a9615a1bdeea587f3fe3388bc3b39ffc3ba0efac6b5fe2bd6497
3
+ size 2148072376
model_3_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:451a6f68da921add092f85ba78eb545f6cd182684a3151e701f1932752368021
3
+ size 2148071864
model_4.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3d91093a979f17cc49daf48e92b683a8d06f74aba1a716babd02409687e320e
3
+ size 132154312
model_41_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f675c8a46718a7b108730a2c7fef60140fafecb93eb1ec94a4ed6003c57897a4
3
+ size 2148072376
model_42_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc752712bcea4e6548d2058d315e856e4a72d3e3b1aa02ea24e80b28c041da2b
3
+ size 2148072376
model_43_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32ed548ceae946e085cc32c62d6dae38c29407c66772068011307a9732d6f26b
3
+ size 2148072376
model_44_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:838b6dc6cdc6dd81582383727f37b81fec6d6b10d82bf8b3c4ff39d0e5b3d326
3
+ size 2148072376
model_45.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a4811b644066e0d1a64d1b9b727ed0124fe1ca491c0eec4a157f9a9a1dcf4d5
3
+ size 132154328
model_45_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4855416cd86ed6531f19e767c35a3a2d1085eaca1b1e3ff74beffe4543f3df7f
3
+ size 2148072376
model_46_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3120281e83b4844bd287c15bba7b60c6427bf2234deeb89198b29f2628643657
3
+ size 2148072376
model_47.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffdc8fb803808c1f5a4d57a31bb3f60dbe6f851192a0970cedad11e31fe0d378
3
+ size 126910184
model_47_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ed8bb40c11e14e15e83c72169256a7fceeb15259bce088206d6b1c3134fa779
3
+ size 2148072376
model_4_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89d1cf0c021e612284ad9efcd4b0efe3f0fbba3a1b5dcfe667cfbbd0e345e4e4
3
+ size 2148071864
model_5.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04c3714f268d8b63a7899051cff1b261a5040775b6bb971436a800ba52ee934c
3
+ size 126910168
model_6.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0aee37741f0f614a0c699d6b0307b4b747a9760b21076dadf6d8a2d175516b30
3
+ size 132154312
model_7.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d532a38f375a26c439ae135a678178b2999758f6c306ec3a5ce2a61be651e82d
3
+ size 132154312
model_7_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d3744ee45c9993bdc1df21c61e9fb56551f40f92a9aea6a536880b8db15cdcd
3
+ size 2148071864
model_8.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af3430a88830091f179afc27e76bf121ff145125aebbe1f422fe3e0a3326ef64
3
+ size 132154312
model_8_linear_fc2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:456252982a334059d7c626f46fe2967809714014f10924585e9fcb6183ec0633
3
+ size 2148071864
model_9.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77e5c9fc719a87bc82e9a66bfc7e2188418a82c90fa35970495ddc68c88e964f
3
+ size 132154312
model_final.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e179b6c4c1974db5f8c71799eceb7bb0b35b27cf112b8bcd9ebe1ef01c53e6f7
3
+ size 1249910976
modeling_mimo_v2_flash.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright 2025 Xiaomi Corporation.
4
+ # Copyright 2025 The HuggingFace Inc. team.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import Callable, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.generation import GenerationMixin
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.integrations import use_kernel_forward_from_hub
28
+
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPast,
31
+ CausalLMOutputWithPast,
32
+ )
33
+
34
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import (
39
+ logging,
40
+ )
41
+
42
+ from transformers.modeling_outputs import MoeModelOutputWithPast
43
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
44
+ from .configuration_mimo_v2_flash import MiMoV2FlashConfig
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ def rotate_half(x):
50
+ """Rotates half the hidden dims of the input."""
51
+ x1 = x[..., : x.shape[-1] // 2]
52
+ x2 = x[..., x.shape[-1] // 2:]
53
+ return torch.cat((-x2, x1), dim=-1)
54
+
55
+
56
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
57
+ """Applies Rotary Position Embedding to the query and key tensors.
58
+
59
+ Args:
60
+ q (`torch.Tensor`): The query tensor.
61
+ k (`torch.Tensor`): The key tensor.
62
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
63
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
64
+ position_ids (`torch.Tensor`, *optional*):
65
+ Deprecated and unused.
66
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
67
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
68
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
69
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
70
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
71
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
72
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
73
+ Returns:
74
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
75
+ """
76
+ cos = cos.unsqueeze(unsqueeze_dim)
77
+ sin = sin.unsqueeze(unsqueeze_dim)
78
+ q_embed = (q * cos) + (rotate_half(q) * sin)
79
+ k_embed = (k * cos) + (rotate_half(k) * sin)
80
+ return q_embed, k_embed
81
+
82
+
83
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
84
+ """
85
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
86
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
87
+ """
88
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
89
+ if n_rep == 1:
90
+ return hidden_states
91
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
92
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
93
+
94
+
95
+ def eager_attention_forward(
96
+ module: nn.Module,
97
+ query: torch.Tensor,
98
+ key: torch.Tensor,
99
+ value: torch.Tensor,
100
+ attention_mask: Optional[torch.Tensor],
101
+ scaling: float,
102
+ dropout: float = 0.0,
103
+ sinks: Optional[torch.Tensor] = None,
104
+ ):
105
+ key_states = repeat_kv(key, module.num_key_value_groups)
106
+ value_states = repeat_kv(value, module.num_key_value_groups)
107
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
108
+ if attention_mask is not None:
109
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
110
+ attn_weights = attn_weights + causal_mask
111
+
112
+ if sinks is not None:
113
+ sinks = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
114
+ attn_weights = torch.cat([attn_weights, sinks], dim=-1)
115
+
116
+ attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values
117
+ probs = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
118
+
119
+ if sinks is not None:
120
+ probs = probs[..., :-1] # we drop the sink here
121
+
122
+ attn_weights = nn.functional.dropout(probs, p=dropout, training=module.training)
123
+ attn_output = torch.matmul(attn_weights, value_states)
124
+ attn_output = attn_output.transpose(1, 2).contiguous()
125
+ return attn_output, attn_weights
126
+
127
+
128
+ @use_kernel_forward_from_hub("RMSNorm")
129
+ class MiMoV2RMSNorm(nn.Module):
130
+ def __init__(self, hidden_size, eps=1e-6):
131
+ """
132
+ MiMoV2RMSNorm is equivalent to T5LayerNorm
133
+ """
134
+ super().__init__()
135
+ self.weight = nn.Parameter(torch.ones(hidden_size))
136
+ self.variance_epsilon = eps
137
+
138
+ def forward(self, hidden_states):
139
+ input_dtype = hidden_states.dtype
140
+ hidden_states = hidden_states.to(torch.float32)
141
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
142
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
143
+ return self.weight * hidden_states.to(input_dtype)
144
+
145
+
146
+ class MiMoV2MLP(nn.Module):
147
+ """MiMoV2MLP matching the gate, up, and down projection layers."""
148
+
149
+ def __init__(self, config: MiMoV2FlashConfig, intermediate_size=None):
150
+ super().__init__()
151
+ self.config = config
152
+ self.hidden_size = config.hidden_size
153
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
154
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
155
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
156
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
157
+ self.act_fn = ACT2FN[config.hidden_act]
158
+
159
+ def forward(self, hidden_states):
160
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
161
+ return down_proj
162
+
163
+
164
+ class MiMoV2MoEGate(nn.Module):
165
+ def __init__(self, config):
166
+ super().__init__()
167
+ self.config = config
168
+ self.top_k = config.num_experts_per_tok
169
+ self.n_routed_experts = config.n_routed_experts
170
+ self.routed_scaling_factor = (
171
+ config.routed_scaling_factor
172
+ if config.routed_scaling_factor is not None
173
+ else 1.0
174
+ )
175
+ self.scoring_func = config.scoring_func
176
+ self.topk_method = config.topk_method
177
+ self.n_group = config.n_group
178
+ self.topk_group = config.topk_group
179
+
180
+ # topk selection algorithm
181
+ self.norm_topk_prob = config.norm_topk_prob
182
+ self.gating_dim = config.hidden_size
183
+ self.weight = nn.Parameter(
184
+ torch.empty((self.n_routed_experts, self.gating_dim))
185
+ )
186
+ if self.topk_method == "noaux_tc":
187
+ self.e_score_correction_bias = nn.Parameter(
188
+ torch.empty((self.n_routed_experts))
189
+ )
190
+
191
+ def forward(self, hidden_states):
192
+ bsz, seq_len, h = hidden_states.shape
193
+ ### compute gating score
194
+ hidden_states = hidden_states.view(-1, h)
195
+ logits = F.linear(
196
+ hidden_states.type(torch.float32), self.weight.type(torch.float32), None
197
+ )
198
+ if self.scoring_func == "sigmoid":
199
+ scores = logits.sigmoid()
200
+ else:
201
+ raise NotImplementedError(
202
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
203
+ )
204
+
205
+ ### select top-k experts
206
+ if self.topk_method == "noaux_tc":
207
+ assert not self.training
208
+ scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
209
+ group_scores = (
210
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
211
+ ) # [n, n_group]
212
+ group_idx = torch.topk(
213
+ group_scores, k=self.topk_group, dim=-1, sorted=False
214
+ )[
215
+ 1
216
+ ] # [n, top_k_group]
217
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
218
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
219
+ score_mask = (
220
+ group_mask.unsqueeze(-1)
221
+ .expand(
222
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
223
+ )
224
+ .reshape(bsz * seq_len, -1)
225
+ ) # [n, e]
226
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
227
+ _, topk_idx = torch.topk(
228
+ tmp_scores, k=self.top_k, dim=-1, sorted=False
229
+ )
230
+ topk_weight = scores.gather(1, topk_idx)
231
+ else:
232
+ raise NotImplementedError(
233
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
234
+ )
235
+
236
+ ### norm gate to sum 1
237
+ if self.top_k > 1 and self.norm_topk_prob:
238
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
239
+ topk_weight = topk_weight / denominator
240
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
241
+
242
+ return topk_idx, topk_weight
243
+
244
+
245
+ class MiMoV2MoE(nn.Module):
246
+ """
247
+ A mixed expert module containing shared experts.
248
+ """
249
+
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ self.config = config
253
+ self.experts = nn.ModuleList(
254
+ [
255
+ MiMoV2MLP(config, intermediate_size=config.moe_intermediate_size)
256
+ for _ in range(config.n_routed_experts)
257
+ ]
258
+ )
259
+ self.gate = MiMoV2MoEGate(config)
260
+
261
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
262
+ r"""
263
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
264
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
265
+ """
266
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
267
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
268
+ expert_mask = expert_mask.permute(2, 0, 1)
269
+
270
+ for expert_idx in range(len(self.experts)):
271
+ expert = self.experts[expert_idx]
272
+ mask = expert_mask[expert_idx]
273
+ token_indices, weight_indices = torch.where(mask)
274
+
275
+ if token_indices.numel() > 0:
276
+ expert_weights = topk_weights[token_indices, weight_indices]
277
+ expert_input = hidden_states[token_indices]
278
+ expert_output = expert(expert_input)
279
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
280
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
281
+
282
+ # in original deepseek, the output of the experts are gathered once we leave this module
283
+ # thus the moe module is itelsf an IsolatedParallel module
284
+ # and all expert are "local" meaning we shard but we don't gather
285
+ return final_hidden_states.type(hidden_states.dtype)
286
+
287
+
288
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
289
+ orig_shape = hidden_states.shape
290
+ topk_indices, topk_weights = self.gate(hidden_states)
291
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
292
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
293
+
294
+ return hidden_states
295
+
296
+
297
+ class MiMoV2Attention(nn.Module):
298
+ """MiMoV2 Global Attention (pattern == 0) and Sliding Window Attention (pattern == 1)."""
299
+
300
+ def __init__(self, config: MiMoV2FlashConfig, is_swa: bool, layer_idx: int):
301
+ super().__init__()
302
+ self.config = config
303
+ self.layer_idx = layer_idx
304
+
305
+ if is_swa:
306
+ self.head_dim = config.swa_head_dim
307
+ self.v_head_dim = config.swa_v_head_dim
308
+ self.num_attention_heads = config.swa_num_attention_heads
309
+ self.num_key_value_heads = config.swa_num_key_value_heads
310
+ else:
311
+ self.head_dim = config.head_dim
312
+ self.v_head_dim = config.v_head_dim
313
+ self.num_attention_heads = config.num_attention_heads
314
+ self.num_key_value_heads = config.num_key_value_heads
315
+
316
+ self.rope_dim = int(self.head_dim * config.partial_rotary_factor)
317
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
318
+ self.attention_bias = config.attention_bias
319
+ self.attention_dropout: float = config.attention_dropout
320
+ self.scaling = self.head_dim ** -0.5
321
+
322
+ # These dimensions are for the attention layers
323
+ q_hidden_size = self.num_attention_heads * self.head_dim
324
+ k_hidden_size = self.num_key_value_heads * self.head_dim
325
+ v_hidden_size = self.num_key_value_heads * self.v_head_dim
326
+ o_hidden_size = self.num_attention_heads * self.v_head_dim
327
+
328
+ self.q_proj = nn.Linear(config.hidden_size, q_hidden_size, bias=self.attention_bias)
329
+ self.k_proj = nn.Linear(config.hidden_size, k_hidden_size, bias=self.attention_bias)
330
+ self.v_proj = nn.Linear(config.hidden_size, v_hidden_size, bias=self.attention_bias)
331
+ self.o_proj = nn.Linear(o_hidden_size, config.hidden_size, bias=False)
332
+
333
+ self.attention_sink_bias = (
334
+ torch.nn.Parameter(torch.empty(config.num_attention_heads), requires_grad=False)
335
+ if (config.add_full_attention_sink_bias and not is_swa) or (config.add_swa_attention_sink_bias and is_swa)
336
+ else None
337
+ )
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
343
+ attention_mask: Optional[torch.Tensor],
344
+ past_key_values: Optional[Cache] = None,
345
+ cache_position: Optional[torch.LongTensor] = None,
346
+ position_ids: Optional[torch.LongTensor] = None,
347
+ **kwargs: Unpack[TransformersKwargs],
348
+ ) -> tuple[torch.Tensor, torch.Tensor]:
349
+ input_shape = hidden_states.shape[:-1]
350
+ qk_hidden_shape = (*input_shape, -1, self.head_dim)
351
+ v_hidden_shape = (*input_shape, -1, self.v_head_dim)
352
+
353
+ query_states = self.q_proj(hidden_states).view(qk_hidden_shape).transpose(1, 2)
354
+ key_states = self.k_proj(hidden_states).view(qk_hidden_shape).transpose(1, 2)
355
+ value_states = self.v_proj(hidden_states).view(v_hidden_shape).transpose(1, 2)
356
+
357
+ cos, sin = position_embeddings
358
+
359
+ query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
360
+ key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
361
+
362
+ query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin)
363
+
364
+ query_states = torch.cat([query_rope, query_nope], dim=-1)
365
+ key_states = torch.cat([key_rope, key_nope], dim=-1)
366
+
367
+ if past_key_values is not None:
368
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
369
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
370
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
371
+
372
+ attention_interface: Callable = eager_attention_forward
373
+ if self.config._attn_implementation != "eager":
374
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
375
+
376
+ attn_output, attn_weights = attention_interface(
377
+ self,
378
+ query_states,
379
+ key_states,
380
+ value_states,
381
+ attention_mask,
382
+ dropout=0.0 if not self.training else self.attention_dropout,
383
+ scaling=self.scaling,
384
+ position_ids=position_ids,
385
+ sinks=self.attention_sink_bias,
386
+ )
387
+
388
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
389
+ attn_output = self.o_proj(attn_output)
390
+ return attn_output, attn_weights
391
+
392
+
393
+ class MiMoV2DecoderLayer(nn.Module):
394
+ """
395
+ MiMoV2 Decoder Layer. It dynamically chooses the correct attention
396
+ module based on the layer index and the `hybrid_layer_pattern`.
397
+ """
398
+
399
+ def __init__(self, config: MiMoV2FlashConfig, layer_idx: int):
400
+ super().__init__()
401
+
402
+ # This is the key logic: choose the module based on the pattern
403
+ is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1
404
+ if is_swa_layer:
405
+ self.attention_type = "sliding_window_attention"
406
+ self.self_attn = MiMoV2Attention(config, True, layer_idx)
407
+ else:
408
+ self.attention_type = "full_attention"
409
+ self.self_attn = MiMoV2Attention(config, False, layer_idx)
410
+
411
+ self.mlp = (
412
+ MiMoV2MoE(config)
413
+ if (
414
+ getattr(config, 'n_routed_experts', None) is not None
415
+ and config.moe_layer_freq[layer_idx]
416
+ )
417
+ else MiMoV2MLP(config)
418
+ )
419
+
420
+ self.input_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
421
+ self.post_attention_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
422
+ self.hidden_size = config.hidden_size
423
+
424
+ def forward(
425
+ self,
426
+ hidden_states: torch.Tensor,
427
+ attention_mask: Optional[torch.Tensor] = None,
428
+ position_ids: Optional[torch.LongTensor] = None,
429
+ past_key_values: Optional[Cache] = None,
430
+ use_cache: Optional[bool] = False,
431
+ cache_position: Optional[torch.LongTensor] = None,
432
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
433
+ **kwargs: Unpack[TransformersKwargs],
434
+ ) -> torch.Tensor:
435
+ residual = hidden_states
436
+ hidden_states = self.input_layernorm(hidden_states)
437
+ # Self Attention
438
+ hidden_states, _ = self.self_attn(
439
+ hidden_states=hidden_states,
440
+ attention_mask=attention_mask,
441
+ position_ids=position_ids,
442
+ past_key_values=past_key_values,
443
+ use_cache=use_cache,
444
+ cache_position=cache_position,
445
+ position_embeddings=position_embeddings,
446
+ **kwargs,
447
+ )
448
+ hidden_states = residual + hidden_states
449
+
450
+ # MLP or MOE
451
+ residual = hidden_states
452
+ hidden_states = self.post_attention_layernorm(hidden_states)
453
+ hidden_states = self.mlp(hidden_states)
454
+ hidden_states = residual + hidden_states
455
+ return hidden_states
456
+
457
+ class MiMoV2FlashRotaryEmbedding(nn.Module):
458
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
459
+
460
+ def __init__(self, config: MiMoV2FlashConfig, is_swa, device=None):
461
+ super().__init__()
462
+ # BC: "rope_type" was originally "type"
463
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
464
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
465
+ else:
466
+ self.rope_type = "default"
467
+ self.max_seq_len_cached = config.max_position_embeddings
468
+ self.original_max_seq_len = config.max_position_embeddings
469
+
470
+ self.config = config
471
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
472
+
473
+ if is_swa:
474
+ self.config.rope_theta = config.swa_rope_theta
475
+ self.config.head_dim = config.swa_head_dim
476
+
477
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
478
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
479
+ self.original_inv_freq = self.inv_freq
480
+
481
+ @torch.no_grad()
482
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
483
+ def forward(self, x, position_ids):
484
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
485
+ position_ids_expanded = position_ids[:, None, :].float()
486
+
487
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
488
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
489
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
490
+ emb = torch.cat((freqs, freqs), dim=-1)
491
+ cos = emb.cos() * self.attention_scaling
492
+ sin = emb.sin() * self.attention_scaling
493
+
494
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
495
+
496
+
497
+ @auto_docstring
498
+ class MiMoV2Model(PreTrainedModel):
499
+ """The main 'model' block, corresponding to `model.` in the weight map."""
500
+ config_class = MiMoV2FlashConfig
501
+
502
+ def __init__(self, config: MiMoV2FlashConfig):
503
+ super().__init__(config)
504
+ self.vocab_size = config.vocab_size
505
+
506
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
507
+ self.layers = nn.ModuleList(
508
+ [MiMoV2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
509
+ )
510
+ self.norm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
511
+ self.rotary_emb = MiMoV2FlashRotaryEmbedding(config=config, is_swa=False)
512
+ self.swa_rotary_emb = MiMoV2FlashRotaryEmbedding(config=config, is_swa=True)
513
+
514
+ self.has_sliding_layers = any(
515
+ pattern == 1 for pattern in config.hybrid_layer_pattern
516
+ )
517
+
518
+ # For Huggingface DynamicCache compatibility
519
+ self.config.layer_types = [
520
+ "sliding_attention" if config.hybrid_layer_pattern[i] == 1 else "full_attention"
521
+ for i in range(config.num_hidden_layers)
522
+ ]
523
+
524
+ @auto_docstring
525
+ def forward(
526
+ self,
527
+ input_ids: Optional[torch.LongTensor] = None,
528
+ attention_mask: Optional[torch.Tensor] = None,
529
+ position_ids: Optional[torch.LongTensor] = None,
530
+ past_key_values: Optional[Cache] = None,
531
+ inputs_embeds: Optional[torch.FloatTensor] = None,
532
+ use_cache: Optional[bool] = None,
533
+ cache_position: Optional[torch.LongTensor] = None,
534
+ **kwargs: Unpack[TransformersKwargs],
535
+ ) -> MoeModelOutputWithPast:
536
+ if (input_ids is None) ^ (inputs_embeds is not None):
537
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
538
+
539
+ if inputs_embeds is None:
540
+ inputs_embeds = self.embed_tokens(input_ids)
541
+
542
+ if use_cache and past_key_values is None:
543
+ past_key_values = DynamicCache(config=self.config)
544
+
545
+ if cache_position is None:
546
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
547
+ cache_position = torch.arange(
548
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
549
+ )
550
+
551
+ if position_ids is None:
552
+ position_ids = cache_position.unsqueeze(0)
553
+
554
+ # It may already have been prepared by e.g. `generate`
555
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
556
+ # Prepare mask arguments
557
+ mask_kwargs = {
558
+ "config": self.config,
559
+ "input_embeds": inputs_embeds,
560
+ "attention_mask": attention_mask,
561
+ "cache_position": cache_position,
562
+ "past_key_values": past_key_values,
563
+ "position_ids": position_ids,
564
+ }
565
+ # Create the masks
566
+ causal_mask_mapping = {
567
+ "full_attention": create_causal_mask(**mask_kwargs),
568
+ }
569
+ # The sliding window alternating layers are not always activated depending on the config
570
+ if self.has_sliding_layers:
571
+ causal_mask_mapping["sliding_window_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
572
+
573
+ hidden_states = inputs_embeds
574
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
575
+ swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
576
+
577
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
578
+ hidden_states = decoder_layer(
579
+ hidden_states,
580
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
581
+ position_embeddings=(
582
+ position_embeddings
583
+ if decoder_layer.attention_type == "full_attention"
584
+ else swa_position_embeddings
585
+ ),
586
+ position_ids=position_ids,
587
+ past_key_values=past_key_values,
588
+ use_cache=use_cache,
589
+ cache_position=cache_position,
590
+ **kwargs,
591
+ )
592
+
593
+ hidden_states = self.norm(hidden_states)
594
+ return BaseModelOutputWithPast(
595
+ last_hidden_state=hidden_states,
596
+ past_key_values=past_key_values if use_cache else None,
597
+ )
598
+
599
+
600
+ @auto_docstring
601
+ class MiMoV2FlashForCausalLM(PreTrainedModel,GenerationMixin):
602
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
603
+ _tp_plan = {"lm_head": "colwise_rep"}
604
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
605
+
606
+ config_class = MiMoV2FlashConfig
607
+ _keys_to_ignore_on_load_unexpected = [r"model.layers\.\d+\.self_attn\.rotary_emb\.inv_freq"]
608
+
609
+ def __init__(self, config: MiMoV2FlashConfig):
610
+ super().__init__(config)
611
+ self.model = MiMoV2Model(config)
612
+ self.vocab_size = config.vocab_size
613
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
614
+
615
+ # Initialize weights and apply final processing
616
+ self.post_init()
617
+
618
+ @can_return_tuple
619
+ @auto_docstring
620
+ def forward(
621
+ self,
622
+ input_ids: Optional[torch.LongTensor] = None,
623
+ attention_mask: Optional[torch.Tensor] = None,
624
+ position_ids: Optional[torch.LongTensor] = None,
625
+ past_key_values: Optional[Cache] = None,
626
+ inputs_embeds: Optional[torch.FloatTensor] = None,
627
+ labels: Optional[torch.LongTensor] = None,
628
+ use_cache: Optional[bool] = None,
629
+ cache_position: Optional[torch.LongTensor] = None,
630
+ logits_to_keep: Union[int, torch.Tensor] = 0,
631
+ **kwargs: Unpack[TransformersKwargs],
632
+ ) -> CausalLMOutputWithPast:
633
+
634
+ outputs: BaseModelOutputWithPast = self.model(
635
+ input_ids=input_ids,
636
+ attention_mask=attention_mask,
637
+ position_ids=position_ids,
638
+ past_key_values=past_key_values,
639
+ inputs_embeds=inputs_embeds,
640
+ use_cache=use_cache,
641
+ cache_position=cache_position,
642
+ **kwargs,
643
+ )
644
+
645
+ hidden_states = outputs.last_hidden_state
646
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
647
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
648
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
649
+
650
+ loss = None
651
+ if labels is not None:
652
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
653
+
654
+ return CausalLMOutputWithPast(
655
+ loss=loss,
656
+ logits=logits,
657
+ past_key_values=outputs.past_key_values,
658
+ hidden_states=outputs.hidden_states,
659
+ attentions=outputs.attentions,
660
+ )
661
+
662
+ __all__ = [
663
+ "MiMoV2FlashForCausalLM"
664
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = false -%}\n{%- endif -%}\n{%- if not enable_thinking is defined -%}\n {%- set enable_thinking = false -%}\n{%- endif -%}\n{%- if not keep_all_reasoning is defined -%}\n {%- set keep_all_reasoning = false -%}\n{%- endif -%}\n{%- macro render_extra_keys(json_dict, handled_keys) -%}\n {%- if json_dict is mapping %}\n {%- for json_key in json_dict if json_key not in handled_keys %}\n {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}\n {{- '\\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}\n {%- else %}\n {{-'\\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n{%- endmacro -%}\n{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- set ns = namespace(last_user_index=-1) %}\n{%- for m in loop_messages %}\n {%- if m.role == 'user' %}\n {%- set ns.last_user_index = loop.index0 -%}\n {%- endif %}\n{%- endfor %}\n{%- if not tools is defined %}\n {%- set tools = [] %}\n{%- endif %}\n{%- if system_message is defined %}\n {{- \"<|im_start|>system\\n\" + system_message }}\n{%- else %}\n {{- \"<|im_start|>system\\nYou are MiMo, a helpful AI assistant engineered by Xiaomi.\" }}\n{%- endif %}\n{%- if tools is iterable and tools | length > 0 %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou have access to the following functions:\\n\\n\" }}\n {{- \"<tools>\" }}\n {%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- \"\\n<function>\\n<name>\" ~ tool.name ~ \"</name>\" }}\n {%- if tool.description is defined %}\n {{- '\\n<description>' ~ (tool.description | trim) ~ '</description>' }}\n {%- endif %}\n {{- '\\n<parameters>' }}\n {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- '\\n<parameter>' }}\n {{- '\\n<name>' ~ param_name ~ '</name>' }}\n {%- if param_fields.type is defined %}\n {{- '\\n<type>' ~ (param_fields.type | string) ~ '</type>' }}\n {%- endif %}\n {%- if param_fields.description is defined %}\n {{- '\\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}\n {%- endif %}\n {%- set handled_keys = ['name', 'type', 'description'] %}\n {{- render_extra_keys(param_fields, handled_keys) }}\n {{- '\\n</parameter>' }}\n {%- endfor %}\n {%- endif %}\n {%- set handled_keys = ['type', 'properties'] %}\n {{- render_extra_keys(tool.parameters, handled_keys) }}\n {{- '\\n</parameters>' }}\n {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}\n {{- render_extra_keys(tool, handled_keys) }}\n {{- '\\n</function>' }}\n {%- endfor %}\n {{- \"\\n</tools>\" }}\n {{- '\\n\\nFor each function call, output the function name and arguments in the following format:\\n<tool_call>\\n<function=example_function_name>\\n<parameter=example_parameter_1>value_1</parameter>\\n<parameter=example_parameter_2>This is the value for the second parameter\\nthat can span\\nmultiple lines</parameter>\\n</function>\\n</tool_call>\\n\\n<IMPORTANT>\\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\\n- DO NOT use function calls inside <think></think> tags.\\n- The value enclosed between parameter tags is preserved exactly as-is, including newlines and spaces.\\n</IMPORTANT>' }}\n{%- endif %}\n{{- '<|im_end|>' }}\n{%- for message in loop_messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if message.role == \"assistant\" %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- set reasoning_content = '' %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].split('<think>')[-1] %}\n {%- set content = content.split('</think>')[-1] %}\n {%- endif %}\n {%- endif %}\n {%- if (keep_all_reasoning or loop.index0 > ns.last_user_index) and reasoning_content -%}\n {{- '<|im_start|>' + message.role + '\\n<think>' + reasoning_content + '</think>' + content }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n<think></think>' + content }}\n {%- endif %}\n {%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n<function=' + tool_call.name + '>\\n' }}\n {%- if tool_call.arguments is defined %}\n {%- for args_name, args_value in tool_call.arguments|items %}\n {{- '<parameter=' + args_name + '>' }}\n {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}\n {{- args_value }}\n {{- '</parameter>\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '</function>\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>' }}\n {%- elif message.role == \"user\" or message.role == \"system\"%}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.previtem and loop.previtem.role != \"tool\" %}\n {{- '<|im_start|>tool\\n' }}\n {%- endif %}\n {{- '<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>\\n' }}\n {%- if not loop.last and loop.nextitem.role != \"tool\" %}\n {{- '<|im_end|>' }}\n {%- elif loop.last %}\n {{- '<|im_end|>' }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if not enable_thinking -%}\n {{- '<think></think>' -}}\n {%- else -%}\n {{- '' -}}\n {%- endif -%}\n{%- endif %}",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 262144,
236
+ "pad_token": "<|endoftext|>",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }