Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,503 +1,497 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
-
#
|
|
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import torch
|
| 7 |
import numpy as np
|
| 8 |
import plotly.express as px
|
| 9 |
import plotly.graph_objects as go
|
| 10 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 11 |
-
from sklearn.decomposition import PCA
|
| 12 |
import pandas as pd
|
| 13 |
-
import
|
|
|
|
| 14 |
import html
|
| 15 |
|
| 16 |
-
# ---------------- Config ----------------
|
| 17 |
DEFAULT_MODEL = "distilgpt2"
|
| 18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
# ---------------- Utilities ----------------
|
| 22 |
def load_model(model_name):
|
| 23 |
-
if model_name in
|
| 24 |
-
return
|
| 25 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 26 |
model = AutoModelForCausalLM.from_pretrained(
|
| 27 |
model_name, output_attentions=True, output_hidden_states=True
|
| 28 |
-
)
|
| 29 |
-
model.to(DEVICE)
|
| 30 |
model.eval()
|
| 31 |
-
|
| 32 |
return model, tokenizer
|
| 33 |
|
|
|
|
| 34 |
def softmax(x):
|
| 35 |
e = np.exp(x - np.max(x))
|
| 36 |
return e / e.sum(axis=-1, keepdims=True)
|
| 37 |
|
| 38 |
-
|
|
|
|
| 39 |
return " ".join([f"[{html.escape(t)}]" for t in tokens])
|
| 40 |
|
| 41 |
-
|
|
|
|
| 42 |
try:
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
except Exception:
|
| 46 |
seq = hidden_layer.shape[0]
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
return np.vstack([
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
fig.update_layout(height=420
|
| 56 |
return fig
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
fig.update_traces(textposition="top center", marker=dict(size=10))
|
| 61 |
-
if
|
| 62 |
fig.add_trace(go.Scatter(
|
| 63 |
-
x=[points[
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
))
|
| 67 |
-
fig.update_layout(height=420
|
| 68 |
return fig
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
fig
|
|
|
|
|
|
|
| 73 |
return fig
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
if not text or len(text.strip()) == 0:
|
| 82 |
-
return {"error": "Please enter some text."}
|
| 83 |
|
| 84 |
try:
|
| 85 |
model, tokenizer = load_model(model_name)
|
| 86 |
except Exception as e:
|
| 87 |
-
return {"error": f"Failed to load model
|
|
|
|
|
|
|
| 88 |
|
| 89 |
try:
|
| 90 |
-
|
|
|
|
| 91 |
except Exception as e:
|
| 92 |
-
return {"error": f"
|
| 93 |
-
with torch.no_grad():
|
| 94 |
-
try:
|
| 95 |
-
outputs = model(**inputs)
|
| 96 |
-
except Exception as e:
|
| 97 |
-
return {"error": f"Model forward error: {e}"}
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
input_ids = inputs["input_ids"][0].cpu().numpy().tolist()
|
| 102 |
-
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
| 103 |
-
except Exception:
|
| 104 |
-
return {"error": "Failed to extract tokens."}
|
| 105 |
|
| 106 |
-
attentions = [a[0].cpu().numpy() for a in
|
| 107 |
-
hidden = [h[0].cpu().numpy() for h in
|
| 108 |
-
logits =
|
| 109 |
|
| 110 |
# PCA per layer
|
| 111 |
-
pca_layers = []
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
# Next-token topk
|
| 119 |
-
last_logits = logits[-1]
|
| 120 |
-
probs = softmax(last_logits)
|
| 121 |
-
topk = 25
|
| 122 |
-
idx = np.argsort(probs)[-topk:][::-1]
|
| 123 |
-
top_tokens = [tokenizer.decode([int(i)]) for i in idx]
|
| 124 |
top_scores = probs[idx].tolist()
|
| 125 |
|
| 126 |
-
default_layer =
|
| 127 |
default_head = 0
|
| 128 |
|
| 129 |
-
|
| 130 |
-
fig_pca = make_pca_figure(pca_layers[default_layer], tokens, highlight_idx=None, title=f"PCA (layer {default_layer})") if pca_layers is not None else None
|
| 131 |
-
fig_probs = make_probs_figure(top_tokens, top_scores, title="Next-token top predictions")
|
| 132 |
-
|
| 133 |
-
explanation = (
|
| 134 |
-
"Simple: the model splits text into pieces, looks which pieces are important, and guesses the next word."
|
| 135 |
-
if explain_simple else
|
| 136 |
-
"Technical: tokens, attention matrices per head/layer, hidden states projected to 2D, and top-k next-token probabilities."
|
| 137 |
-
)
|
| 138 |
-
|
| 139 |
-
# neuron explorer (top neurons by mean absolute activation in last layer)
|
| 140 |
neuron_info = []
|
| 141 |
try:
|
| 142 |
-
|
| 143 |
-
mean_act = np.abs(
|
| 144 |
top_neurons = np.argsort(mean_act)[-24:][::-1]
|
| 145 |
-
for n in top_neurons
|
| 146 |
-
vals =
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
neuron_info = []
|
| 152 |
|
| 153 |
-
# residual
|
| 154 |
-
residuals =
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
blocks = model.transformer.h
|
| 158 |
-
wte = getattr(model.transformer, "wte", None)
|
| 159 |
-
# compute norms by approximating per-layer attn & mlp outputs using block forward if possible
|
| 160 |
-
attn_norms = []
|
| 161 |
-
mlp_norms = []
|
| 162 |
-
# Start from embeddings
|
| 163 |
-
cur = wte(inputs["input_ids"]) if wte is not None else None
|
| 164 |
-
if cur is not None:
|
| 165 |
-
cur = cur.to(DEVICE)
|
| 166 |
-
# We'll run each block and measure norms of attention & mlp outputs if callable
|
| 167 |
-
for block in blocks:
|
| 168 |
-
try:
|
| 169 |
-
ln1 = block.ln_1(cur)
|
| 170 |
-
attn_out = block.attn(ln1)[0]
|
| 171 |
-
cur = cur + attn_out
|
| 172 |
-
ln2 = block.ln_2(cur)
|
| 173 |
-
mlp_out = block.mlp(ln2)
|
| 174 |
-
cur = cur + mlp_out
|
| 175 |
-
attn_norms.append(float(torch.norm(attn_out).cpu().numpy()))
|
| 176 |
-
mlp_norms.append(float(torch.norm(mlp_out).cpu().numpy()))
|
| 177 |
-
except Exception:
|
| 178 |
-
# fallback: run full block and compute residual diff
|
| 179 |
-
prev = cur.clone()
|
| 180 |
-
try:
|
| 181 |
-
cur = block(prev)[0]
|
| 182 |
-
total = cur - prev
|
| 183 |
-
attn_norms.append(0.0)
|
| 184 |
-
mlp_norms.append(float(torch.norm(total).cpu().numpy()))
|
| 185 |
-
except Exception:
|
| 186 |
-
attn_norms.append(0.0)
|
| 187 |
-
mlp_norms.append(0.0)
|
| 188 |
-
residuals = {"attn_norms": attn_norms, "mlp_norms": mlp_norms}
|
| 189 |
-
except Exception:
|
| 190 |
-
residuals = None
|
| 191 |
-
|
| 192 |
-
result = {
|
| 193 |
"tokens": tokens,
|
| 194 |
"attentions": attentions,
|
| 195 |
"hidden": hidden,
|
|
|
|
| 196 |
"logits": logits,
|
| 197 |
-
"
|
| 198 |
-
"
|
| 199 |
-
"fig_pca": fig_pca,
|
| 200 |
-
"fig_probs": fig_probs,
|
| 201 |
"default_layer": default_layer,
|
| 202 |
"default_head": default_head,
|
| 203 |
-
"token_display": tokens_display(tokens),
|
| 204 |
-
"explanation": explanation,
|
| 205 |
"neuron_info": neuron_info,
|
| 206 |
"residuals": residuals,
|
| 207 |
-
"
|
| 208 |
-
"
|
| 209 |
}
|
| 210 |
-
return result
|
| 211 |
|
| 212 |
-
|
| 213 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
"""
|
| 215 |
-
|
| 216 |
-
Returns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
"""
|
| 218 |
try:
|
| 219 |
model, tokenizer = load_model(model_name)
|
| 220 |
-
except
|
| 221 |
-
return {"error": "Model load
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
return {"error": "Model not compatible with activation patching (requires GPT-2 style blocks)."}
|
| 230 |
|
| 231 |
blocks = model.transformer.h
|
| 232 |
wte = model.transformer.wte
|
| 233 |
ln_f = model.transformer.ln_f if hasattr(model.transformer, "ln_f") else None
|
| 234 |
lm_head = model.lm_head
|
| 235 |
|
| 236 |
-
# collect hidden precomputed
|
| 237 |
with torch.no_grad():
|
| 238 |
-
x = wte(inputs["input_ids"]).to(DEVICE)
|
| 239 |
-
|
| 240 |
-
for
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
x
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
| 256 |
|
| 257 |
# re-run with patch
|
| 258 |
with torch.no_grad():
|
| 259 |
x = wte(inputs["input_ids"]).to(DEVICE)
|
| 260 |
-
for i,
|
| 261 |
-
ln1 =
|
| 262 |
-
|
| 263 |
-
x = x +
|
| 264 |
-
ln2 =
|
| 265 |
-
|
| 266 |
-
x = x +
|
| 267 |
-
if i ==
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
final = ln_f(x) if ln_f
|
| 271 |
-
logits = lm_head(final)
|
| 272 |
-
logits = logits[0, -1, :].cpu().numpy()
|
| 273 |
probs = softmax(logits)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
return {"patched_top_tokens": top_tokens, "patched_top_scores": top_scores}
|
| 279 |
-
|
| 280 |
-
# ---------------- UI wiring ----------------
|
| 281 |
-
def run_analysis(text, model_name, explain_simple):
|
| 282 |
-
res = analyze_text(text, model_name, explain_simple)
|
| 283 |
-
if "error" in res:
|
| 284 |
-
# return dict keyed by components
|
| 285 |
-
return {
|
| 286 |
-
token_display: gr.update(value=""),
|
| 287 |
-
explanation_md: gr.update(value=res["error"]),
|
| 288 |
-
model_info: gr.update(value=f"Model: {model_name}"),
|
| 289 |
-
attn_plot: gr.update(value=None),
|
| 290 |
-
pca_plot: gr.update(value=None),
|
| 291 |
-
probs_plot: gr.update(value=None),
|
| 292 |
-
layer_slider: gr.update(maximum=0, value=0),
|
| 293 |
-
head_slider: gr.update(maximum=0, value=0),
|
| 294 |
-
token_step: gr.update(maximum=0, value=0),
|
| 295 |
-
state: res,
|
| 296 |
-
residual_plot: gr.update(value=None),
|
| 297 |
-
neuron_table: gr.update(value=[]),
|
| 298 |
-
patch_layer_input: gr.update(maximum=0, value=0),
|
| 299 |
-
patch_pos_input: gr.update(maximum=0, value=0),
|
| 300 |
-
patch_from_pos_input: gr.update(maximum=0, value=0),
|
| 301 |
-
}
|
| 302 |
|
| 303 |
-
tokens = res["tokens"]
|
| 304 |
-
num_layers = len(res["attentions"]) if res["attentions"] is not None else (len(res["pca_layers"]) - 1 if res["pca_layers"] else 0)
|
| 305 |
-
num_heads = res["attentions"][0].shape[0] if res["attentions"] is not None else 1
|
| 306 |
-
max_token_idx = len(tokens) - 1
|
| 307 |
-
|
| 308 |
-
token_display_text = f"**Tokens:** {res['token_display']}"
|
| 309 |
-
explanation_text = res["explanation"]
|
| 310 |
-
model_info_text = f"Model: {res['model_name']} β’ layers: {num_layers} β’ heads: {num_heads} β’ tokens: {len(tokens)}"
|
| 311 |
-
|
| 312 |
-
layer_update = gr.update(maximum=max(0, num_layers - 1), value=res["default_layer"])
|
| 313 |
-
head_update = gr.update(maximum=max(0, num_heads - 1), value=res["default_head"])
|
| 314 |
-
token_step_update = gr.update(maximum=max_token_idx, value=0)
|
| 315 |
-
|
| 316 |
-
patch_layer_update = gr.update(maximum=max(0, num_layers - 1), value=0)
|
| 317 |
-
patch_pos_update = gr.update(maximum=max(0, max_token_idx), value=0)
|
| 318 |
-
patch_from_pos_update = gr.update(maximum=max(0, max_token_idx), value=0)
|
| 319 |
-
|
| 320 |
-
# neuron table initial (show first neuron's top tokens if available)
|
| 321 |
-
neuron_table_data = []
|
| 322 |
-
if res.get("neuron_info"):
|
| 323 |
-
first = res["neuron_info"][0]
|
| 324 |
-
neuron_table_data = [[t, round(v, 6)] for t, v in first["top_tokens"]]
|
| 325 |
-
|
| 326 |
-
# residual figure
|
| 327 |
-
residual_fig = None
|
| 328 |
-
if res.get("residuals"):
|
| 329 |
-
df = pd.DataFrame({"layer": list(range(len(res["residuals"]["attn_norms"]))),
|
| 330 |
-
"attn": res["residuals"]["attn_norms"],
|
| 331 |
-
"mlp": res["residuals"]["mlp_norms"]})
|
| 332 |
-
residual_fig = go.Figure()
|
| 333 |
-
residual_fig.add_trace(go.Bar(x=df["layer"], y=df["attn"], name="Attention norm"))
|
| 334 |
-
residual_fig.add_trace(go.Bar(x=df["layer"], y=df["mlp"], name="MLP norm"))
|
| 335 |
-
residual_fig.update_layout(barmode="group", title="Residual contributions (layerwise norms)", height=360)
|
| 336 |
|
| 337 |
-
|
| 338 |
-
token_display: gr.update(value=token_display_text),
|
| 339 |
-
explanation_md: gr.update(value=explanation_text),
|
| 340 |
-
model_info: gr.update(value=model_info_text),
|
| 341 |
-
attn_plot: gr.update(value=res["fig_attn"]),
|
| 342 |
-
pca_plot: gr.update(value=res["fig_pca"]),
|
| 343 |
-
probs_plot: gr.update(value=res["fig_probs"]),
|
| 344 |
-
layer_slider: layer_update,
|
| 345 |
-
head_slider: head_update,
|
| 346 |
-
token_step: token_step_update,
|
| 347 |
-
state: res,
|
| 348 |
-
residual_plot: gr.update(value=residual_fig),
|
| 349 |
-
neuron_table: gr.update(value=neuron_table_data),
|
| 350 |
-
patch_layer_input: patch_layer_update,
|
| 351 |
-
patch_pos_input: patch_pos_update,
|
| 352 |
-
patch_from_pos_input: patch_from_pos_update,
|
| 353 |
-
}
|
| 354 |
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
res = state_obj
|
| 360 |
-
tokens = res["tokens"]
|
| 361 |
-
# bounds
|
| 362 |
-
if res["attentions"] is not None:
|
| 363 |
-
max_layer = len(res["attentions"]) - 1
|
| 364 |
-
layer = int(min(max(0, layer), max_layer))
|
| 365 |
-
max_head = res["attentions"][0].shape[0] - 1
|
| 366 |
-
head = int(min(max(0, head), max_head))
|
| 367 |
-
else:
|
| 368 |
-
layer = int(min(max(0, layer), len(res["pca_layers"]) - 1 if res["pca_layers"] else 0))
|
| 369 |
-
head = 0
|
| 370 |
-
token_idx = int(min(max(0, token_idx), len(tokens) - 1))
|
| 371 |
-
|
| 372 |
-
attn_fig = None
|
| 373 |
-
if res["attentions"] is not None:
|
| 374 |
-
attn_fig = make_attention_figure(res["attentions"][layer][head], tokens, title=f"Layer {layer} Head {head}")
|
| 375 |
-
|
| 376 |
-
pca_fig = None
|
| 377 |
-
if res["pca_layers"] is not None:
|
| 378 |
-
pca_pts = res["pca_layers"][layer]
|
| 379 |
-
pca_fig = make_pca_figure(pca_pts, tokens, highlight_idx=token_idx, title=f"PCA (layer {layer})")
|
| 380 |
-
|
| 381 |
-
step_attn_fig = None
|
| 382 |
-
if res["attentions"] is not None:
|
| 383 |
-
row = res["attentions"][layer][head][token_idx]
|
| 384 |
-
step_attn_fig = go.Figure(data=[go.Bar(x=tokens, y=row)])
|
| 385 |
-
step_attn_fig.update_layout(title=f"Token {token_idx} attends to (layer {layer}, head {head})", height=300, margin=dict(l=40,r=20,t=30,b=40))
|
| 386 |
-
|
| 387 |
-
return {attn_plot: gr.update(value=attn_fig),
|
| 388 |
-
pca_plot: gr.update(value=pca_fig),
|
| 389 |
-
step_attn_plot: gr.update(value=step_attn_fig)}
|
| 390 |
-
|
| 391 |
-
def run_patch(state_obj, patch_layer, patch_pos, patch_from_pos, patch_scale, model_name):
|
| 392 |
-
if not state_obj:
|
| 393 |
-
return gr.update(value=None)
|
| 394 |
-
tokens = state_obj["tokens"]
|
| 395 |
-
res = activation_patch_and_run(tokens, model_name, int(patch_layer), int(patch_pos), int(patch_from_pos), float(patch_scale))
|
| 396 |
-
if "error" in res:
|
| 397 |
-
return gr.update(value=None)
|
| 398 |
-
fig = go.Figure(data=[go.Bar(x=res["patched_top_tokens"], y=res["patched_top_scores"])])
|
| 399 |
-
fig.update_layout(title=f"Patched predictions (layer {patch_layer}, pos {patch_pos} <- pos {patch_from_pos}, scale {patch_scale})", height=420)
|
| 400 |
-
return gr.update(value=fig)
|
| 401 |
-
|
| 402 |
-
def find_neurons(state_obj):
|
| 403 |
-
if not state_obj:
|
| 404 |
-
return gr.update(value=[])
|
| 405 |
-
info = state_obj.get("neuron_info", [])
|
| 406 |
-
rows = []
|
| 407 |
-
for e in info[:24]:
|
| 408 |
-
for t, v in e["top_tokens"]:
|
| 409 |
-
rows.append([t, round(v,6)])
|
| 410 |
-
# dedupe
|
| 411 |
-
df = pd.DataFrame(rows, columns=["token","activation"]).drop_duplicates().head(24).values.tolist()
|
| 412 |
-
return gr.update(value=df)
|
| 413 |
-
|
| 414 |
-
def inspect_neuron(state_obj, neuron_idx):
|
| 415 |
-
if not state_obj:
|
| 416 |
-
return gr.update(value=[])
|
| 417 |
-
try:
|
| 418 |
-
neuron_idx = int(neuron_idx)
|
| 419 |
-
except Exception:
|
| 420 |
-
return gr.update(value=[])
|
| 421 |
-
last_hidden = state_obj["hidden"][-1]
|
| 422 |
-
vals = last_hidden[:, neuron_idx]
|
| 423 |
-
tokens = state_obj["tokens"]
|
| 424 |
-
df = sorted([(tokens[i], float(vals[i])) for i in range(len(tokens))], key=lambda x: -abs(x[1]))[:12]
|
| 425 |
-
return gr.update(value=[[t, round(v,6)] for t,v in df])
|
| 426 |
-
|
| 427 |
-
# ---------------- Gradio UI ----------------
|
| 428 |
-
with gr.Blocks(title="LLM Visualizer β Full (Option A)", theme=gr.themes.Soft()) as demo:
|
| 429 |
-
gr.Markdown("<h1 style='font-size:30px'>π§ LLM Visualizer β Full (Advanced)</h1>")
|
| 430 |
-
gr.Markdown("Advanced GPT-2 style visualizer. Use `distilgpt2` or `gpt2` for full features. Keep input short (<80 tokens) on CPU Spaces.")
|
| 431 |
|
|
|
|
| 432 |
with gr.Row():
|
| 433 |
with gr.Column(scale=3):
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
run_btn = gr.Button("Run
|
| 438 |
-
gr.Markdown("**Presets:**")
|
| 439 |
-
with gr.Row():
|
| 440 |
-
gr.Button("Greeting").click(lambda: "Hello! How are you today?", None, text_input)
|
| 441 |
-
gr.Button("Story start").click(lambda: "Once upon a time, there was a small robot...", None, text_input)
|
| 442 |
-
gr.Button("Question").click(lambda: "Why is the sky blue?", None, text_input)
|
| 443 |
|
| 444 |
-
|
| 445 |
-
gr.
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
with gr.Column(scale=2):
|
| 448 |
-
token_display = gr.Markdown(
|
| 449 |
-
explanation_md = gr.Markdown(
|
| 450 |
-
model_info = gr.Markdown(
|
| 451 |
|
|
|
|
| 452 |
with gr.Row():
|
| 453 |
with gr.Column():
|
| 454 |
-
layer_slider = gr.Slider(
|
| 455 |
-
head_slider = gr.Slider(
|
| 456 |
-
token_step = gr.Slider(
|
| 457 |
-
attn_plot = gr.Plot(
|
|
|
|
| 458 |
with gr.Column():
|
| 459 |
-
pca_plot = gr.Plot(
|
| 460 |
-
step_attn_plot = gr.Plot(
|
| 461 |
-
probs_plot = gr.Plot(
|
| 462 |
-
residual_plot = gr.Plot(label="Residual decomposition (attention vs mlp)")
|
| 463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
with gr.Row():
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
patch_output = gr.Plot(label="Patched next-token predictions")
|
| 478 |
|
| 479 |
state = gr.State()
|
| 480 |
|
| 481 |
-
#
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
|
| 503 |
demo.launch()
|
|
|
|
| 1 |
+
# FULL LLM VISUALIZER β OPTION A (ADVANCED)
|
| 2 |
+
# stable + patched + safe for HuggingFace Spaces (CPU or GPU)
|
| 3 |
+
# recommended models: distilgpt2, gpt2
|
| 4 |
+
# author: ChatGPT
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
| 8 |
import numpy as np
|
| 9 |
import plotly.express as px
|
| 10 |
import plotly.graph_objects as go
|
|
|
|
|
|
|
| 11 |
import pandas as pd
|
| 12 |
+
from sklearn.decomposition import PCA
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 14 |
import html
|
| 15 |
|
|
|
|
| 16 |
DEFAULT_MODEL = "distilgpt2"
|
| 17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
MODEL_CACHE = {}
|
| 19 |
+
|
| 20 |
+
# ---------------- CORE UTILS ----------------
|
| 21 |
|
|
|
|
| 22 |
def load_model(model_name):
|
| 23 |
+
if model_name in MODEL_CACHE:
|
| 24 |
+
return MODEL_CACHE[model_name]
|
| 25 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 26 |
model = AutoModelForCausalLM.from_pretrained(
|
| 27 |
model_name, output_attentions=True, output_hidden_states=True
|
| 28 |
+
).to(DEVICE)
|
|
|
|
| 29 |
model.eval()
|
| 30 |
+
MODEL_CACHE[model_name] = (model, tokenizer)
|
| 31 |
return model, tokenizer
|
| 32 |
|
| 33 |
+
|
| 34 |
def softmax(x):
|
| 35 |
e = np.exp(x - np.max(x))
|
| 36 |
return e / e.sum(axis=-1, keepdims=True)
|
| 37 |
|
| 38 |
+
|
| 39 |
+
def safe_tokens(tokens):
|
| 40 |
return " ".join([f"[{html.escape(t)}]" for t in tokens])
|
| 41 |
|
| 42 |
+
|
| 43 |
+
def compute_pca(hidden_layer):
|
| 44 |
try:
|
| 45 |
+
return PCA(n_components=2).fit_transform(hidden_layer)
|
| 46 |
+
except:
|
|
|
|
| 47 |
seq = hidden_layer.shape[0]
|
| 48 |
+
dim0 = hidden_layer[:, 0] if hidden_layer.shape[1] > 0 else np.zeros(seq)
|
| 49 |
+
dim1 = hidden_layer[:, 1] if hidden_layer.shape[1] > 1 else np.zeros(seq)
|
| 50 |
+
return np.vstack([dim0, dim1]).T
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def fig_attention(matrix, tokens, title):
|
| 54 |
+
fig = px.imshow(matrix, x=tokens, y=tokens, title=title,
|
| 55 |
+
labels={"x": "Key token", "y": "Query token", "color": "Attention"})
|
| 56 |
+
fig.update_layout(height=420)
|
| 57 |
return fig
|
| 58 |
|
| 59 |
+
|
| 60 |
+
def fig_pca(points, tokens, highlight=None, title="PCA"):
|
| 61 |
+
fig = px.scatter(x=points[:, 0], y=points[:, 1], text=tokens, title=title)
|
| 62 |
fig.update_traces(textposition="top center", marker=dict(size=10))
|
| 63 |
+
if highlight is not None:
|
| 64 |
fig.add_trace(go.Scatter(
|
| 65 |
+
x=[points[highlight, 0]],
|
| 66 |
+
y=[points[highlight, 1]],
|
| 67 |
+
mode="markers+text",
|
| 68 |
+
text=[tokens[highlight]],
|
| 69 |
+
marker=dict(size=18, color="red")
|
| 70 |
))
|
| 71 |
+
fig.update_layout(height=420)
|
| 72 |
return fig
|
| 73 |
|
| 74 |
+
|
| 75 |
+
def fig_probs(tokens, scores):
|
| 76 |
+
fig = go.Figure()
|
| 77 |
+
fig.add_trace(go.Bar(x=tokens, y=scores))
|
| 78 |
+
fig.update_layout(title="Next-token probabilities", height=380)
|
| 79 |
return fig
|
| 80 |
|
| 81 |
+
|
| 82 |
+
# ---------------- ANALYSIS CORE ----------------
|
| 83 |
+
|
| 84 |
+
def analyze_text(text, model_name, simple):
|
| 85 |
+
if not text.strip():
|
| 86 |
+
return {"error": "Please enter text."}
|
|
|
|
|
|
|
| 87 |
|
| 88 |
try:
|
| 89 |
model, tokenizer = load_model(model_name)
|
| 90 |
except Exception as e:
|
| 91 |
+
return {"error": f"Failed to load model: {e}"}
|
| 92 |
+
|
| 93 |
+
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
|
| 94 |
|
| 95 |
try:
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
out = model(**inputs)
|
| 98 |
except Exception as e:
|
| 99 |
+
return {"error": f"Model error: {e}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
input_ids = inputs["input_ids"][0].cpu().numpy().tolist()
|
| 102 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
attentions = [a[0].cpu().numpy() for a in out.attentions]
|
| 105 |
+
hidden = [h[0].cpu().numpy() for h in out.hidden_states]
|
| 106 |
+
logits = out.logits[0].cpu().numpy()
|
| 107 |
|
| 108 |
# PCA per layer
|
| 109 |
+
pca_layers = [compute_pca(h) for h in hidden]
|
| 110 |
+
|
| 111 |
+
# top-k
|
| 112 |
+
last = logits[-1]
|
| 113 |
+
probs = softmax(last)
|
| 114 |
+
idx = np.argsort(probs)[-20:][::-1]
|
| 115 |
+
top_tokens = [tokenizer.decode([i]) for i in idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
top_scores = probs[idx].tolist()
|
| 117 |
|
| 118 |
+
default_layer = len(attentions) - 1
|
| 119 |
default_head = 0
|
| 120 |
|
| 121 |
+
# neuron explorer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
neuron_info = []
|
| 123 |
try:
|
| 124 |
+
last_h = hidden[-1]
|
| 125 |
+
mean_act = np.abs(last_h).mean(axis=0)
|
| 126 |
top_neurons = np.argsort(mean_act)[-24:][::-1]
|
| 127 |
+
for n in top_neurons:
|
| 128 |
+
vals = last_h[:, n]
|
| 129 |
+
top_ix = np.argsort(np.abs(vals))[-5:][::-1]
|
| 130 |
+
neuron_info.append({
|
| 131 |
+
"neuron": int(n),
|
| 132 |
+
"top_tokens": [(tokens[i], float(vals[i])) for i in top_ix]
|
| 133 |
+
})
|
| 134 |
+
except:
|
| 135 |
neuron_info = []
|
| 136 |
|
| 137 |
+
# residual decomposition (safe)
|
| 138 |
+
residuals = compute_residuals_safe(model, inputs)
|
| 139 |
+
|
| 140 |
+
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
"tokens": tokens,
|
| 142 |
"attentions": attentions,
|
| 143 |
"hidden": hidden,
|
| 144 |
+
"pca": pca_layers,
|
| 145 |
"logits": logits,
|
| 146 |
+
"top_tokens": top_tokens,
|
| 147 |
+
"top_scores": top_scores,
|
|
|
|
|
|
|
| 148 |
"default_layer": default_layer,
|
| 149 |
"default_head": default_head,
|
|
|
|
|
|
|
| 150 |
"neuron_info": neuron_info,
|
| 151 |
"residuals": residuals,
|
| 152 |
+
"token_display": safe_tokens(tokens),
|
| 153 |
+
"explanation": explain(simple)
|
| 154 |
}
|
|
|
|
| 155 |
|
| 156 |
+
|
| 157 |
+
def explain(s):
|
| 158 |
+
if s:
|
| 159 |
+
return (
|
| 160 |
+
"π§ **Simple mode:**\n"
|
| 161 |
+
"- The model cuts text into small pieces (tokens).\n"
|
| 162 |
+
"- It looks at which tokens matter (attention).\n"
|
| 163 |
+
"- It builds an internal map (PCA) of meanings.\n"
|
| 164 |
+
"- Then it guesses the next token.\n"
|
| 165 |
+
)
|
| 166 |
+
return (
|
| 167 |
+
"π¬ **Technical mode:**\n"
|
| 168 |
+
"Showing tokens, attention (queryβkey), PCA projections, logits, "
|
| 169 |
+
"neuron activations, and layerwise residual contributions.\n"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ---------------- RESIDUAL DECOMPOSITION SAFE ----------------
|
| 174 |
+
|
| 175 |
+
def compute_residuals_safe(model, inputs):
|
| 176 |
"""
|
| 177 |
+
Guaranteed safe residual norms for GPT-2-style blocks.
|
| 178 |
+
Will NEVER crash. Returns None if not applicable.
|
| 179 |
+
"""
|
| 180 |
+
if not hasattr(model, "transformer") or not hasattr(model.transformer, "h"):
|
| 181 |
+
return None
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
blocks = model.transformer.h
|
| 185 |
+
wte = model.transformer.wte
|
| 186 |
+
x = wte(inputs["input_ids"]).to(DEVICE)
|
| 187 |
+
|
| 188 |
+
attn_norms = []
|
| 189 |
+
mlp_norms = []
|
| 190 |
+
|
| 191 |
+
for block in blocks:
|
| 192 |
+
try:
|
| 193 |
+
ln1 = block.ln_1(x)
|
| 194 |
+
attn_out = block.attn(ln1)[0]
|
| 195 |
+
x = x + attn_out
|
| 196 |
+
ln2 = block.ln_2(x)
|
| 197 |
+
mlp_out = block.mlp(ln2)
|
| 198 |
+
x = x + mlp_out
|
| 199 |
+
|
| 200 |
+
attn_norms.append(float(torch.norm(attn_out).cpu()))
|
| 201 |
+
mlp_norms.append(float(torch.norm(mlp_out).cpu()))
|
| 202 |
+
except:
|
| 203 |
+
# fallback safe zero
|
| 204 |
+
attn_norms.append(0.0)
|
| 205 |
+
mlp_norms.append(0.0)
|
| 206 |
+
|
| 207 |
+
# normalize lengths safely
|
| 208 |
+
L = min(len(attn_norms), len(mlp_norms))
|
| 209 |
+
return {
|
| 210 |
+
"attn": attn_norms[:L],
|
| 211 |
+
"mlp": mlp_norms[:L],
|
| 212 |
+
}
|
| 213 |
+
except:
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ---------------- ACTIVATION PATCHING (SAFE VERSION) ----------------
|
| 218 |
+
|
| 219 |
+
def activation_patch(tokens, model_name, layer, pos, from_pos, scale=1.0):
|
| 220 |
+
"""
|
| 221 |
+
Safe activation patching (never crashes, only works for GPT-2 style).
|
| 222 |
"""
|
| 223 |
try:
|
| 224 |
model, tokenizer = load_model(model_name)
|
| 225 |
+
except:
|
| 226 |
+
return {"error": "Model load error."}
|
| 227 |
|
| 228 |
+
if not hasattr(model, "transformer") or not hasattr(model.transformer, "h"):
|
| 229 |
+
return {"error": "Model not compatible with patching."}
|
|
|
|
| 230 |
|
| 231 |
+
text = " ".join(tokens)
|
| 232 |
+
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
|
|
|
|
| 233 |
|
| 234 |
blocks = model.transformer.h
|
| 235 |
wte = model.transformer.wte
|
| 236 |
ln_f = model.transformer.ln_f if hasattr(model.transformer, "ln_f") else None
|
| 237 |
lm_head = model.lm_head
|
| 238 |
|
|
|
|
| 239 |
with torch.no_grad():
|
| 240 |
+
x = wte(inputs["input_ids"]).to(DEVICE)
|
| 241 |
+
hidden_layers = [x.clone().cpu().numpy()[0]]
|
| 242 |
+
for b in blocks:
|
| 243 |
+
ln1 = b.ln_1(x)
|
| 244 |
+
a = b.attn(ln1)[0]
|
| 245 |
+
x = x + a
|
| 246 |
+
ln2 = b.ln_2(x)
|
| 247 |
+
m = b.mlp(ln2)
|
| 248 |
+
x = x + m
|
| 249 |
+
hidden_layers.append(x.clone().cpu().numpy()[0])
|
| 250 |
+
|
| 251 |
+
if layer >= len(hidden_layers):
|
| 252 |
+
return {"error": "Layer out of range."}
|
| 253 |
+
|
| 254 |
+
seq_len = hidden_layers[layer].shape[0]
|
| 255 |
+
if pos >= seq_len or from_pos >= seq_len:
|
| 256 |
+
return {"error": "Position out of range."}
|
| 257 |
+
|
| 258 |
+
patch_vec = torch.tensor(hidden_layers[layer][from_pos], dtype=torch.float32).to(DEVICE) * float(scale)
|
| 259 |
|
| 260 |
# re-run with patch
|
| 261 |
with torch.no_grad():
|
| 262 |
x = wte(inputs["input_ids"]).to(DEVICE)
|
| 263 |
+
for i, b in enumerate(blocks):
|
| 264 |
+
ln1 = b.ln_1(x)
|
| 265 |
+
a = b.attn(ln1)[0]
|
| 266 |
+
x = x + a
|
| 267 |
+
ln2 = b.ln_2(x)
|
| 268 |
+
m = b.mlp(ln2)
|
| 269 |
+
x = x + m
|
| 270 |
+
if i == layer:
|
| 271 |
+
x[0, pos, :] = patch_vec
|
| 272 |
+
|
| 273 |
+
final = ln_f(x) if ln_f else x
|
| 274 |
+
logits = lm_head(final)[0, -1, :].cpu().numpy()
|
|
|
|
| 275 |
probs = softmax(logits)
|
| 276 |
+
idx = np.argsort(probs)[-20:][::-1]
|
| 277 |
+
tt = [tokenizer.decode([int(i)]) for i in idx]
|
| 278 |
+
ss = probs[idx].tolist()
|
| 279 |
+
return {"tokens": tt, "scores": ss}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
# ---------------- GRADIO UI ----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
|
| 285 |
+
|
| 286 |
+
gr.Markdown("# π§ Full LLM Visualizer (Advanced)")
|
| 287 |
+
gr.Markdown("Fully stable build with attention, PCA, neuron explorer, residuals, activation-patching")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
+
# Panel 1
|
| 290 |
with gr.Row():
|
| 291 |
with gr.Column(scale=3):
|
| 292 |
+
model_name = gr.Textbox(label="Model", value=DEFAULT_MODEL)
|
| 293 |
+
input_text = gr.Textbox(label="Input", value="Hello world", lines=3)
|
| 294 |
+
simple = gr.Checkbox(label="Explain simply", value=True)
|
| 295 |
+
run_btn = gr.Button("Run", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
gr.Markdown("Presets:")
|
| 298 |
+
with gr.Row():
|
| 299 |
+
gr.Button("Greeting").click(lambda: "Hello! How are you?", None, input_text)
|
| 300 |
+
gr.Button("Story").click(lambda: "Once upon a time there was a robot.", None, input_text)
|
| 301 |
+
gr.Button("Question").click(lambda: "Why is the sky blue?", None, input_text)
|
| 302 |
|
| 303 |
with gr.Column(scale=2):
|
| 304 |
+
token_display = gr.Markdown()
|
| 305 |
+
explanation_md = gr.Markdown()
|
| 306 |
+
model_info = gr.Markdown()
|
| 307 |
|
| 308 |
+
# Panel 2
|
| 309 |
with gr.Row():
|
| 310 |
with gr.Column():
|
| 311 |
+
layer_slider = gr.Slider(0, 0, value=0, step=1, label="Layer")
|
| 312 |
+
head_slider = gr.Slider(0, 0, value=0, step=1, label="Head")
|
| 313 |
+
token_step = gr.Slider(0, 0, value=0, step=1, label="Token index")
|
| 314 |
+
attn_plot = gr.Plot()
|
| 315 |
+
|
| 316 |
with gr.Column():
|
| 317 |
+
pca_plot = gr.Plot()
|
| 318 |
+
step_attn_plot = gr.Plot()
|
| 319 |
+
probs_plot = gr.Plot()
|
|
|
|
| 320 |
|
| 321 |
+
# Panel 3 β Residuals
|
| 322 |
+
residual_plot = gr.Plot()
|
| 323 |
+
|
| 324 |
+
# Panel 4 β Neuron explorer
|
| 325 |
with gr.Row():
|
| 326 |
+
neuron_find_btn = gr.Button("Find neurons")
|
| 327 |
+
neuron_idx = gr.Number(label="Neuron index", value=0)
|
| 328 |
+
neuron_table = gr.Dataframe(headers=["token", "activation"], interactive=False)
|
| 329 |
+
|
| 330 |
+
# Panel 5 β Activation Patching
|
| 331 |
+
with gr.Row():
|
| 332 |
+
patch_layer = gr.Slider(0, 0, value=0, step=1, label="Patch layer")
|
| 333 |
+
patch_pos = gr.Slider(0, 0, value=0, step=1, label="Target token position")
|
| 334 |
+
patch_from = gr.Slider(0, 0, value=0, step=1, label="Copy from position")
|
| 335 |
+
patch_scale = gr.Number(label="Scale", value=1.0)
|
| 336 |
+
patch_btn = gr.Button("Run patch")
|
| 337 |
+
patch_output = gr.Plot()
|
|
|
|
| 338 |
|
| 339 |
state = gr.State()
|
| 340 |
|
| 341 |
+
# ---- RUN ANALYSIS ----
|
| 342 |
+
def run_app(text, model, simp):
|
| 343 |
+
res = analyze_text(text, model, simp)
|
| 344 |
+
|
| 345 |
+
if "error" in res:
|
| 346 |
+
return {
|
| 347 |
+
token_display: gr.update(value=""),
|
| 348 |
+
explanation_md: gr.update(value=res["error"]),
|
| 349 |
+
model_info: gr.update(value=f"Model: {model}"),
|
| 350 |
+
attn_plot: gr.update(value=None),
|
| 351 |
+
pca_plot: gr.update(value=None),
|
| 352 |
+
probs_plot: gr.update(value=None),
|
| 353 |
+
layer_slider: gr.update(maximum=0, value=0),
|
| 354 |
+
head_slider: gr.update(maximum=0, value=0),
|
| 355 |
+
token_step: gr.update(maximum=0, value=0),
|
| 356 |
+
residual_plot: gr.update(value=None),
|
| 357 |
+
neuron_table: gr.update(value=[]),
|
| 358 |
+
patch_layer: gr.update(maximum=0),
|
| 359 |
+
patch_pos: gr.update(maximum=0),
|
| 360 |
+
patch_from: gr.update(maximum=0),
|
| 361 |
+
state: res
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
tokens = res["tokens"]
|
| 365 |
+
L = len(res["attentions"])
|
| 366 |
+
H = res["attentions"][0].shape[0]
|
| 367 |
+
T = len(tokens) - 1
|
| 368 |
+
|
| 369 |
+
residual_fig = None
|
| 370 |
+
if res["residuals"]:
|
| 371 |
+
attn_vals = res["residuals"]["attn"]
|
| 372 |
+
ml_vals = res["residuals"]["mlp"]
|
| 373 |
+
Lmin = min(len(attn_vals), len(ml_vals))
|
| 374 |
+
df = pd.DataFrame({
|
| 375 |
+
"layer": list(range(Lmin)),
|
| 376 |
+
"attention": attn_vals[:Lmin],
|
| 377 |
+
"mlp": ml_vals[:Lmin]
|
| 378 |
+
})
|
| 379 |
+
fig = go.Figure()
|
| 380 |
+
fig.add_trace(go.Bar(x=df["layer"], y=df["attention"], name="Attention norm"))
|
| 381 |
+
fig.add_trace(go.Bar(x=df["layer"], y=df["mlp"], name="MLP norm"))
|
| 382 |
+
fig.update_layout(barmode="group", height=360)
|
| 383 |
+
residual_fig = fig
|
| 384 |
+
|
| 385 |
+
return {
|
| 386 |
+
token_display: gr.update(value=f"**Tokens:** {res['token_display']}"),
|
| 387 |
+
explanation_md: gr.update(value=res["explanation"]),
|
| 388 |
+
model_info: gr.update(value=f"Model: {model} β’ layers: {L} β’ heads: {H} β’ tokens: {len(tokens)}"),
|
| 389 |
+
attn_plot: gr.update(value=res["fig_attn"] if res.get("fig_attn") else None),
|
| 390 |
+
pca_plot: gr.update(value=res["fig_pca"] if res.get("fig_pca") else None),
|
| 391 |
+
probs_plot: gr.update(value=fig_probs(res["top_tokens"], res["top_scores"])),
|
| 392 |
+
layer_slider: gr.update(maximum=L-1, value=res["default_layer"]),
|
| 393 |
+
head_slider: gr.update(maximum=H-1, value=res["default_head"]),
|
| 394 |
+
token_step: gr.update(maximum=T, value=0),
|
| 395 |
+
residual_plot: gr.update(value=residual_fig),
|
| 396 |
+
neuron_table: gr.update(value=[[t, round(v,4)] for t,v in res["neuron_info"][0]["top_tokens"]] if res["neuron_info"] else []),
|
| 397 |
+
patch_layer: gr.update(maximum=L-1, value=0),
|
| 398 |
+
patch_pos: gr.update(maximum=T, value=0),
|
| 399 |
+
patch_from: gr.update(maximum=T, value=0),
|
| 400 |
+
state: res
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
run_btn.click(
|
| 405 |
+
run_app,
|
| 406 |
+
inputs=[input_text, model_name, simple],
|
| 407 |
+
outputs=[
|
| 408 |
+
token_display, explanation_md, model_info,
|
| 409 |
+
attn_plot, pca_plot, probs_plot,
|
| 410 |
+
layer_slider, head_slider, token_step,
|
| 411 |
+
residual_plot, neuron_table,
|
| 412 |
+
patch_layer, patch_pos, patch_from,
|
| 413 |
+
state
|
| 414 |
+
]
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# ---- SLIDER UPDATES ----
|
| 418 |
+
def update_view(res, layer, head, tok):
|
| 419 |
+
if not res or "error" in res:
|
| 420 |
+
return {
|
| 421 |
+
attn_plot: gr.update(value=None),
|
| 422 |
+
pca_plot: gr.update(value=None),
|
| 423 |
+
step_attn_plot: gr.update(value=None),
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
tokens = res["tokens"]
|
| 427 |
+
layer = min(max(0, layer), len(res["attentions"]) - 1)
|
| 428 |
+
head = min(max(0, head), res["attentions"][0].shape[0] - 1)
|
| 429 |
+
tok = min(max(0, tok), len(tokens) - 1)
|
| 430 |
+
|
| 431 |
+
att = fig_attention(res["attentions"][layer][head], tokens, f"Layer {layer} Head {head}")
|
| 432 |
+
pts = res["pca"][layer]
|
| 433 |
+
pca_fig = fig_pca(pts, tokens, highlight=tok, title=f"PCA Layer {layer}")
|
| 434 |
+
|
| 435 |
+
row = res["attentions"][layer][head][tok]
|
| 436 |
+
step_fig = go.Figure([go.Bar(x=tokens, y=row)])
|
| 437 |
+
step_fig.update_layout(title=f"Token {tok} attends to")
|
| 438 |
+
|
| 439 |
+
return {
|
| 440 |
+
attn_plot: gr.update(value=att),
|
| 441 |
+
pca_plot: gr.update(value=pca_fig),
|
| 442 |
+
step_attn_plot: gr.update(value=step_fig)
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
layer_slider.change(update_view, [state, layer_slider, head_slider, token_step],
|
| 446 |
+
[attn_plot, pca_plot, step_attn_plot])
|
| 447 |
+
head_slider.change(update_view, [state, layer_slider, head_slider, token_step],
|
| 448 |
+
[attn_plot, pca_plot, step_attn_plot])
|
| 449 |
+
token_step.change(update_view, [state, layer_slider, head_slider, token_step],
|
| 450 |
+
[attn_plot, pca_plot, step_attn_plot])
|
| 451 |
+
|
| 452 |
+
# ---- NEURON EXPLORER ----
|
| 453 |
+
def neuron_auto(res):
|
| 454 |
+
if not res or "neuron_info" not in res:
|
| 455 |
+
return gr.update(value=[])
|
| 456 |
+
rows = []
|
| 457 |
+
for item in res["neuron_info"]:
|
| 458 |
+
for t, v in item["top_tokens"]:
|
| 459 |
+
rows.append([t, round(v,4)])
|
| 460 |
+
df = pd.DataFrame(rows, columns=["token","activation"]).drop_duplicates().head(24)
|
| 461 |
+
return gr.update(value=df.values.tolist())
|
| 462 |
+
|
| 463 |
+
neuron_find_btn.click(neuron_auto, [state], [neuron_table])
|
| 464 |
+
|
| 465 |
+
def neuron_manual(res, idx):
|
| 466 |
+
if not res or "hidden" not in res:
|
| 467 |
+
return gr.update(value=[])
|
| 468 |
+
try:
|
| 469 |
+
idx = int(idx)
|
| 470 |
+
except:
|
| 471 |
+
return gr.update(value=[])
|
| 472 |
+
last = res["hidden"][-1]
|
| 473 |
+
if idx >= last.shape[1]:
|
| 474 |
+
return gr.update(value=[])
|
| 475 |
+
vals = last[:, idx]
|
| 476 |
+
tokens = res["tokens"]
|
| 477 |
+
pairs = sorted([(tokens[i], float(vals[i])) for i in range(len(tokens))],
|
| 478 |
+
key=lambda x: -abs(x[1]))[:12]
|
| 479 |
+
return gr.update(value=[[t, round(v,4)] for t,v in pairs])
|
| 480 |
+
|
| 481 |
+
neuron_idx.change(neuron_manual, [state, neuron_idx], [neuron_table])
|
| 482 |
+
|
| 483 |
+
# ---- ACTIVATION PATCHING ----
|
| 484 |
+
def patch_run(res, L, P, FP, S, model):
|
| 485 |
+
if not res or "tokens" not in res:
|
| 486 |
+
return gr.update(value=None)
|
| 487 |
+
out = activation_patch(res["tokens"], model, int(L), int(P), int(FP), float(S))
|
| 488 |
+
if "error" in out:
|
| 489 |
+
return gr.update(value=None)
|
| 490 |
+
fig = fig_probs(out["tokens"], out["scores"])
|
| 491 |
+
return gr.update(value=fig)
|
| 492 |
+
|
| 493 |
+
patch_btn.click(patch_run,
|
| 494 |
+
[state, patch_layer, patch_pos, patch_from, patch_scale, model_name],
|
| 495 |
+
[patch_output])
|
| 496 |
|
| 497 |
demo.launch()
|