DataLoader and Dataset in Pytorch

Jimmy (xiaoke) Shen
2 min readMay 15, 2021

PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples. From [1]

Dataset

[1] gave a pretty good example of FashionMNIST

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)

Let’s explore this code.

FashionMNIST source code

If we run the above code on the terminal, it will download the data first and then we can explore

>>> training_data
Dataset FashionMNIST
Number of datapoints: 60000
Root location: data
Split: Train
StandardTransform
Transform: ToTensor()
>>> type(training_data)
<class 'torchvision.datasets.mnist.FashionMNIST'>
>>> first = training_data[0]
>>> type(first)
<class 'tuple'>
>>> len(first)
2
>>> type(first[0])
<class 'torch.Tensor'>
>>> type(first[1])
<class 'int'>
>>> a, b = first
>>> a.shape
torch.Size([1, 28, 28])
>>> b
9

Source code of FashionMnist can be found here

The source code of Dataset is pretty short, you can find it from here

Essentially, for the subclass of Dataset, we need to implement

__len__ and __getitem__, so we can use

dataset = MyDataset()
length_of_dataset = len(dataset)
input_data, input_data_label = dataset[i]

Dataloader

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

DataLoader is an iterable that abstracts this complexity for us in an easy API. From [1]

>>> from torch.utils.data import DataLoader
>>> train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
>>> type(train_dataloader)
<class 'torch.utils.data.dataloader.DataLoader'>
>>> len(train_dataloader)
938
>>> len(training_data)
60000
>>> 60000/64
937.5

Reference

[1]https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

--

--