Lab: Grad-CAM for classifier explanations (PyTorch)¶
Goals¶
- See where a trained network looks when it chooses a class.
- Implement a minimal Grad-CAM on
ResNet18using the last convolutional block (layer4). - Discuss limits: explanations are not ground truth about physical causality.
Prerequisites¶
pip install torch torchvision matplotlib numpy pillow requests
Note¶
Grad-CAM requires gradients to flow into the feature map. We clone the input tensor and set requires_grad_(True) for the explanation pass.
In [ ]:
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available() and getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
device = torch.device('mps')
weights = ResNet18_Weights.IMAGENET1K_V1
model = resnet18(weights=weights).to(device)
model.eval()
preprocess = weights.transforms()
# Demo image: try Wikimedia cat, fall back to CIFAR-10 if offline
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/320px-Cat03.jpg'
try:
resp = requests.get(url, timeout=15)
resp.raise_for_status()
pil = Image.open(BytesIO(resp.content)).convert('RGB')
except Exception as e:
print('Using CIFAR-10 fallback (network image failed):', e)
from torchvision.datasets import CIFAR10
pil = CIFAR10('./data', train=True, download=True)[0][0]
plt.imshow(pil)
plt.title('Input image')
plt.axis('off')
plt.show()
Preprocess and predict top-5 ImageNet classes¶
In [ ]:
x = preprocess(pil).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
top5 = probs.topk(5)
idc = weights.meta['categories']
for i, (p, idx) in enumerate(zip(top5.values.tolist(), top5.indices.tolist())):
print(f'{i+1}. {idc[idx]:35s} p={p:.4f}')
Grad-CAM on layer4¶
We hook the output of layer4, backward the logit for the predicted class, and build a heatmap from channel-wise gradients (standard Grad-CAM recipe).
In [ ]:
def grad_cam_resnet18(model_: nn.Module, input_tensor: torch.Tensor, target_layer: nn.Module, class_idx: int | None):
model_.eval()
acts = []
def fwd_hook(_m, _inp, out):
out.retain_grad()
acts.append(out)
h = target_layer.register_forward_hook(fwd_hook)
input_tensor = input_tensor.clone().detach().requires_grad_(True)
logits = model_(input_tensor)
if class_idx is None:
class_idx = int(logits.argmax(dim=1).item())
model_.zero_grad(set_to_none=True)
logits[0, class_idx].backward()
h.remove()
a = acts[0]
g = a.grad
w = g.mean(dim=(2, 3), keepdim=True)
cam = (w * a).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam[0, 0].detach().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam, class_idx
target = model.layer4
heatmap, cls = grad_cam_resnet18(model, x, target, class_idx=None)
print('Explaining class:', idc[cls])
In [ ]:
def overlay_heatmap(pil_img: Image.Image, heatmap: np.ndarray, alpha: float = 0.45):
hmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_img.size, Image.Resampling.BILINEAR)
hmap = np.asarray(hmap).astype(np.float32) / 255.0
cmap = plt.cm.jet(hmap)[:, :, :3]
base = np.asarray(pil_img).astype(np.float32) / 255.0
out = (1 - alpha) * base + alpha * cmap
return np.clip(out, 0, 1)
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(heatmap, cmap='jet')
plt.title('Grad-CAM heatmap (low res)')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(overlay_heatmap(pil, heatmap))
plt.title(f'Overlay → {idc[cls]}')
plt.axis('off')
plt.tight_layout()
plt.show()
Try this¶
- Run Grad-CAM for the second-best class index instead of the argmax. Does the heatmap still look sensible?
- Failure cases: try an abstract texture or an object outside ImageNet—what happens?
- Compare with a wrong prediction (if you can find an image the model mislabels): does the heatmap look confident anyway?
Reference¶
Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, ICCV 2017.