DeepSeek-V3.2-NVFP4 / inference /test_minimal_generation.py
eousphoros's picture
Upload inference/test_minimal_generation.py with huggingface_hub
a5ceb64 verified
#!/usr/bin/env python3
"""
Test NVFP4 model token generation.
This tests autoregressive token generation (5 tokens only for speed).
"""
import sys
import json
import torch
from transformers import AutoTokenizer
from model import Transformer, ModelArgs
from generate import load_sharded_model, link_fp8_scales
from encoding_dsv32 import encode_messages, eos_token
def test_minimal_generation():
"""Test generating 5 tokens autoregressively."""
print("\n" + "=" * 70)
print("NVFP4 Minimal Generation Test")
print("=" * 70)
print("Testing autoregressive generation (5 tokens)")
print("Expected runtime: 1-10 minutes")
print("=" * 70 + "\n")
# Load config
print("Loading config...")
config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json"
with open(config_path) as f:
args = ModelArgs(**json.load(f))
print(f" PASS: Config loaded\n")
# Create model
print("Creating model...")
torch.set_default_dtype(torch.bfloat16)
with torch.device("cpu"):
model = Transformer(args)
print(f" PASS: Model created\n")
# Load weights
print("Loading weights...")
ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4"
load_sharded_model(model, ckpt_path)
print(f" PASS: Weights loaded\n")
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print(f" PASS: Tokenizer loaded (vocab size: {len(tokenizer)})\n")
# Prepare prompt using DeepSeek V3.2 message encoding
user_message = "Hello"
print(f"User message: '{user_message}'")
messages = [{"role": "user", "content": user_message}]
# Use proper DeepSeek V3.2 encoding (thinking_mode="chat" for no reasoning)
prompt_str = encode_messages(messages, thinking_mode="chat")
print(f"Encoded prompt string (first 100 chars): '{prompt_str[:100]}...'")
prompt_tokens = tokenizer.encode(prompt_str, add_special_tokens=False)
print(f"Encoded tokens: {len(prompt_tokens)} tokens")
print(f"First 10 tokens: {prompt_tokens[:10]}\n")
tokens = torch.tensor([prompt_tokens], dtype=torch.long, device="cpu")
# Get EOS token ID
eos_id = tokenizer.convert_tokens_to_ids(eos_token)
print(f"EOS token: '{eos_token}' -> ID {eos_id}\n")
# Generate tokens
max_new_tokens = 5
print(f"Generating {max_new_tokens} tokens...")
print("-" * 70)
generated_tokens = []
prev_pos = 0
try:
with torch.inference_mode():
for step in range(max_new_tokens):
print(f"\nStep {step+1}/{max_new_tokens}:")
# Forward pass
logits = model(tokens[:, prev_pos:], start_pos=prev_pos)
# Sample next token (argmax for deterministic output)
next_token = logits.argmax(dim=-1)
next_token_id = next_token.item()
generated_tokens.append(next_token_id)
# Decode token
decoded = tokenizer.decode([next_token_id])
print(f" Generated token {next_token_id}: '{decoded}'")
# Check for EOS
if next_token_id == eos_id:
print(f" PASS: Reached EOS token, stopping generation")
break
# Check for issues
if torch.isnan(logits).any():
print(f" FAIL: ERROR: NaN in logits at step {step+1}")
return 1
if torch.isinf(logits).any():
print(f" FAIL: ERROR: Inf in logits at step {step+1}")
return 1
# Append to sequence
tokens = torch.cat([tokens, next_token.unsqueeze(-1)], dim=-1)
prev_pos = tokens.shape[1] - 1
# Show current full text (decode only the new tokens after prompt)
generated_text = tokenizer.decode(generated_tokens)
print(f" Generated so far: '{generated_text}'")
print("\n" + "-" * 70)
# Final output
full_text = tokenizer.decode(tokens[0].tolist())
generated_text = tokenizer.decode(generated_tokens)
print(f"\nPASS: Generation completed successfully!")
print(f"\nResults:")
print(f" User message: '{user_message}'")
print(f" Generated: '{generated_text}'")
print(f" Full text: '{full_text}'")
print(f" Generated tokens: {generated_tokens}")
# Basic sanity check
if len(generated_tokens) != max_new_tokens:
print(f"\nWARN: WARNING: Expected {max_new_tokens} tokens, got {len(generated_tokens)}")
print("\n" + "=" * 70)
print("PASS: GENERATION TEST PASSED")
print("=" * 70)
print("Token generation working correctly!")
print("Ready for full interactive inference.")
print("=" * 70)
return 0
except Exception as e:
print(f"\nFAIL: GENERATION FAILED: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(test_minimal_generation())