eousphoros commited on
Commit
f6e007b
·
verified ·
1 Parent(s): b8ee892

Upload inference/generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/generate.py +19 -17
inference/generate.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import json
3
  from argparse import ArgumentParser
4
  from typing import List
@@ -50,18 +51,18 @@ def hf_to_deepseek_key(hf_key: str) -> str:
50
  key = key.replace(".mlp.up_proj.", ".ffn.w3.")
51
  key = key.replace(".mlp.down_proj.", ".ffn.w2.")
52
 
53
- # MoE
54
- key = key.replace(".mlp.shared_experts.gate_proj.", ".moe.shared_experts.w1.")
55
- key = key.replace(".mlp.shared_experts.up_proj.", ".moe.shared_experts.w3.")
56
- key = key.replace(".mlp.shared_experts.down_proj.", ".moe.shared_experts.w2.")
57
- key = key.replace(".mlp.experts.", ".moe.experts.")
58
- key = key.replace(".mlp.gate.weight", ".moe.gate.weight")
 
59
 
60
  # Expert weights
61
- import re
62
- key = re.sub(r"\.moe\.experts\.(\d+)\.gate_proj\.", r".moe.experts.\1.w1.", key)
63
- key = re.sub(r"\.moe\.experts\.(\d+)\.up_proj\.", r".moe.experts.\1.w3.", key)
64
- key = re.sub(r"\.moe\.experts\.(\d+)\.down_proj\.", r".moe.experts.\1.w2.", key)
65
 
66
  return key
67
 
@@ -86,7 +87,7 @@ def load_sharded_model(model, ckpt_path):
86
  for i, shard_file in enumerate(shard_files):
87
  shard_path = os.path.join(ckpt_path, shard_file)
88
  print(f" [{i+1}/{len(shard_files)}] {shard_file}", end="", flush=True)
89
- shard_dict = load_file(shard_path, device="cuda")
90
 
91
  # Copy matching tensors to model (with key mapping)
92
  matched = 0
@@ -151,11 +152,12 @@ def generate(
151
  prompt_lens = [len(t) for t in prompt_tokens]
152
  assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
153
  total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
154
- tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
 
155
  for i, t in enumerate(prompt_tokens):
156
- tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
157
  prev_pos = 0
158
- finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
159
  prompt_mask = tokens != -1
160
  for cur_pos in range(min(prompt_lens), total_len):
161
  logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
@@ -205,14 +207,14 @@ def main(
205
  global print
206
  if rank != 0:
207
  print = lambda *_, **__: None
208
- torch.cuda.set_device(local_rank)
209
  torch.set_default_dtype(torch.bfloat16)
210
- torch.set_num_threads(8)
211
  torch.manual_seed(33377335)
212
  with open(config) as f:
213
  args = ModelArgs(**json.load(f))
214
  print(args)
215
- with torch.device("cuda"):
 
216
  model = Transformer(args)
217
  tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
218
  print("Loading model weights...")
 
1
  import os
2
+ import re
3
  import json
4
  from argparse import ArgumentParser
5
  from typing import List
 
51
  key = key.replace(".mlp.up_proj.", ".ffn.w3.")
52
  key = key.replace(".mlp.down_proj.", ".ffn.w2.")
53
 
54
+ # MoE (uses "ffn" module name in model, not "moe")
55
+ key = key.replace(".mlp.shared_experts.gate_proj.", ".ffn.shared_experts.w1.")
56
+ key = key.replace(".mlp.shared_experts.up_proj.", ".ffn.shared_experts.w3.")
57
+ key = key.replace(".mlp.shared_experts.down_proj.", ".ffn.shared_experts.w2.")
58
+ key = key.replace(".mlp.experts.", ".ffn.experts.")
59
+ key = key.replace(".mlp.gate.weight", ".ffn.gate.weight")
60
+ key = key.replace(".mlp.gate.e_score_correction_bias", ".ffn.gate.bias")
61
 
62
  # Expert weights
63
+ key = re.sub(r"\.ffn\.experts\.(\d+)\.gate_proj\.", r".ffn.experts.\1.w1.", key)
64
+ key = re.sub(r"\.ffn\.experts\.(\d+)\.up_proj\.", r".ffn.experts.\1.w3.", key)
65
+ key = re.sub(r"\.ffn\.experts\.(\d+)\.down_proj\.", r".ffn.experts.\1.w2.", key)
 
66
 
67
  return key
68
 
 
87
  for i, shard_file in enumerate(shard_files):
88
  shard_path = os.path.join(ckpt_path, shard_file)
89
  print(f" [{i+1}/{len(shard_files)}] {shard_file}", end="", flush=True)
90
+ shard_dict = load_file(shard_path, device="cpu")
91
 
92
  # Copy matching tensors to model (with key mapping)
93
  matched = 0
 
152
  prompt_lens = [len(t) for t in prompt_tokens]
153
  assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
154
  total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
155
+ device = next(model.parameters()).device
156
+ tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
157
  for i, t in enumerate(prompt_tokens):
158
+ tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
159
  prev_pos = 0
160
+ finished = torch.tensor([False] * len(prompt_tokens), device=device)
161
  prompt_mask = tokens != -1
162
  for cur_pos in range(min(prompt_lens), total_len):
163
  logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
 
207
  global print
208
  if rank != 0:
209
  print = lambda *_, **__: None
 
210
  torch.set_default_dtype(torch.bfloat16)
211
+ torch.set_num_threads(96) # Use all CPU threads
212
  torch.manual_seed(33377335)
213
  with open(config) as f:
214
  args = ModelArgs(**json.load(f))
215
  print(args)
216
+ print("Creating model on CPU (this may take a while)...")
217
+ with torch.device("cpu"):
218
  model = Transformer(args)
219
  tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
220
  print("Loading model weights...")