Lab: Grad-CAM for classifier explanations (PyTorch)¶

Goals¶

  • See where a trained network looks when it chooses a class.
  • Implement a minimal Grad-CAM on ResNet18 using the last convolutional block (layer4).
  • Discuss limits: explanations are not ground truth about physical causality.

Prerequisites¶

pip install torch torchvision matplotlib numpy pillow requests

Note¶

Grad-CAM requires gradients to flow into the feature map. We clone the input tensor and set requires_grad_(True) for the explanation pass.

In [ ]:
from io import BytesIO

import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available() and getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
    device = torch.device('mps')

weights = ResNet18_Weights.IMAGENET1K_V1
model = resnet18(weights=weights).to(device)
model.eval()

preprocess = weights.transforms()

# Demo image: try Wikimedia cat, fall back to CIFAR-10 if offline
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/320px-Cat03.jpg'
try:
    resp = requests.get(url, timeout=15)
    resp.raise_for_status()
    pil = Image.open(BytesIO(resp.content)).convert('RGB')
except Exception as e:
    print('Using CIFAR-10 fallback (network image failed):', e)
    from torchvision.datasets import CIFAR10

    pil = CIFAR10('./data', train=True, download=True)[0][0]
plt.imshow(pil)
plt.title('Input image')
plt.axis('off')
plt.show()

Preprocess and predict top-5 ImageNet classes¶

In [ ]:
x = preprocess(pil).unsqueeze(0).to(device)
with torch.no_grad():
    logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
top5 = probs.topk(5)
idc = weights.meta['categories']
for i, (p, idx) in enumerate(zip(top5.values.tolist(), top5.indices.tolist())):
    print(f'{i+1}. {idc[idx]:35s}  p={p:.4f}')

Grad-CAM on layer4¶

We hook the output of layer4, backward the logit for the predicted class, and build a heatmap from channel-wise gradients (standard Grad-CAM recipe).

In [ ]:
def grad_cam_resnet18(model_: nn.Module, input_tensor: torch.Tensor, target_layer: nn.Module, class_idx: int | None):
    model_.eval()
    acts = []

    def fwd_hook(_m, _inp, out):
        out.retain_grad()
        acts.append(out)

    h = target_layer.register_forward_hook(fwd_hook)
    input_tensor = input_tensor.clone().detach().requires_grad_(True)
    logits = model_(input_tensor)
    if class_idx is None:
        class_idx = int(logits.argmax(dim=1).item())
    model_.zero_grad(set_to_none=True)
    logits[0, class_idx].backward()
    h.remove()

    a = acts[0]
    g = a.grad
    w = g.mean(dim=(2, 3), keepdim=True)
    cam = (w * a).sum(dim=1, keepdim=True)
    cam = F.relu(cam)
    cam = cam[0, 0].detach().cpu().numpy()
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    return cam, class_idx


target = model.layer4
heatmap, cls = grad_cam_resnet18(model, x, target, class_idx=None)
print('Explaining class:', idc[cls])
In [ ]:
def overlay_heatmap(pil_img: Image.Image, heatmap: np.ndarray, alpha: float = 0.45):
    hmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_img.size, Image.Resampling.BILINEAR)
    hmap = np.asarray(hmap).astype(np.float32) / 255.0
    cmap = plt.cm.jet(hmap)[:, :, :3]
    base = np.asarray(pil_img).astype(np.float32) / 255.0
    out = (1 - alpha) * base + alpha * cmap
    return np.clip(out, 0, 1)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(heatmap, cmap='jet')
plt.title('Grad-CAM heatmap (low res)')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(overlay_heatmap(pil, heatmap))
plt.title(f'Overlay → {idc[cls]}')
plt.axis('off')
plt.tight_layout()
plt.show()

Try this¶

  1. Run Grad-CAM for the second-best class index instead of the argmax. Does the heatmap still look sensible?
  2. Failure cases: try an abstract texture or an object outside ImageNet—what happens?
  3. Compare with a wrong prediction (if you can find an image the model mislabels): does the heatmap look confident anyway?

Reference¶

Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, ICCV 2017.