Lab: Convolutions and feature maps (PyTorch)¶

Goals¶

  • Apply fixed $3\times3$ filters (edge / blur) with conv2d and see how outputs respond to structure in the image.
  • Inspect learned filters from the first layer of a small trained network on MNIST.
  • Visualize activation maps (channels) after one conv layer.

Prerequisites¶

pip install torch torchvision matplotlib numpy

In [ ]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST

plt.rcParams['figure.figsize'] = (9, 3)

# Grayscale digit
ds = MNIST(root='./data', train=True, download=True)
img_uint8 = ds[0][0]  # PIL
x = np.asarray(img_uint8, dtype=np.float32) / 255.0
x = torch.from_numpy(x)[None, None, :, :]  # 1,1,28,28
print('Input shape:', tuple(x.shape))

plt.imshow(x[0, 0].numpy(), cmap='gray')
plt.title('MNIST input')
plt.axis('off')
plt.show()

1. Hand-designed filters¶

F.conv2d uses correlation: the kernel is not flipped relative to the classical “convolution” definition in some math texts. That matches what modern CNN libraries do in practice.

Sobel kernels approximate gradients (edges). A box filter averages locally (blur).

In [ ]:
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
box = torch.ones(1, 1, 3, 3, dtype=torch.float32) / 9.0

gx = F.conv2d(x, sobel_x, padding=1)
gy = F.conv2d(x, sobel_y, padding=1)
mag = torch.sqrt(gx ** 2 + gy ** 2 + 1e-6)
blur = F.conv2d(x, box, padding=1)

fig, ax = plt.subplots(1, 4, figsize=(12, 3))
for a, t, title in zip(
    ax,
    [gx, gy, mag, blur],
    ['Sobel X', 'Sobel Y', 'Gradient magnitude', 'Box blur 3×3'],
):
    a.imshow(t[0, 0].detach().numpy(), cmap='gray')
    a.set_title(title)
    a.axis('off')
plt.tight_layout()
plt.show()

2. Multiple filters at once¶

Stack filters along the output channel dimension. One conv layer can therefore produce many “views” of the same input (early CNN layers often act like oriented edge detectors).

In [ ]:
kernels = torch.cat([sobel_x, sobel_y, box], dim=0)  # 3,1,3,3
multi = F.conv2d(x, kernels, padding=1)
print('Output shape (N, C_out, H, W):', tuple(multi.shape))

fig, ax = plt.subplots(1, 3, figsize=(9, 3))
for i in range(3):
    ax[i].imshow(multi[0, i].detach().numpy(), cmap='gray')
    ax[i].set_title(f'Channel {i}')
    ax[i].axis('off')
plt.tight_layout()
plt.show()

3. Learned first-layer filters (quick MNIST train)¶

We train a tiny CNN for one epoch to make first-layer weights meaningful, then plot a subset of kernels.

In [ ]:
from torch.utils.data import DataLoader
from torchvision import transforms

tfm = transforms.Compose([transforms.ToTensor()])
train_loader = DataLoader(MNIST('./data', train=True, download=False, transform=tfm), batch_size=128, shuffle=True)

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(32 * 7 * 7, 10)

    def forward(self, t):
        t = self.pool(F.relu(self.conv1(t)))
        t = self.pool(F.relu(self.conv2(t)))
        return self.fc(torch.flatten(t, 1))


m = TinyCNN()
opt = torch.optim.Adam(m.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
m.train()
max_batches = 150  # set to None to run a full epoch (slower)
for bi, (batch_x, batch_y) in enumerate(train_loader):
    if max_batches is not None and bi >= max_batches:
        break
    opt.zero_grad(set_to_none=True)
    loss = loss_fn(m(batch_x), batch_y)
    loss.backward()
    opt.step()

W = m.conv1.weight.detach().cpu()  # (16,1,3,3)
fig, ax = plt.subplots(2, 8, figsize=(12, 3))
for i, a in enumerate(ax.flat):
    a.imshow(W[i, 0].numpy(), cmap='gray', vmin=-0.8, vmax=0.8)
    a.axis('off')
plt.suptitle('Learned conv1 kernels (~150 batches demo)')
plt.tight_layout()
plt.show()

4. Feature map after conv1¶

Pass the digit through conv1 only (no ReLU/pool) and display several channel activations.

In [ ]:
m.eval()
with torch.no_grad():
    feat = m.conv1(x)

fig, ax = plt.subplots(2, 8, figsize=(12, 3))
for i, a in enumerate(ax.flat):
    a.imshow(feat[0, i].numpy(), cmap='viridis')
    a.axis('off')
plt.suptitle('Feature maps: conv1 output channels')
plt.tight_layout()
plt.show()

5. Try this¶

  1. Change padding from 1 to 0 for Sobel. What happens at the borders?
  2. Increase blur kernel to 5×5. How do edges change?
  3. Compare gradient magnitude on a CIFAR-10 RGB channel vs grayscale.