File size: 5,130 Bytes
a5ceb64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
"""
Test NVFP4 model token generation.

This tests autoregressive token generation (5 tokens only for speed).
"""

import sys
import json
import torch
from transformers import AutoTokenizer

from model import Transformer, ModelArgs
from generate import load_sharded_model, link_fp8_scales
from encoding_dsv32 import encode_messages, eos_token


def test_minimal_generation():
    """Test generating 5 tokens autoregressively."""
    print("\n" + "=" * 70)
    print("NVFP4 Minimal Generation Test")
    print("=" * 70)
    print("Testing autoregressive generation (5 tokens)")
    print("Expected runtime: 1-10 minutes")
    print("=" * 70 + "\n")

    # Load config
    print("Loading config...")
    config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json"
    with open(config_path) as f:
        args = ModelArgs(**json.load(f))
    print(f"  PASS: Config loaded\n")

    # Create model
    print("Creating model...")
    torch.set_default_dtype(torch.bfloat16)
    with torch.device("cpu"):
        model = Transformer(args)
    print(f"  PASS: Model created\n")

    # Load weights
    print("Loading weights...")
    ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4"
    load_sharded_model(model, ckpt_path)
    print(f"  PASS: Weights loaded\n")

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
    print(f"  PASS: Tokenizer loaded (vocab size: {len(tokenizer)})\n")

    # Prepare prompt using DeepSeek V3.2 message encoding
    user_message = "Hello"
    print(f"User message: '{user_message}'")

    messages = [{"role": "user", "content": user_message}]
    # Use proper DeepSeek V3.2 encoding (thinking_mode="chat" for no reasoning)
    prompt_str = encode_messages(messages, thinking_mode="chat")
    print(f"Encoded prompt string (first 100 chars): '{prompt_str[:100]}...'")

    prompt_tokens = tokenizer.encode(prompt_str, add_special_tokens=False)
    print(f"Encoded tokens: {len(prompt_tokens)} tokens")
    print(f"First 10 tokens: {prompt_tokens[:10]}\n")

    tokens = torch.tensor([prompt_tokens], dtype=torch.long, device="cpu")

    # Get EOS token ID
    eos_id = tokenizer.convert_tokens_to_ids(eos_token)
    print(f"EOS token: '{eos_token}' -> ID {eos_id}\n")

    # Generate tokens
    max_new_tokens = 5
    print(f"Generating {max_new_tokens} tokens...")
    print("-" * 70)

    generated_tokens = []
    prev_pos = 0

    try:
        with torch.inference_mode():
            for step in range(max_new_tokens):
                print(f"\nStep {step+1}/{max_new_tokens}:")

                # Forward pass
                logits = model(tokens[:, prev_pos:], start_pos=prev_pos)

                # Sample next token (argmax for deterministic output)
                next_token = logits.argmax(dim=-1)
                next_token_id = next_token.item()

                generated_tokens.append(next_token_id)

                # Decode token
                decoded = tokenizer.decode([next_token_id])
                print(f"  Generated token {next_token_id}: '{decoded}'")

                # Check for EOS
                if next_token_id == eos_id:
                    print(f"  PASS: Reached EOS token, stopping generation")
                    break

                # Check for issues
                if torch.isnan(logits).any():
                    print(f"  FAIL: ERROR: NaN in logits at step {step+1}")
                    return 1
                if torch.isinf(logits).any():
                    print(f"  FAIL: ERROR: Inf in logits at step {step+1}")
                    return 1

                # Append to sequence
                tokens = torch.cat([tokens, next_token.unsqueeze(-1)], dim=-1)
                prev_pos = tokens.shape[1] - 1

                # Show current full text (decode only the new tokens after prompt)
                generated_text = tokenizer.decode(generated_tokens)
                print(f"  Generated so far: '{generated_text}'")

        print("\n" + "-" * 70)

        # Final output
        full_text = tokenizer.decode(tokens[0].tolist())
        generated_text = tokenizer.decode(generated_tokens)

        print(f"\nPASS: Generation completed successfully!")
        print(f"\nResults:")
        print(f"  User message: '{user_message}'")
        print(f"  Generated: '{generated_text}'")
        print(f"  Full text: '{full_text}'")
        print(f"  Generated tokens: {generated_tokens}")

        # Basic sanity check
        if len(generated_tokens) != max_new_tokens:
            print(f"\nWARN: WARNING: Expected {max_new_tokens} tokens, got {len(generated_tokens)}")

        print("\n" + "=" * 70)
        print("PASS: GENERATION TEST PASSED")
        print("=" * 70)
        print("Token generation working correctly!")
        print("Ready for full interactive inference.")
        print("=" * 70)

        return 0

    except Exception as e:
        print(f"\nFAIL: GENERATION FAILED: {e}")
        import traceback
        traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(test_minimal_generation())