# %% [markdown]
# # Self-attention as a heatmap
#
# The Python behind the "attention heatmap" explainer on the *Transformers,
# RAG & LLMs* chapter. Self-attention scores every token against every other
# token with a scaled dot product, then softmaxes each row so it sums to 1.
# We use tiny hand-picked embeddings so the pattern is readable.
#
# Requirements: `numpy`, `matplotlib`.

# %%
import numpy as np
import matplotlib.pyplot as plt

# %% [markdown]
# ## 1. Tokens and (toy) embeddings
#
# Real models learn these vectors; here they are hand-tuned so related tokens
# ("weld", "bead", "porosity") point in similar directions.

# %%
tokens = ["The", "weld", "bead", "has", "porosity"]
E = np.array([
    [0.05, 0.00, 0.00, 0.00],
    [0.85, 0.25, 0.15, 0.05],
    [0.75, 0.35, 0.25, 0.10],
    [0.10, 0.05, 0.00, 0.00],
    [0.70, 0.15, 0.90, 0.20],
])

# %% [markdown]
# ## 2. Scaled dot-product attention
#
# `scores = (Q . K^T) / sqrt(d)`, then a row-wise softmax. With Q = K = E this
# is self-attention; the result is how much each token (row) attends to every
# other token (column).

# %%
def softmax(x, axis=-1):
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

d = E.shape[1]
scores = E @ E.T / np.sqrt(d)
attention = softmax(scores, axis=1)

np.set_printoptions(precision=2, suppress=True)
print("attention weights (rows sum to 1):\n", attention)

# %% [markdown]
# ## 3. Draw the heatmap
#
# Darker = stronger attention. Each row sums to 1.

# %%
fig, ax = plt.subplots(figsize=(5.5, 5))
im = ax.imshow(attention, cmap="Blues", vmin=0, vmax=attention.max())
ax.set_xticks(range(len(tokens)), tokens, rotation=45, ha="right")
ax.set_yticks(range(len(tokens)), tokens)
ax.set_xlabel("attends to"); ax.set_ylabel("query token")
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f"{attention[i, j]:.2f}", ha="center", va="center",
                color="white" if attention[i, j] > 0.4 else "black")
fig.colorbar(im, fraction=0.046, pad=0.04)
plt.title("Self-attention weights"); plt.tight_layout(); plt.show()

# %% [markdown]
# ## Your turn
#
# - Edit `tokens` and `E` (keep them the same length) and re-run.
# - Add a temperature: divide `scores` by a constant `T` before softmax. Small
#   T makes attention sharp (near one-hot); large T makes it flat.
# - Split the 4-D embeddings into two 2-D "heads", attend separately, and
#   concatenate — that is multi-head attention.
