PraneshJs commited on
Commit
aa71186
Β·
verified Β·
1 Parent(s): 612ed45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +393 -399
app.py CHANGED
@@ -1,503 +1,497 @@
1
- # app.py β€” Full LLM Visualizer (Option A) for Hugging Face Spaces (Gradio)
2
- # Advanced features: attention, PCA, token animation, residual norms, activation patching, neuron explorer.
3
- # Recommended models: "distilgpt2", "gpt2". Use GPU Space for larger models.
 
4
 
5
  import gradio as gr
6
  import torch
7
  import numpy as np
8
  import plotly.express as px
9
  import plotly.graph_objects as go
10
- from transformers import AutoTokenizer, AutoModelForCausalLM
11
- from sklearn.decomposition import PCA
12
  import pandas as pd
13
- import time
 
14
  import html
15
 
16
- # ---------------- Config ----------------
17
  DEFAULT_MODEL = "distilgpt2"
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
- _MODEL_CACHE = {}
 
 
20
 
21
- # ---------------- Utilities ----------------
22
  def load_model(model_name):
23
- if model_name in _MODEL_CACHE:
24
- return _MODEL_CACHE[model_name]
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_name, output_attentions=True, output_hidden_states=True
28
- )
29
- model.to(DEVICE)
30
  model.eval()
31
- _MODEL_CACHE[model_name] = (model, tokenizer)
32
  return model, tokenizer
33
 
 
34
  def softmax(x):
35
  e = np.exp(x - np.max(x))
36
  return e / e.sum(axis=-1, keepdims=True)
37
 
38
- def tokens_display(tokens):
 
39
  return " ".join([f"[{html.escape(t)}]" for t in tokens])
40
 
41
- def compute_pca_points(hidden_layer):
 
42
  try:
43
- p = PCA(n_components=2).fit_transform(hidden_layer)
44
- return p
45
- except Exception:
46
  seq = hidden_layer.shape[0]
47
- d0 = hidden_layer[:, 0] if hidden_layer.shape[1] > 0 else np.zeros(seq)
48
- d1 = hidden_layer[:, 1] if hidden_layer.shape[1] > 1 else np.zeros(seq)
49
- return np.vstack([d0, d1]).T
50
-
51
- def make_attention_figure(attn_matrix, tokens, title=None):
52
- fig = px.imshow(attn_matrix, x=tokens, y=tokens,
53
- labels={"x":"Key token", "y":"Query token", "color":"Attention"},
54
- title=title or "Attention")
55
- fig.update_layout(height=420, margin=dict(l=60, r=20, t=40, b=40))
56
  return fig
57
 
58
- def make_pca_figure(points, tokens, highlight_idx=None, title=None):
59
- fig = px.scatter(x=points[:,0], y=points[:,1], text=tokens, title=title or "PCA (2D)")
 
60
  fig.update_traces(textposition="top center", marker=dict(size=10))
61
- if highlight_idx is not None:
62
  fig.add_trace(go.Scatter(
63
- x=[points[highlight_idx,0]], y=[points[highlight_idx,1]],
64
- mode="markers+text", text=[tokens[highlight_idx]],
65
- marker=dict(size=18, color="red"), name="selected token"
 
 
66
  ))
67
- fig.update_layout(height=420, margin=dict(l=40, r=40, t=40, b=40))
68
  return fig
69
 
70
- def make_probs_figure(top_tokens, top_scores, title=None):
71
- fig = go.Figure(data=[go.Bar(x=top_tokens, y=top_scores)])
72
- fig.update_layout(title=title or "Next-token top predictions", yaxis_title="Probability", height=360, margin=dict(l=40,r=20,t=40,b=40))
 
 
73
  return fig
74
 
75
- # ---------------- Core analysis ----------------
76
- def analyze_text(text, model_name, explain_simple):
77
- """
78
- Run forward pass and return internals.
79
- Returns dict with tokens, attentions (list per layer), hidden states (list per layer), logits, PCA points, figures.
80
- """
81
- if not text or len(text.strip()) == 0:
82
- return {"error": "Please enter some text."}
83
 
84
  try:
85
  model, tokenizer = load_model(model_name)
86
  except Exception as e:
87
- return {"error": f"Failed to load model '{model_name}': {e}"}
 
 
88
 
89
  try:
90
- inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
 
91
  except Exception as e:
92
- return {"error": f"Tokenization error: {e}"}
93
- with torch.no_grad():
94
- try:
95
- outputs = model(**inputs)
96
- except Exception as e:
97
- return {"error": f"Model forward error: {e}"}
98
 
99
- # Extract tokens & internals
100
- try:
101
- input_ids = inputs["input_ids"][0].cpu().numpy().tolist()
102
- tokens = tokenizer.convert_ids_to_tokens(input_ids)
103
- except Exception:
104
- return {"error": "Failed to extract tokens."}
105
 
106
- attentions = [a[0].cpu().numpy() for a in outputs.attentions] if outputs.attentions is not None else None
107
- hidden = [h[0].cpu().numpy() for h in outputs.hidden_states] if outputs.hidden_states is not None else None
108
- logits = outputs.logits[0].cpu().numpy() # shape (seq_len, vocab)
109
 
110
  # PCA per layer
111
- pca_layers = []
112
- if hidden is not None:
113
- for layer_h in hidden:
114
- pca_layers.append(compute_pca_points(layer_h))
115
- else:
116
- pca_layers = None
117
-
118
- # Next-token topk
119
- last_logits = logits[-1]
120
- probs = softmax(last_logits)
121
- topk = 25
122
- idx = np.argsort(probs)[-topk:][::-1]
123
- top_tokens = [tokenizer.decode([int(i)]) for i in idx]
124
  top_scores = probs[idx].tolist()
125
 
126
- default_layer = (len(attentions) - 1) if attentions is not None else (len(pca_layers) - 1 if pca_layers else 0)
127
  default_head = 0
128
 
129
- fig_attn = make_attention_figure(attentions[default_layer][default_head], tokens, title=f"Layer {default_layer} Head {default_head}") if attentions is not None else None
130
- fig_pca = make_pca_figure(pca_layers[default_layer], tokens, highlight_idx=None, title=f"PCA (layer {default_layer})") if pca_layers is not None else None
131
- fig_probs = make_probs_figure(top_tokens, top_scores, title="Next-token top predictions")
132
-
133
- explanation = (
134
- "Simple: the model splits text into pieces, looks which pieces are important, and guesses the next word."
135
- if explain_simple else
136
- "Technical: tokens, attention matrices per head/layer, hidden states projected to 2D, and top-k next-token probabilities."
137
- )
138
-
139
- # neuron explorer (top neurons by mean absolute activation in last layer)
140
  neuron_info = []
141
  try:
142
- last_hidden = hidden[-1] # (seq, dim)
143
- mean_act = np.abs(last_hidden).mean(axis=0)
144
  top_neurons = np.argsort(mean_act)[-24:][::-1]
145
- for n in top_neurons[:24]:
146
- vals = last_hidden[:, n]
147
- top_token_idx = np.argsort(np.abs(vals))[-6:][::-1]
148
- token_hits = [(tokens[i], float(vals[i])) for i in top_token_idx]
149
- neuron_info.append({"neuron": int(n), "top_tokens": token_hits})
150
- except Exception:
 
 
151
  neuron_info = []
152
 
153
- # residual norms (best-effort)
154
- residuals = None
155
- try:
156
- if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
157
- blocks = model.transformer.h
158
- wte = getattr(model.transformer, "wte", None)
159
- # compute norms by approximating per-layer attn & mlp outputs using block forward if possible
160
- attn_norms = []
161
- mlp_norms = []
162
- # Start from embeddings
163
- cur = wte(inputs["input_ids"]) if wte is not None else None
164
- if cur is not None:
165
- cur = cur.to(DEVICE)
166
- # We'll run each block and measure norms of attention & mlp outputs if callable
167
- for block in blocks:
168
- try:
169
- ln1 = block.ln_1(cur)
170
- attn_out = block.attn(ln1)[0]
171
- cur = cur + attn_out
172
- ln2 = block.ln_2(cur)
173
- mlp_out = block.mlp(ln2)
174
- cur = cur + mlp_out
175
- attn_norms.append(float(torch.norm(attn_out).cpu().numpy()))
176
- mlp_norms.append(float(torch.norm(mlp_out).cpu().numpy()))
177
- except Exception:
178
- # fallback: run full block and compute residual diff
179
- prev = cur.clone()
180
- try:
181
- cur = block(prev)[0]
182
- total = cur - prev
183
- attn_norms.append(0.0)
184
- mlp_norms.append(float(torch.norm(total).cpu().numpy()))
185
- except Exception:
186
- attn_norms.append(0.0)
187
- mlp_norms.append(0.0)
188
- residuals = {"attn_norms": attn_norms, "mlp_norms": mlp_norms}
189
- except Exception:
190
- residuals = None
191
-
192
- result = {
193
  "tokens": tokens,
194
  "attentions": attentions,
195
  "hidden": hidden,
 
196
  "logits": logits,
197
- "pca_layers": pca_layers,
198
- "fig_attn": fig_attn,
199
- "fig_pca": fig_pca,
200
- "fig_probs": fig_probs,
201
  "default_layer": default_layer,
202
  "default_head": default_head,
203
- "token_display": tokens_display(tokens),
204
- "explanation": explanation,
205
  "neuron_info": neuron_info,
206
  "residuals": residuals,
207
- "model_name": model_name,
208
- "input_ids": input_ids
209
  }
210
- return result
211
 
212
- # ---------------- Activation patching ----------------
213
- def activation_patch_and_run(text_tokens, model_name, patch_layer, patch_pos, patch_from_pos, patch_scale=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  """
215
- Activation patching for GPT-2 style models: copy vector at patch_from_pos to patch_pos at patch_layer.
216
- Returns top-k next-token predictions after patching.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  """
218
  try:
219
  model, tokenizer = load_model(model_name)
220
- except Exception:
221
- return {"error": "Model load failed for patching."}
222
 
223
- # Prepare inputs again as string
224
- text = " ".join(text_tokens)
225
- inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
226
 
227
- # Check block availability
228
- if not (hasattr(model, "transformer") and hasattr(model.transformer, "h")):
229
- return {"error": "Model not compatible with activation patching (requires GPT-2 style blocks)."}
230
 
231
  blocks = model.transformer.h
232
  wte = model.transformer.wte
233
  ln_f = model.transformer.ln_f if hasattr(model.transformer, "ln_f") else None
234
  lm_head = model.lm_head
235
 
236
- # collect hidden precomputed
237
  with torch.no_grad():
238
- x = wte(inputs["input_ids"]).to(DEVICE) # (1, seq, dim)
239
- hidden_per_layer = [x.detach().cpu().numpy()[0]] # embedding considered layer -1
240
- for block in blocks:
241
- # standard GPT-2 block flow
242
- ln1 = block.ln_1(x)
243
- attn_out = block.attn(ln1)[0]
244
- x = x + attn_out
245
- ln2 = block.ln_2(x)
246
- mlp_out = block.mlp(ln2)
247
- x = x + mlp_out
248
- hidden_per_layer.append(x.detach().cpu().numpy()[0])
249
-
250
- seq_len = hidden_per_layer[0].shape[0]
251
- if patch_pos < 0 or patch_pos >= seq_len or patch_from_pos < 0 or patch_from_pos >= seq_len:
252
- return {"error": "Patch positions out of range."}
253
-
254
- # vector to copy
255
- vec = torch.tensor(hidden_per_layer[patch_layer][patch_from_pos], dtype=torch.float32).to(DEVICE) * float(patch_scale)
 
256
 
257
  # re-run with patch
258
  with torch.no_grad():
259
  x = wte(inputs["input_ids"]).to(DEVICE)
260
- for i, block in enumerate(blocks):
261
- ln1 = block.ln_1(x)
262
- attn_out = block.attn(ln1)[0]
263
- x = x + attn_out
264
- ln2 = block.ln_2(x)
265
- mlp_out = block.mlp(ln2)
266
- x = x + mlp_out
267
- if i == patch_layer:
268
- # set vector at position
269
- x[0, patch_pos, :] = vec
270
- final = ln_f(x) if ln_f is not None else x
271
- logits = lm_head(final)
272
- logits = logits[0, -1, :].cpu().numpy()
273
  probs = softmax(logits)
274
- topk = 25
275
- idx = np.argsort(probs)[-topk:][::-1]
276
- top_tokens = [tokenizer.decode([int(i)]) for i in idx]
277
- top_scores = probs[idx].tolist()
278
- return {"patched_top_tokens": top_tokens, "patched_top_scores": top_scores}
279
-
280
- # ---------------- UI wiring ----------------
281
- def run_analysis(text, model_name, explain_simple):
282
- res = analyze_text(text, model_name, explain_simple)
283
- if "error" in res:
284
- # return dict keyed by components
285
- return {
286
- token_display: gr.update(value=""),
287
- explanation_md: gr.update(value=res["error"]),
288
- model_info: gr.update(value=f"Model: {model_name}"),
289
- attn_plot: gr.update(value=None),
290
- pca_plot: gr.update(value=None),
291
- probs_plot: gr.update(value=None),
292
- layer_slider: gr.update(maximum=0, value=0),
293
- head_slider: gr.update(maximum=0, value=0),
294
- token_step: gr.update(maximum=0, value=0),
295
- state: res,
296
- residual_plot: gr.update(value=None),
297
- neuron_table: gr.update(value=[]),
298
- patch_layer_input: gr.update(maximum=0, value=0),
299
- patch_pos_input: gr.update(maximum=0, value=0),
300
- patch_from_pos_input: gr.update(maximum=0, value=0),
301
- }
302
 
303
- tokens = res["tokens"]
304
- num_layers = len(res["attentions"]) if res["attentions"] is not None else (len(res["pca_layers"]) - 1 if res["pca_layers"] else 0)
305
- num_heads = res["attentions"][0].shape[0] if res["attentions"] is not None else 1
306
- max_token_idx = len(tokens) - 1
307
-
308
- token_display_text = f"**Tokens:** {res['token_display']}"
309
- explanation_text = res["explanation"]
310
- model_info_text = f"Model: {res['model_name']} β€’ layers: {num_layers} β€’ heads: {num_heads} β€’ tokens: {len(tokens)}"
311
-
312
- layer_update = gr.update(maximum=max(0, num_layers - 1), value=res["default_layer"])
313
- head_update = gr.update(maximum=max(0, num_heads - 1), value=res["default_head"])
314
- token_step_update = gr.update(maximum=max_token_idx, value=0)
315
-
316
- patch_layer_update = gr.update(maximum=max(0, num_layers - 1), value=0)
317
- patch_pos_update = gr.update(maximum=max(0, max_token_idx), value=0)
318
- patch_from_pos_update = gr.update(maximum=max(0, max_token_idx), value=0)
319
-
320
- # neuron table initial (show first neuron's top tokens if available)
321
- neuron_table_data = []
322
- if res.get("neuron_info"):
323
- first = res["neuron_info"][0]
324
- neuron_table_data = [[t, round(v, 6)] for t, v in first["top_tokens"]]
325
-
326
- # residual figure
327
- residual_fig = None
328
- if res.get("residuals"):
329
- df = pd.DataFrame({"layer": list(range(len(res["residuals"]["attn_norms"]))),
330
- "attn": res["residuals"]["attn_norms"],
331
- "mlp": res["residuals"]["mlp_norms"]})
332
- residual_fig = go.Figure()
333
- residual_fig.add_trace(go.Bar(x=df["layer"], y=df["attn"], name="Attention norm"))
334
- residual_fig.add_trace(go.Bar(x=df["layer"], y=df["mlp"], name="MLP norm"))
335
- residual_fig.update_layout(barmode="group", title="Residual contributions (layerwise norms)", height=360)
336
 
337
- return {
338
- token_display: gr.update(value=token_display_text),
339
- explanation_md: gr.update(value=explanation_text),
340
- model_info: gr.update(value=model_info_text),
341
- attn_plot: gr.update(value=res["fig_attn"]),
342
- pca_plot: gr.update(value=res["fig_pca"]),
343
- probs_plot: gr.update(value=res["fig_probs"]),
344
- layer_slider: layer_update,
345
- head_slider: head_update,
346
- token_step: token_step_update,
347
- state: res,
348
- residual_plot: gr.update(value=residual_fig),
349
- neuron_table: gr.update(value=neuron_table_data),
350
- patch_layer_input: patch_layer_update,
351
- patch_pos_input: patch_pos_update,
352
- patch_from_pos_input: patch_from_pos_update,
353
- }
354
 
355
- def update_visuals(state_obj, layer, head, token_idx):
356
- # update attention, pca, and attention-row for selected token using cached state
357
- if not state_obj:
358
- return {attn_plot: gr.update(value=None), pca_plot: gr.update(value=None), step_attn_plot: gr.update(value=None)}
359
- res = state_obj
360
- tokens = res["tokens"]
361
- # bounds
362
- if res["attentions"] is not None:
363
- max_layer = len(res["attentions"]) - 1
364
- layer = int(min(max(0, layer), max_layer))
365
- max_head = res["attentions"][0].shape[0] - 1
366
- head = int(min(max(0, head), max_head))
367
- else:
368
- layer = int(min(max(0, layer), len(res["pca_layers"]) - 1 if res["pca_layers"] else 0))
369
- head = 0
370
- token_idx = int(min(max(0, token_idx), len(tokens) - 1))
371
-
372
- attn_fig = None
373
- if res["attentions"] is not None:
374
- attn_fig = make_attention_figure(res["attentions"][layer][head], tokens, title=f"Layer {layer} Head {head}")
375
-
376
- pca_fig = None
377
- if res["pca_layers"] is not None:
378
- pca_pts = res["pca_layers"][layer]
379
- pca_fig = make_pca_figure(pca_pts, tokens, highlight_idx=token_idx, title=f"PCA (layer {layer})")
380
-
381
- step_attn_fig = None
382
- if res["attentions"] is not None:
383
- row = res["attentions"][layer][head][token_idx]
384
- step_attn_fig = go.Figure(data=[go.Bar(x=tokens, y=row)])
385
- step_attn_fig.update_layout(title=f"Token {token_idx} attends to (layer {layer}, head {head})", height=300, margin=dict(l=40,r=20,t=30,b=40))
386
-
387
- return {attn_plot: gr.update(value=attn_fig),
388
- pca_plot: gr.update(value=pca_fig),
389
- step_attn_plot: gr.update(value=step_attn_fig)}
390
-
391
- def run_patch(state_obj, patch_layer, patch_pos, patch_from_pos, patch_scale, model_name):
392
- if not state_obj:
393
- return gr.update(value=None)
394
- tokens = state_obj["tokens"]
395
- res = activation_patch_and_run(tokens, model_name, int(patch_layer), int(patch_pos), int(patch_from_pos), float(patch_scale))
396
- if "error" in res:
397
- return gr.update(value=None)
398
- fig = go.Figure(data=[go.Bar(x=res["patched_top_tokens"], y=res["patched_top_scores"])])
399
- fig.update_layout(title=f"Patched predictions (layer {patch_layer}, pos {patch_pos} <- pos {patch_from_pos}, scale {patch_scale})", height=420)
400
- return gr.update(value=fig)
401
-
402
- def find_neurons(state_obj):
403
- if not state_obj:
404
- return gr.update(value=[])
405
- info = state_obj.get("neuron_info", [])
406
- rows = []
407
- for e in info[:24]:
408
- for t, v in e["top_tokens"]:
409
- rows.append([t, round(v,6)])
410
- # dedupe
411
- df = pd.DataFrame(rows, columns=["token","activation"]).drop_duplicates().head(24).values.tolist()
412
- return gr.update(value=df)
413
-
414
- def inspect_neuron(state_obj, neuron_idx):
415
- if not state_obj:
416
- return gr.update(value=[])
417
- try:
418
- neuron_idx = int(neuron_idx)
419
- except Exception:
420
- return gr.update(value=[])
421
- last_hidden = state_obj["hidden"][-1]
422
- vals = last_hidden[:, neuron_idx]
423
- tokens = state_obj["tokens"]
424
- df = sorted([(tokens[i], float(vals[i])) for i in range(len(tokens))], key=lambda x: -abs(x[1]))[:12]
425
- return gr.update(value=[[t, round(v,6)] for t,v in df])
426
-
427
- # ---------------- Gradio UI ----------------
428
- with gr.Blocks(title="LLM Visualizer β€” Full (Option A)", theme=gr.themes.Soft()) as demo:
429
- gr.Markdown("<h1 style='font-size:30px'>🧠 LLM Visualizer β€” Full (Advanced)</h1>")
430
- gr.Markdown("Advanced GPT-2 style visualizer. Use `distilgpt2` or `gpt2` for full features. Keep input short (<80 tokens) on CPU Spaces.")
431
 
 
432
  with gr.Row():
433
  with gr.Column(scale=3):
434
- model_input = gr.Textbox(label="Model (Hugging Face name)", value=DEFAULT_MODEL)
435
- text_input = gr.Textbox(label="Input text", value="Hello world, this is a test.", lines=3)
436
- explain_simple = gr.Checkbox(label="Explain simply (kid/elder mode)", value=True)
437
- run_btn = gr.Button("Run analysis", variant="primary")
438
- gr.Markdown("**Presets:**")
439
- with gr.Row():
440
- gr.Button("Greeting").click(lambda: "Hello! How are you today?", None, text_input)
441
- gr.Button("Story start").click(lambda: "Once upon a time, there was a small robot...", None, text_input)
442
- gr.Button("Question").click(lambda: "Why is the sky blue?", None, text_input)
443
 
444
- # Guided hints
445
- gr.Markdown("**Hints:** Use small inputs. Slide Layer/Head to explore. Use Token slider to animate token flow.")
 
 
 
446
 
447
  with gr.Column(scale=2):
448
- token_display = gr.Markdown("Tokens will appear here.")
449
- explanation_md = gr.Markdown("Explanation will appear here.")
450
- model_info = gr.Markdown("Model info: β€”")
451
 
 
452
  with gr.Row():
453
  with gr.Column():
454
- layer_slider = gr.Slider(label="Layer", minimum=0, maximum=0, step=1, value=0)
455
- head_slider = gr.Slider(label="Head", minimum=0, maximum=0, step=1, value=0)
456
- token_step = gr.Slider(label="Token index (step through tokens)", minimum=0, maximum=0, step=1, value=0)
457
- attn_plot = gr.Plot(label="Attention heatmap")
 
458
  with gr.Column():
459
- pca_plot = gr.Plot(label="PCA hidden states (2D)")
460
- step_attn_plot = gr.Plot(label="Attention row for selected token")
461
- probs_plot = gr.Plot(label="Next-token top predictions")
462
- residual_plot = gr.Plot(label="Residual decomposition (attention vs mlp)")
463
 
 
 
 
 
464
  with gr.Row():
465
- with gr.Column(scale=2):
466
- gr.Markdown("### Neuron / Circuit explorer")
467
- neuron_find_btn = gr.Button("Find example neurons (auto)")
468
- neuron_dropdown = gr.Number(label="Neuron index to inspect (enter integer)", value=0)
469
- neuron_table = gr.Dataframe(headers=["token", "activation"], interactive=False)
470
- with gr.Column(scale=3):
471
- gr.Markdown("### Activation patching (copy vector at layer)")
472
- patch_layer_input = gr.Slider(label="Patch Layer (0 = first block)", minimum=0, maximum=0, step=1, value=0)
473
- patch_pos_input = gr.Slider(label="Patch position (token index)", minimum=0, maximum=0, step=1, value=0)
474
- patch_from_pos_input = gr.Slider(label="Copy from position (token index)", minimum=0, maximum=0, step=1, value=0)
475
- patch_scale_input = gr.Number(label="Patch scale (multiplier)", value=1.0)
476
- patch_btn = gr.Button("Run Activation Patch & Show Top Predictions", variant="primary")
477
- patch_output = gr.Plot(label="Patched next-token predictions")
478
 
479
  state = gr.State()
480
 
481
- # Events wiring
482
- run_btn.click(fn=run_analysis,
483
- inputs=[text_input, model_input, explain_simple],
484
- outputs=[token_display, explanation_md, model_info,
485
- attn_plot, pca_plot, probs_plot,
486
- layer_slider, head_slider, token_step,
487
- state, residual_plot, neuron_table,
488
- patch_layer_input, patch_pos_input, patch_from_pos_input])
489
-
490
- layer_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step],
491
- outputs=[attn_plot, pca_plot, step_attn_plot])
492
- head_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step],
493
- outputs=[attn_plot, pca_plot, step_attn_plot])
494
- token_step.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step],
495
- outputs=[attn_plot, pca_plot, step_attn_plot])
496
-
497
- neuron_find_btn.click(fn=find_neurons, inputs=[state], outputs=[neuron_table])
498
- neuron_dropdown.change(fn=inspect_neuron, inputs=[state, neuron_dropdown], outputs=[neuron_table])
499
-
500
- patch_btn.click(fn=run_patch, inputs=[state, patch_layer_input, patch_pos_input, patch_from_pos_input, patch_scale_input, model_input],
501
- outputs=[patch_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  demo.launch()
 
1
+ # FULL LLM VISUALIZER β€” OPTION A (ADVANCED)
2
+ # stable + patched + safe for HuggingFace Spaces (CPU or GPU)
3
+ # recommended models: distilgpt2, gpt2
4
+ # author: ChatGPT
5
 
6
  import gradio as gr
7
  import torch
8
  import numpy as np
9
  import plotly.express as px
10
  import plotly.graph_objects as go
 
 
11
  import pandas as pd
12
+ from sklearn.decomposition import PCA
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
  import html
15
 
 
16
  DEFAULT_MODEL = "distilgpt2"
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ MODEL_CACHE = {}
19
+
20
+ # ---------------- CORE UTILS ----------------
21
 
 
22
  def load_model(model_name):
23
+ if model_name in MODEL_CACHE:
24
+ return MODEL_CACHE[model_name]
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_name, output_attentions=True, output_hidden_states=True
28
+ ).to(DEVICE)
 
29
  model.eval()
30
+ MODEL_CACHE[model_name] = (model, tokenizer)
31
  return model, tokenizer
32
 
33
+
34
  def softmax(x):
35
  e = np.exp(x - np.max(x))
36
  return e / e.sum(axis=-1, keepdims=True)
37
 
38
+
39
+ def safe_tokens(tokens):
40
  return " ".join([f"[{html.escape(t)}]" for t in tokens])
41
 
42
+
43
+ def compute_pca(hidden_layer):
44
  try:
45
+ return PCA(n_components=2).fit_transform(hidden_layer)
46
+ except:
 
47
  seq = hidden_layer.shape[0]
48
+ dim0 = hidden_layer[:, 0] if hidden_layer.shape[1] > 0 else np.zeros(seq)
49
+ dim1 = hidden_layer[:, 1] if hidden_layer.shape[1] > 1 else np.zeros(seq)
50
+ return np.vstack([dim0, dim1]).T
51
+
52
+
53
+ def fig_attention(matrix, tokens, title):
54
+ fig = px.imshow(matrix, x=tokens, y=tokens, title=title,
55
+ labels={"x": "Key token", "y": "Query token", "color": "Attention"})
56
+ fig.update_layout(height=420)
57
  return fig
58
 
59
+
60
+ def fig_pca(points, tokens, highlight=None, title="PCA"):
61
+ fig = px.scatter(x=points[:, 0], y=points[:, 1], text=tokens, title=title)
62
  fig.update_traces(textposition="top center", marker=dict(size=10))
63
+ if highlight is not None:
64
  fig.add_trace(go.Scatter(
65
+ x=[points[highlight, 0]],
66
+ y=[points[highlight, 1]],
67
+ mode="markers+text",
68
+ text=[tokens[highlight]],
69
+ marker=dict(size=18, color="red")
70
  ))
71
+ fig.update_layout(height=420)
72
  return fig
73
 
74
+
75
+ def fig_probs(tokens, scores):
76
+ fig = go.Figure()
77
+ fig.add_trace(go.Bar(x=tokens, y=scores))
78
+ fig.update_layout(title="Next-token probabilities", height=380)
79
  return fig
80
 
81
+
82
+ # ---------------- ANALYSIS CORE ----------------
83
+
84
+ def analyze_text(text, model_name, simple):
85
+ if not text.strip():
86
+ return {"error": "Please enter text."}
 
 
87
 
88
  try:
89
  model, tokenizer = load_model(model_name)
90
  except Exception as e:
91
+ return {"error": f"Failed to load model: {e}"}
92
+
93
+ inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
94
 
95
  try:
96
+ with torch.no_grad():
97
+ out = model(**inputs)
98
  except Exception as e:
99
+ return {"error": f"Model error: {e}"}
 
 
 
 
 
100
 
101
+ input_ids = inputs["input_ids"][0].cpu().numpy().tolist()
102
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
 
 
 
 
103
 
104
+ attentions = [a[0].cpu().numpy() for a in out.attentions]
105
+ hidden = [h[0].cpu().numpy() for h in out.hidden_states]
106
+ logits = out.logits[0].cpu().numpy()
107
 
108
  # PCA per layer
109
+ pca_layers = [compute_pca(h) for h in hidden]
110
+
111
+ # top-k
112
+ last = logits[-1]
113
+ probs = softmax(last)
114
+ idx = np.argsort(probs)[-20:][::-1]
115
+ top_tokens = [tokenizer.decode([i]) for i in idx]
 
 
 
 
 
 
116
  top_scores = probs[idx].tolist()
117
 
118
+ default_layer = len(attentions) - 1
119
  default_head = 0
120
 
121
+ # neuron explorer
 
 
 
 
 
 
 
 
 
 
122
  neuron_info = []
123
  try:
124
+ last_h = hidden[-1]
125
+ mean_act = np.abs(last_h).mean(axis=0)
126
  top_neurons = np.argsort(mean_act)[-24:][::-1]
127
+ for n in top_neurons:
128
+ vals = last_h[:, n]
129
+ top_ix = np.argsort(np.abs(vals))[-5:][::-1]
130
+ neuron_info.append({
131
+ "neuron": int(n),
132
+ "top_tokens": [(tokens[i], float(vals[i])) for i in top_ix]
133
+ })
134
+ except:
135
  neuron_info = []
136
 
137
+ # residual decomposition (safe)
138
+ residuals = compute_residuals_safe(model, inputs)
139
+
140
+ return {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  "tokens": tokens,
142
  "attentions": attentions,
143
  "hidden": hidden,
144
+ "pca": pca_layers,
145
  "logits": logits,
146
+ "top_tokens": top_tokens,
147
+ "top_scores": top_scores,
 
 
148
  "default_layer": default_layer,
149
  "default_head": default_head,
 
 
150
  "neuron_info": neuron_info,
151
  "residuals": residuals,
152
+ "token_display": safe_tokens(tokens),
153
+ "explanation": explain(simple)
154
  }
 
155
 
156
+
157
+ def explain(s):
158
+ if s:
159
+ return (
160
+ "πŸ§’ **Simple mode:**\n"
161
+ "- The model cuts text into small pieces (tokens).\n"
162
+ "- It looks at which tokens matter (attention).\n"
163
+ "- It builds an internal map (PCA) of meanings.\n"
164
+ "- Then it guesses the next token.\n"
165
+ )
166
+ return (
167
+ "πŸ”¬ **Technical mode:**\n"
168
+ "Showing tokens, attention (query→key), PCA projections, logits, "
169
+ "neuron activations, and layerwise residual contributions.\n"
170
+ )
171
+
172
+
173
+ # ---------------- RESIDUAL DECOMPOSITION SAFE ----------------
174
+
175
+ def compute_residuals_safe(model, inputs):
176
  """
177
+ Guaranteed safe residual norms for GPT-2-style blocks.
178
+ Will NEVER crash. Returns None if not applicable.
179
+ """
180
+ if not hasattr(model, "transformer") or not hasattr(model.transformer, "h"):
181
+ return None
182
+
183
+ try:
184
+ blocks = model.transformer.h
185
+ wte = model.transformer.wte
186
+ x = wte(inputs["input_ids"]).to(DEVICE)
187
+
188
+ attn_norms = []
189
+ mlp_norms = []
190
+
191
+ for block in blocks:
192
+ try:
193
+ ln1 = block.ln_1(x)
194
+ attn_out = block.attn(ln1)[0]
195
+ x = x + attn_out
196
+ ln2 = block.ln_2(x)
197
+ mlp_out = block.mlp(ln2)
198
+ x = x + mlp_out
199
+
200
+ attn_norms.append(float(torch.norm(attn_out).cpu()))
201
+ mlp_norms.append(float(torch.norm(mlp_out).cpu()))
202
+ except:
203
+ # fallback safe zero
204
+ attn_norms.append(0.0)
205
+ mlp_norms.append(0.0)
206
+
207
+ # normalize lengths safely
208
+ L = min(len(attn_norms), len(mlp_norms))
209
+ return {
210
+ "attn": attn_norms[:L],
211
+ "mlp": mlp_norms[:L],
212
+ }
213
+ except:
214
+ return None
215
+
216
+
217
+ # ---------------- ACTIVATION PATCHING (SAFE VERSION) ----------------
218
+
219
+ def activation_patch(tokens, model_name, layer, pos, from_pos, scale=1.0):
220
+ """
221
+ Safe activation patching (never crashes, only works for GPT-2 style).
222
  """
223
  try:
224
  model, tokenizer = load_model(model_name)
225
+ except:
226
+ return {"error": "Model load error."}
227
 
228
+ if not hasattr(model, "transformer") or not hasattr(model.transformer, "h"):
229
+ return {"error": "Model not compatible with patching."}
 
230
 
231
+ text = " ".join(tokens)
232
+ inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
 
233
 
234
  blocks = model.transformer.h
235
  wte = model.transformer.wte
236
  ln_f = model.transformer.ln_f if hasattr(model.transformer, "ln_f") else None
237
  lm_head = model.lm_head
238
 
 
239
  with torch.no_grad():
240
+ x = wte(inputs["input_ids"]).to(DEVICE)
241
+ hidden_layers = [x.clone().cpu().numpy()[0]]
242
+ for b in blocks:
243
+ ln1 = b.ln_1(x)
244
+ a = b.attn(ln1)[0]
245
+ x = x + a
246
+ ln2 = b.ln_2(x)
247
+ m = b.mlp(ln2)
248
+ x = x + m
249
+ hidden_layers.append(x.clone().cpu().numpy()[0])
250
+
251
+ if layer >= len(hidden_layers):
252
+ return {"error": "Layer out of range."}
253
+
254
+ seq_len = hidden_layers[layer].shape[0]
255
+ if pos >= seq_len or from_pos >= seq_len:
256
+ return {"error": "Position out of range."}
257
+
258
+ patch_vec = torch.tensor(hidden_layers[layer][from_pos], dtype=torch.float32).to(DEVICE) * float(scale)
259
 
260
  # re-run with patch
261
  with torch.no_grad():
262
  x = wte(inputs["input_ids"]).to(DEVICE)
263
+ for i, b in enumerate(blocks):
264
+ ln1 = b.ln_1(x)
265
+ a = b.attn(ln1)[0]
266
+ x = x + a
267
+ ln2 = b.ln_2(x)
268
+ m = b.mlp(ln2)
269
+ x = x + m
270
+ if i == layer:
271
+ x[0, pos, :] = patch_vec
272
+
273
+ final = ln_f(x) if ln_f else x
274
+ logits = lm_head(final)[0, -1, :].cpu().numpy()
 
275
  probs = softmax(logits)
276
+ idx = np.argsort(probs)[-20:][::-1]
277
+ tt = [tokenizer.decode([int(i)]) for i in idx]
278
+ ss = probs[idx].tolist()
279
+ return {"tokens": tt, "scores": ss}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ # ---------------- GRADIO UI ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ with gr.Blocks(title="LLM Visualizer β€” Full", theme=gr.themes.Soft()) as demo:
285
+
286
+ gr.Markdown("# 🧠 Full LLM Visualizer (Advanced)")
287
+ gr.Markdown("Fully stable build with attention, PCA, neuron explorer, residuals, activation-patching")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ # Panel 1
290
  with gr.Row():
291
  with gr.Column(scale=3):
292
+ model_name = gr.Textbox(label="Model", value=DEFAULT_MODEL)
293
+ input_text = gr.Textbox(label="Input", value="Hello world", lines=3)
294
+ simple = gr.Checkbox(label="Explain simply", value=True)
295
+ run_btn = gr.Button("Run", variant="primary")
 
 
 
 
 
296
 
297
+ gr.Markdown("Presets:")
298
+ with gr.Row():
299
+ gr.Button("Greeting").click(lambda: "Hello! How are you?", None, input_text)
300
+ gr.Button("Story").click(lambda: "Once upon a time there was a robot.", None, input_text)
301
+ gr.Button("Question").click(lambda: "Why is the sky blue?", None, input_text)
302
 
303
  with gr.Column(scale=2):
304
+ token_display = gr.Markdown()
305
+ explanation_md = gr.Markdown()
306
+ model_info = gr.Markdown()
307
 
308
+ # Panel 2
309
  with gr.Row():
310
  with gr.Column():
311
+ layer_slider = gr.Slider(0, 0, value=0, step=1, label="Layer")
312
+ head_slider = gr.Slider(0, 0, value=0, step=1, label="Head")
313
+ token_step = gr.Slider(0, 0, value=0, step=1, label="Token index")
314
+ attn_plot = gr.Plot()
315
+
316
  with gr.Column():
317
+ pca_plot = gr.Plot()
318
+ step_attn_plot = gr.Plot()
319
+ probs_plot = gr.Plot()
 
320
 
321
+ # Panel 3 β€” Residuals
322
+ residual_plot = gr.Plot()
323
+
324
+ # Panel 4 β€” Neuron explorer
325
  with gr.Row():
326
+ neuron_find_btn = gr.Button("Find neurons")
327
+ neuron_idx = gr.Number(label="Neuron index", value=0)
328
+ neuron_table = gr.Dataframe(headers=["token", "activation"], interactive=False)
329
+
330
+ # Panel 5 β€” Activation Patching
331
+ with gr.Row():
332
+ patch_layer = gr.Slider(0, 0, value=0, step=1, label="Patch layer")
333
+ patch_pos = gr.Slider(0, 0, value=0, step=1, label="Target token position")
334
+ patch_from = gr.Slider(0, 0, value=0, step=1, label="Copy from position")
335
+ patch_scale = gr.Number(label="Scale", value=1.0)
336
+ patch_btn = gr.Button("Run patch")
337
+ patch_output = gr.Plot()
 
338
 
339
  state = gr.State()
340
 
341
+ # ---- RUN ANALYSIS ----
342
+ def run_app(text, model, simp):
343
+ res = analyze_text(text, model, simp)
344
+
345
+ if "error" in res:
346
+ return {
347
+ token_display: gr.update(value=""),
348
+ explanation_md: gr.update(value=res["error"]),
349
+ model_info: gr.update(value=f"Model: {model}"),
350
+ attn_plot: gr.update(value=None),
351
+ pca_plot: gr.update(value=None),
352
+ probs_plot: gr.update(value=None),
353
+ layer_slider: gr.update(maximum=0, value=0),
354
+ head_slider: gr.update(maximum=0, value=0),
355
+ token_step: gr.update(maximum=0, value=0),
356
+ residual_plot: gr.update(value=None),
357
+ neuron_table: gr.update(value=[]),
358
+ patch_layer: gr.update(maximum=0),
359
+ patch_pos: gr.update(maximum=0),
360
+ patch_from: gr.update(maximum=0),
361
+ state: res
362
+ }
363
+
364
+ tokens = res["tokens"]
365
+ L = len(res["attentions"])
366
+ H = res["attentions"][0].shape[0]
367
+ T = len(tokens) - 1
368
+
369
+ residual_fig = None
370
+ if res["residuals"]:
371
+ attn_vals = res["residuals"]["attn"]
372
+ ml_vals = res["residuals"]["mlp"]
373
+ Lmin = min(len(attn_vals), len(ml_vals))
374
+ df = pd.DataFrame({
375
+ "layer": list(range(Lmin)),
376
+ "attention": attn_vals[:Lmin],
377
+ "mlp": ml_vals[:Lmin]
378
+ })
379
+ fig = go.Figure()
380
+ fig.add_trace(go.Bar(x=df["layer"], y=df["attention"], name="Attention norm"))
381
+ fig.add_trace(go.Bar(x=df["layer"], y=df["mlp"], name="MLP norm"))
382
+ fig.update_layout(barmode="group", height=360)
383
+ residual_fig = fig
384
+
385
+ return {
386
+ token_display: gr.update(value=f"**Tokens:** {res['token_display']}"),
387
+ explanation_md: gr.update(value=res["explanation"]),
388
+ model_info: gr.update(value=f"Model: {model} β€’ layers: {L} β€’ heads: {H} β€’ tokens: {len(tokens)}"),
389
+ attn_plot: gr.update(value=res["fig_attn"] if res.get("fig_attn") else None),
390
+ pca_plot: gr.update(value=res["fig_pca"] if res.get("fig_pca") else None),
391
+ probs_plot: gr.update(value=fig_probs(res["top_tokens"], res["top_scores"])),
392
+ layer_slider: gr.update(maximum=L-1, value=res["default_layer"]),
393
+ head_slider: gr.update(maximum=H-1, value=res["default_head"]),
394
+ token_step: gr.update(maximum=T, value=0),
395
+ residual_plot: gr.update(value=residual_fig),
396
+ neuron_table: gr.update(value=[[t, round(v,4)] for t,v in res["neuron_info"][0]["top_tokens"]] if res["neuron_info"] else []),
397
+ patch_layer: gr.update(maximum=L-1, value=0),
398
+ patch_pos: gr.update(maximum=T, value=0),
399
+ patch_from: gr.update(maximum=T, value=0),
400
+ state: res
401
+ }
402
+
403
+
404
+ run_btn.click(
405
+ run_app,
406
+ inputs=[input_text, model_name, simple],
407
+ outputs=[
408
+ token_display, explanation_md, model_info,
409
+ attn_plot, pca_plot, probs_plot,
410
+ layer_slider, head_slider, token_step,
411
+ residual_plot, neuron_table,
412
+ patch_layer, patch_pos, patch_from,
413
+ state
414
+ ]
415
+ )
416
+
417
+ # ---- SLIDER UPDATES ----
418
+ def update_view(res, layer, head, tok):
419
+ if not res or "error" in res:
420
+ return {
421
+ attn_plot: gr.update(value=None),
422
+ pca_plot: gr.update(value=None),
423
+ step_attn_plot: gr.update(value=None),
424
+ }
425
+
426
+ tokens = res["tokens"]
427
+ layer = min(max(0, layer), len(res["attentions"]) - 1)
428
+ head = min(max(0, head), res["attentions"][0].shape[0] - 1)
429
+ tok = min(max(0, tok), len(tokens) - 1)
430
+
431
+ att = fig_attention(res["attentions"][layer][head], tokens, f"Layer {layer} Head {head}")
432
+ pts = res["pca"][layer]
433
+ pca_fig = fig_pca(pts, tokens, highlight=tok, title=f"PCA Layer {layer}")
434
+
435
+ row = res["attentions"][layer][head][tok]
436
+ step_fig = go.Figure([go.Bar(x=tokens, y=row)])
437
+ step_fig.update_layout(title=f"Token {tok} attends to")
438
+
439
+ return {
440
+ attn_plot: gr.update(value=att),
441
+ pca_plot: gr.update(value=pca_fig),
442
+ step_attn_plot: gr.update(value=step_fig)
443
+ }
444
+
445
+ layer_slider.change(update_view, [state, layer_slider, head_slider, token_step],
446
+ [attn_plot, pca_plot, step_attn_plot])
447
+ head_slider.change(update_view, [state, layer_slider, head_slider, token_step],
448
+ [attn_plot, pca_plot, step_attn_plot])
449
+ token_step.change(update_view, [state, layer_slider, head_slider, token_step],
450
+ [attn_plot, pca_plot, step_attn_plot])
451
+
452
+ # ---- NEURON EXPLORER ----
453
+ def neuron_auto(res):
454
+ if not res or "neuron_info" not in res:
455
+ return gr.update(value=[])
456
+ rows = []
457
+ for item in res["neuron_info"]:
458
+ for t, v in item["top_tokens"]:
459
+ rows.append([t, round(v,4)])
460
+ df = pd.DataFrame(rows, columns=["token","activation"]).drop_duplicates().head(24)
461
+ return gr.update(value=df.values.tolist())
462
+
463
+ neuron_find_btn.click(neuron_auto, [state], [neuron_table])
464
+
465
+ def neuron_manual(res, idx):
466
+ if not res or "hidden" not in res:
467
+ return gr.update(value=[])
468
+ try:
469
+ idx = int(idx)
470
+ except:
471
+ return gr.update(value=[])
472
+ last = res["hidden"][-1]
473
+ if idx >= last.shape[1]:
474
+ return gr.update(value=[])
475
+ vals = last[:, idx]
476
+ tokens = res["tokens"]
477
+ pairs = sorted([(tokens[i], float(vals[i])) for i in range(len(tokens))],
478
+ key=lambda x: -abs(x[1]))[:12]
479
+ return gr.update(value=[[t, round(v,4)] for t,v in pairs])
480
+
481
+ neuron_idx.change(neuron_manual, [state, neuron_idx], [neuron_table])
482
+
483
+ # ---- ACTIVATION PATCHING ----
484
+ def patch_run(res, L, P, FP, S, model):
485
+ if not res or "tokens" not in res:
486
+ return gr.update(value=None)
487
+ out = activation_patch(res["tokens"], model, int(L), int(P), int(FP), float(S))
488
+ if "error" in out:
489
+ return gr.update(value=None)
490
+ fig = fig_probs(out["tokens"], out["scores"])
491
+ return gr.update(value=fig)
492
+
493
+ patch_btn.click(patch_run,
494
+ [state, patch_layer, patch_pos, patch_from, patch_scale, model_name],
495
+ [patch_output])
496
 
497
  demo.launch()