Your First PyTorch Notebook
3 of 42Deep Learning with PyTorch
Your First PyTorch Notebook
Time to build a real, working, end-to-end model. By the end of this notebook you'll have trained a small neural network on a classic dataset, evaluated it, plotted its loss curve, and saved + reloaded the checkpoint. Same shape as every PyTorch notebook you'll ever write — only the model class changes.
pip install "torch==2.5.1" "torchvision==0.20.1" \
"scikit-learn==1.5.2" "matplotlib==3.9.2"1. The Imports
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
torch.manual_seed(0)
DEVICE = "cuda" if torch.cuda.is_available() else (
"mps" if torch.backends.mps.is_available() else "cpu")
print("device:", DEVICE)
2. The Dataset
X, y = load_digits(return_X_y=True) # 1797 × 64, labels 0-9
X = X.astype("float32") / 16.0 # normalise to [0, 1]
X_tr, X_te, y_tr, y_te = train_test_split(
X, y, test_size=0.2, random_state=0, stratify=y)
train_ds = TensorDataset(torch.from_numpy(X_tr),
torch.from_numpy(y_tr).long())
test_ds = TensorDataset(torch.from_numpy(X_te),
torch.from_numpy(y_te).long())
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)
Always wrap data in a DataLoader, even for toy problems. The DataLoader handles batching, shuffling, and (later) multi-worker loading without a single line change.
3. The Model
class MLP(nn.Module):
def __init__(self, in_dim=64, hidden=128, out_dim=10, p_drop=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(),
nn.Dropout(p_drop),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Dropout(p_drop),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
return self.net(x)
model = MLP().to(DEVICE)
print(model)
print("params:",
sum(p.numel() for p in model.parameters() if p.requires_grad))
Two patterns to internalize:
- Subclass
nn.Module— never build a model from raw tensors when training. Modules track parameters automatically. - The
forwardmethod takes a batch and returns the output. PyTorch wires__call__through the module's hooks so you callmodel(x), notmodel.forward(x).
4. The Training Loop
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
def train_one_epoch(model, dl, loss_fn, opt):
model.train()
total, correct, loss_sum = 0, 0, 0.0
for xb, yb in dl:
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
logits = model(xb)
loss = loss_fn(logits, yb)
opt.zero_grad(); loss.backward(); opt.step()
loss_sum += loss.item() * xb.size(0)
correct += (logits.argmax(-1) == yb).sum().item()
total += xb.size(0)
return loss_sum / total, correct / total
@torch.no_grad()
def evaluate(model, dl, loss_fn):
model.eval()
total, correct, loss_sum = 0, 0, 0.0
for xb, yb in dl:
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
logits = model(xb)
loss = loss_fn(logits, yb)
loss_sum += loss.item() * xb.size(0)
correct += (logits.argmax(-1) == yb).sum().item()
total += xb.size(0)
return loss_sum / total, correct / total
5. Run It
EPOCHS = 20
hist = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}
for ep in range(EPOCHS):
tl, ta = train_one_epoch(model, train_dl, loss_fn, opt)
vl, va = evaluate(model, test_dl, loss_fn)
hist["train_loss"].append(tl); hist["train_acc"].append(ta)
hist["val_loss"].append(vl); hist["val_acc"].append(va)
print(f"ep {ep:02} | train {tl:.3f}/{ta:.3f} | val {vl:.3f}/{va:.3f}")
Expected: validation accuracy around 0.97-0.98 by epoch 10 on the digits dataset. The training loop ran on CPU in seconds; the same code on GPU is just a device change.
6. Plot the Curves
fig, ax = plt.subplots(1, 2, figsize=(10, 3.5))
ax[0].plot(hist["train_loss"], label="train")
ax[0].plot(hist["val_loss"], label="val")
ax[0].set_title("loss"); ax[0].legend()
ax[1].plot(hist["train_acc"], label="train")
ax[1].plot(hist["val_acc"], label="val")
ax[1].set_title("accuracy"); ax[1].legend()
plt.tight_layout(); plt.show()
Always plot training and validation together. Diverging curves signal overfitting; staying-flat curves signal under-capacity or a learning-rate problem. This 4-line plot catches a surprising fraction of training bugs.
7. Save and Load
torch.save(model.state_dict(), "mlp_digits.pt")
# Restore later (or in a different process)
loaded = MLP().to(DEVICE)
loaded.load_state_dict(torch.load("mlp_digits.pt", map_location=DEVICE))
loaded.eval()
Two rules:
- Save the
state_dict, not the model. Pickling the whole module ties the file to your code's class layout; the dict is portable. - Always pass
map_location. Otherwise a checkpoint trained on GPU 0 fails to load on a CPU-only box.