Note
Click here to download the full example code
Loading data in PyTorch¶
PyTorch features extensive neural network building blocks with a simple, intuitive, and stable API. PyTorch includes packages to prepare and load common datasets for your model.
Introduction¶
At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset. Libraries in PyTorch offer built-in high-quality datasets for you to use in torch.utils.data.Dataset. These datasets are currently available in:
with more to come.
Using the Yesno dataset from torchaudio.datasets.YESNO
, we will
demonstrate how to effectively and efficiently load data from a PyTorch
Dataset
into a PyTorch DataLoader
.
Setup¶
Before we begin, we need to install torchaudio
to have access to the
dataset.
pip install torchaudio
Steps¶
- Import all necessary libraries for loading our data
- Access the data in the dataset
- Loading the data
- Iterate over the data
- [Optional] Visualize the data
1. Import necessary libraries for loading our data¶
For this recipe, we will use torch
and torchaudio
. Depending on
what built-in datasets you use, you can also install and import
torchvision
or torchtext
.
import torch
import torchaudio
2. Access the data in the dataset¶
The Yesno dataset in torchaudio
features sixty recordings of one
individual saying yes or no in Hebrew; with each recording being eight
words long (read more here).
torchaudio.datasets.YESNO
creates a dataset for YesNo.
torchaudio.datasets.YESNO(
root,
url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
folder_in_archive='waves_yesno',
download=False,
transform=None,
target_transform=None)
Each item in the dataset is a tuple of the form: (waveform, sample_rate, labels).
You must set a root
for the Yesno dataset, which is where the
training and testing dataset will exist. The other parameters are
optional, with their default values shown. Here is some additional
useful info on the other parameters:
# * ``download``: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
# * ``transform``: Using transforms on your data allows you to take it from its source state and transform it into data that’s joined together, de-normalized, and ready for training. Each library in PyTorch supports a growing list of transformations.
# * ``target_transform``: A function/transform that takes in the target and transforms it.
#
# Let’s access our Yesno data:
#
# A data point in Yesno is a tuple (waveform, sample_rate, labels) where labels
# is a list of integers with 1 for yes and 0 for no.
yesno_data = torchaudio.datasets.YESNO('./', download=True)
# Pick data point number 3 to see an example of the the yesno_data:
n = 3
waveform, sample_rate, labels = yesno_data[n]
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(waveform, sample_rate, labels))
When using this data in practice, it is best practice to provision the data into a “training” dataset and a “testing” dataset. This ensures that you have out-of-sample data to test the performance of your model.
3. Loading the data¶
Now that we have access to the dataset, we must pass it through
torch.utils.data.DataLoader
. The DataLoader
combines the dataset
and a sampler, returning an iterable over the dataset.
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True)
4. Iterate over the data¶
Our data is now iterable using the data_loader
. This will be
necessary when we begin training our model! You will notice that now
each data entry in the data_loader
object is converted to a tensor
containing tensors representing our waveform, sample rate, and labels.
for data in data_loader:
print("Data: ", data)
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
break
5. [Optional] Visualize the data¶
You can optionally visualize your data to further understand the output
from your DataLoader
.
import matplotlib.pyplot as plt
print(data[0][0].numpy())
plt.figure()
plt.plot(waveform.t().numpy())
Congratulations! You have successfully loaded data in PyTorch.
Learn More¶
Take a look at these other recipes to continue your learning:
Total running time of the script: ( 0 minutes 0.000 seconds)