Exercise 4: Mini-batches and DataLoaders#
In this exercise, you’ll make the loss versus epoch plots from the previous section, but using a PyTorch DataLoader.
Here are the imports:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import Dataset, TensorDataset, DataLoader
And here’s the dataset:
boston_prices_df = pd.read_csv(
"data/boston-house-prices.csv", sep="\s+", header=None,
names=["CRIM", "ZN", "INDUS", "CHAS", "NOX", "RM", "AGE", "DIS", "RAD", "TAX", "PTRATIO", "B", "LSTAT", "MEDV"],
# Pre-normalize the data so we can ignore that part of modeling
boston_prices_df = (boston_prices_df - boston_prices_df.mean()) / boston_prices_df.std()
features = torch.tensor(boston_prices_df.drop(columns="MEDV").values).float()
targets = torch.tensor(boston_prices_df["MEDV"]).float()[:, np.newaxis]
The exercise#
See PyTorch’s documentation on Datasets and DataLoaders to create a TensorDataset from the features
and targets
, and then load that into two DataLoaders, one with batch_size=len(features)
(i.e. one big batch) and the other with batch_size=50
. When complete, the dataloader
can be used for iteration over mini-batches like this:
for features_subset, targets_subset in dataloader:
You can use the following code to plot the loss versus epoch from the two batch sizes:
fig, ax = plt.subplots()
ax.plot(range(1, len(loss_vs_epoch) + 1), loss_vs_epoch, label="one big batch")
ax.plot(range(1, len(loss_vs_epoch_batched) + 1), loss_vs_epoch_batched, label="mini-batches")
ax.set_xlabel("number of epochs")
ax.set_ylabel("loss ($\chi^2$)")
ax.legend(loc="upper right")
Something to think about#
Are you following this course because you have a particular ML problem in mind? Does its data fit in memory?
Note that you can make subclasses of the generic Dataset class like this:
class CustomDataset(Dataset):
def __len__(self):
return np.iinfo(np.int64).max # or number of batches, if known
def __getitem__(self, batch_index):
if can_get_more_data():
features, targets = get_more_data()
return features, targets
raise IndexError("no more data")
For your particular dataset, how would you fit your data-loading procedure into the class above?