Upload inference/generate.py with huggingface_hub
Browse files- inference/generate.py +19 -17
inference/generate.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import json
|
| 3 |
from argparse import ArgumentParser
|
| 4 |
from typing import List
|
|
@@ -50,18 +51,18 @@ def hf_to_deepseek_key(hf_key: str) -> str:
|
|
| 50 |
key = key.replace(".mlp.up_proj.", ".ffn.w3.")
|
| 51 |
key = key.replace(".mlp.down_proj.", ".ffn.w2.")
|
| 52 |
|
| 53 |
-
# MoE
|
| 54 |
-
key = key.replace(".mlp.shared_experts.gate_proj.", ".
|
| 55 |
-
key = key.replace(".mlp.shared_experts.up_proj.", ".
|
| 56 |
-
key = key.replace(".mlp.shared_experts.down_proj.", ".
|
| 57 |
-
key = key.replace(".mlp.experts.", ".
|
| 58 |
-
key = key.replace(".mlp.gate.weight", ".
|
|
|
|
| 59 |
|
| 60 |
# Expert weights
|
| 61 |
-
|
| 62 |
-
key = re.sub(r"\.
|
| 63 |
-
key = re.sub(r"\.
|
| 64 |
-
key = re.sub(r"\.moe\.experts\.(\d+)\.down_proj\.", r".moe.experts.\1.w2.", key)
|
| 65 |
|
| 66 |
return key
|
| 67 |
|
|
@@ -86,7 +87,7 @@ def load_sharded_model(model, ckpt_path):
|
|
| 86 |
for i, shard_file in enumerate(shard_files):
|
| 87 |
shard_path = os.path.join(ckpt_path, shard_file)
|
| 88 |
print(f" [{i+1}/{len(shard_files)}] {shard_file}", end="", flush=True)
|
| 89 |
-
shard_dict = load_file(shard_path, device="
|
| 90 |
|
| 91 |
# Copy matching tensors to model (with key mapping)
|
| 92 |
matched = 0
|
|
@@ -151,11 +152,12 @@ def generate(
|
|
| 151 |
prompt_lens = [len(t) for t in prompt_tokens]
|
| 152 |
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
| 153 |
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
| 154 |
-
|
|
|
|
| 155 |
for i, t in enumerate(prompt_tokens):
|
| 156 |
-
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=
|
| 157 |
prev_pos = 0
|
| 158 |
-
finished = torch.tensor([False] * len(prompt_tokens), device=
|
| 159 |
prompt_mask = tokens != -1
|
| 160 |
for cur_pos in range(min(prompt_lens), total_len):
|
| 161 |
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
@@ -205,14 +207,14 @@ def main(
|
|
| 205 |
global print
|
| 206 |
if rank != 0:
|
| 207 |
print = lambda *_, **__: None
|
| 208 |
-
torch.cuda.set_device(local_rank)
|
| 209 |
torch.set_default_dtype(torch.bfloat16)
|
| 210 |
-
torch.set_num_threads(
|
| 211 |
torch.manual_seed(33377335)
|
| 212 |
with open(config) as f:
|
| 213 |
args = ModelArgs(**json.load(f))
|
| 214 |
print(args)
|
| 215 |
-
|
|
|
|
| 216 |
model = Transformer(args)
|
| 217 |
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
| 218 |
print("Loading model weights...")
|
|
|
|
| 1 |
import os
|
| 2 |
+
import re
|
| 3 |
import json
|
| 4 |
from argparse import ArgumentParser
|
| 5 |
from typing import List
|
|
|
|
| 51 |
key = key.replace(".mlp.up_proj.", ".ffn.w3.")
|
| 52 |
key = key.replace(".mlp.down_proj.", ".ffn.w2.")
|
| 53 |
|
| 54 |
+
# MoE (uses "ffn" module name in model, not "moe")
|
| 55 |
+
key = key.replace(".mlp.shared_experts.gate_proj.", ".ffn.shared_experts.w1.")
|
| 56 |
+
key = key.replace(".mlp.shared_experts.up_proj.", ".ffn.shared_experts.w3.")
|
| 57 |
+
key = key.replace(".mlp.shared_experts.down_proj.", ".ffn.shared_experts.w2.")
|
| 58 |
+
key = key.replace(".mlp.experts.", ".ffn.experts.")
|
| 59 |
+
key = key.replace(".mlp.gate.weight", ".ffn.gate.weight")
|
| 60 |
+
key = key.replace(".mlp.gate.e_score_correction_bias", ".ffn.gate.bias")
|
| 61 |
|
| 62 |
# Expert weights
|
| 63 |
+
key = re.sub(r"\.ffn\.experts\.(\d+)\.gate_proj\.", r".ffn.experts.\1.w1.", key)
|
| 64 |
+
key = re.sub(r"\.ffn\.experts\.(\d+)\.up_proj\.", r".ffn.experts.\1.w3.", key)
|
| 65 |
+
key = re.sub(r"\.ffn\.experts\.(\d+)\.down_proj\.", r".ffn.experts.\1.w2.", key)
|
|
|
|
| 66 |
|
| 67 |
return key
|
| 68 |
|
|
|
|
| 87 |
for i, shard_file in enumerate(shard_files):
|
| 88 |
shard_path = os.path.join(ckpt_path, shard_file)
|
| 89 |
print(f" [{i+1}/{len(shard_files)}] {shard_file}", end="", flush=True)
|
| 90 |
+
shard_dict = load_file(shard_path, device="cpu")
|
| 91 |
|
| 92 |
# Copy matching tensors to model (with key mapping)
|
| 93 |
matched = 0
|
|
|
|
| 152 |
prompt_lens = [len(t) for t in prompt_tokens]
|
| 153 |
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
| 154 |
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
| 155 |
+
device = next(model.parameters()).device
|
| 156 |
+
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
|
| 157 |
for i, t in enumerate(prompt_tokens):
|
| 158 |
+
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
|
| 159 |
prev_pos = 0
|
| 160 |
+
finished = torch.tensor([False] * len(prompt_tokens), device=device)
|
| 161 |
prompt_mask = tokens != -1
|
| 162 |
for cur_pos in range(min(prompt_lens), total_len):
|
| 163 |
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
|
|
| 207 |
global print
|
| 208 |
if rank != 0:
|
| 209 |
print = lambda *_, **__: None
|
|
|
|
| 210 |
torch.set_default_dtype(torch.bfloat16)
|
| 211 |
+
torch.set_num_threads(96) # Use all CPU threads
|
| 212 |
torch.manual_seed(33377335)
|
| 213 |
with open(config) as f:
|
| 214 |
args = ModelArgs(**json.load(f))
|
| 215 |
print(args)
|
| 216 |
+
print("Creating model on CPU (this may take a while)...")
|
| 217 |
+
with torch.device("cpu"):
|
| 218 |
model = Transformer(args)
|
| 219 |
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
| 220 |
print("Loading model weights...")
|