|
|
|
|
|
""" |
|
|
Test NVFP4 model token generation. |
|
|
|
|
|
This tests autoregressive token generation (5 tokens only for speed). |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import json |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from model import Transformer, ModelArgs |
|
|
from generate import load_sharded_model, link_fp8_scales |
|
|
from encoding_dsv32 import encode_messages, eos_token |
|
|
|
|
|
|
|
|
def test_minimal_generation(): |
|
|
"""Test generating 5 tokens autoregressively.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("NVFP4 Minimal Generation Test") |
|
|
print("=" * 70) |
|
|
print("Testing autoregressive generation (5 tokens)") |
|
|
print("Expected runtime: 1-10 minutes") |
|
|
print("=" * 70 + "\n") |
|
|
|
|
|
|
|
|
print("Loading config...") |
|
|
config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json" |
|
|
with open(config_path) as f: |
|
|
args = ModelArgs(**json.load(f)) |
|
|
print(f" PASS: Config loaded\n") |
|
|
|
|
|
|
|
|
print("Creating model...") |
|
|
torch.set_default_dtype(torch.bfloat16) |
|
|
with torch.device("cpu"): |
|
|
model = Transformer(args) |
|
|
print(f" PASS: Model created\n") |
|
|
|
|
|
|
|
|
print("Loading weights...") |
|
|
ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4" |
|
|
load_sharded_model(model, ckpt_path) |
|
|
print(f" PASS: Weights loaded\n") |
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) |
|
|
print(f" PASS: Tokenizer loaded (vocab size: {len(tokenizer)})\n") |
|
|
|
|
|
|
|
|
user_message = "Hello" |
|
|
print(f"User message: '{user_message}'") |
|
|
|
|
|
messages = [{"role": "user", "content": user_message}] |
|
|
|
|
|
prompt_str = encode_messages(messages, thinking_mode="chat") |
|
|
print(f"Encoded prompt string (first 100 chars): '{prompt_str[:100]}...'") |
|
|
|
|
|
prompt_tokens = tokenizer.encode(prompt_str, add_special_tokens=False) |
|
|
print(f"Encoded tokens: {len(prompt_tokens)} tokens") |
|
|
print(f"First 10 tokens: {prompt_tokens[:10]}\n") |
|
|
|
|
|
tokens = torch.tensor([prompt_tokens], dtype=torch.long, device="cpu") |
|
|
|
|
|
|
|
|
eos_id = tokenizer.convert_tokens_to_ids(eos_token) |
|
|
print(f"EOS token: '{eos_token}' -> ID {eos_id}\n") |
|
|
|
|
|
|
|
|
max_new_tokens = 5 |
|
|
print(f"Generating {max_new_tokens} tokens...") |
|
|
print("-" * 70) |
|
|
|
|
|
generated_tokens = [] |
|
|
prev_pos = 0 |
|
|
|
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
for step in range(max_new_tokens): |
|
|
print(f"\nStep {step+1}/{max_new_tokens}:") |
|
|
|
|
|
|
|
|
logits = model(tokens[:, prev_pos:], start_pos=prev_pos) |
|
|
|
|
|
|
|
|
next_token = logits.argmax(dim=-1) |
|
|
next_token_id = next_token.item() |
|
|
|
|
|
generated_tokens.append(next_token_id) |
|
|
|
|
|
|
|
|
decoded = tokenizer.decode([next_token_id]) |
|
|
print(f" Generated token {next_token_id}: '{decoded}'") |
|
|
|
|
|
|
|
|
if next_token_id == eos_id: |
|
|
print(f" PASS: Reached EOS token, stopping generation") |
|
|
break |
|
|
|
|
|
|
|
|
if torch.isnan(logits).any(): |
|
|
print(f" FAIL: ERROR: NaN in logits at step {step+1}") |
|
|
return 1 |
|
|
if torch.isinf(logits).any(): |
|
|
print(f" FAIL: ERROR: Inf in logits at step {step+1}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
tokens = torch.cat([tokens, next_token.unsqueeze(-1)], dim=-1) |
|
|
prev_pos = tokens.shape[1] - 1 |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(generated_tokens) |
|
|
print(f" Generated so far: '{generated_text}'") |
|
|
|
|
|
print("\n" + "-" * 70) |
|
|
|
|
|
|
|
|
full_text = tokenizer.decode(tokens[0].tolist()) |
|
|
generated_text = tokenizer.decode(generated_tokens) |
|
|
|
|
|
print(f"\nPASS: Generation completed successfully!") |
|
|
print(f"\nResults:") |
|
|
print(f" User message: '{user_message}'") |
|
|
print(f" Generated: '{generated_text}'") |
|
|
print(f" Full text: '{full_text}'") |
|
|
print(f" Generated tokens: {generated_tokens}") |
|
|
|
|
|
|
|
|
if len(generated_tokens) != max_new_tokens: |
|
|
print(f"\nWARN: WARNING: Expected {max_new_tokens} tokens, got {len(generated_tokens)}") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("PASS: GENERATION TEST PASSED") |
|
|
print("=" * 70) |
|
|
print("Token generation working correctly!") |
|
|
print("Ready for full interactive inference.") |
|
|
print("=" * 70) |
|
|
|
|
|
return 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nFAIL: GENERATION FAILED: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(test_minimal_generation()) |
|
|
|