DataLoader and Dataset in Pytorch
PyTorch provides two data primitives:
torch.utils.data.DataLoader
andtorch.utils.data.Dataset
that allow you to use pre-loaded datasets as well as your own data.Dataset
stores the samples and their corresponding labels, andDataLoader
wraps an iterable around theDataset
to enable easy access to the samples. From [1]
Dataset
[1] gave a pretty good example of FashionMNIST
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
Let’s explore this code.
FashionMNIST source code
If we run the above code on the terminal, it will download the data first and then we can explore
>>> training_data
Dataset FashionMNIST
Number of datapoints: 60000
Root location: data
Split: Train
StandardTransform
Transform: ToTensor()
>>> type(training_data)
<class 'torchvision.datasets.mnist.FashionMNIST'>
>>> first = training_data[0]
>>> type(first)
<class 'tuple'>
>>> len(first)
2
>>> type(first[0])
<class 'torch.Tensor'>
>>> type(first[1])
<class 'int'>
>>> a, b = first
>>> a.shape
torch.Size([1, 28, 28])
>>> b
9
Source code of FashionMnist can be found here
The source code of Dataset is pretty short, you can find it from here
Essentially, for the subclass of Dataset, we need to implement
__len__ and __getitem__, so we can use
dataset = MyDataset()
length_of_dataset = len(dataset)
input_data, input_data_label = dataset[i]
Dataloader
The
Dataset
retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’smultiprocessing
to speed up data retrieval.
DataLoader
is an iterable that abstracts this complexity for us in an easy API. From [1]
>>> from torch.utils.data import DataLoader
>>> train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
>>> type(train_dataloader)
<class 'torch.utils.data.dataloader.DataLoader'>
>>> len(train_dataloader)
938
>>> len(training_data)
60000
>>> 60000/64
937.5
Reference
[1]https://pytorch.org/tutorials/beginner/basics/data_tutorial.html