Loading data


Datasets and data loaders

Central to data ingestion and preprocessing are datasets and data loaders.

torch comes equipped with a bag of datasets related to, mostly, image recognition and natural language processing (e.g., mnist_dataset()), which can be iterated over by means of dataloaders:

# ...
ds <- mnist_dataset(
  download = TRUE, 
  transform = function(x) {
    x <- x$to(dtype = torch_float())/256

dl <- dataloader(ds, batch_size = 32, shuffle = TRUE)

for (b in enumerate(dl)) {
  # ...

Cf. vignettes/examples/mnist-cnn.R for a complete example.

What if you want to train on a different dataset? In these cases, you subclass Dataset, an abstract container that needs to know how to iterate over the given data. To that purpose, your subclass needs to implement .getitem(), and say what should be returned when the data loader is asking for the next batch.

In .getitem(), you can implement whatever preprocessing you require. Additionally, you should implement .length(), so users can find out how many items there are in the dataset.

While this may sound complicated, it is not at all. The base logic is straightforward – complexity will, naturally, correlate with how involved your preprocessing is. To provide you with a simple but functional prototype, here we show how to create your own dataset to train on Allison Horst's penguins.

A custom dataset



Datasets are R6 classes created using the dataset() constructor. You can pass a name and various member functions. Among those should be initialize(), to create instance variables, .getitem(), to indicate how the data should be returned, and .length(), to say how many items we have.

In addition, any number of helper functions can be defined.

Here, we assume the penguins have already been loaded, and all preprocessing consists in removing lines with NA values, transforming factors to numbers starting from 0, and converting from R data types to torch tensors.

In .getitem, we essentially decide how this data is going to be used: All variables besides species go into x, the predictor, and species will constitute y, the target. Predictor and target are returned in a list, to be accessed as batch[[1]] and batch[[2]] during training.

penguins_dataset <- dataset(
  name = "penguins_dataset",
  initialize = function() {
    self$data <- self$prepare_penguin_data()
  .getitem = function(index) {
    x <- self$data[index, 2:-1]
    y <- self$data[index, 1]$to(torch_long())
    list(x, y)
  .length = function() {
  prepare_penguin_data = function() {
    input <- na.omit(penguins) 
    # conveniently, the categorical data are already factors
    input$species <- as.numeric(input$species)
    input$island <- as.numeric(input$island)
    input$sex <- as.numeric(input$sex)
    input <- as.matrix(input)

Let’s create the dataset , query for it’s length, and look at its first item:

tuxes <- penguins_dataset()

To be able to iterate over tuxes, we need a data loader (we override the default batch size of 1):

dl <-tuxes %>% dataloader(batch_size = 8)

Calling .length() on a data loader (as opposed to a dataset) will return the number of batches we have:


And we can create an iterator to inspect the first batch:

iter <- dl$.iter()
b <- iter$.next()

To train a network, we can use enumerate to iterate over batches.

Training with data loaders

Our example network is very simple. (In reality, we would want to treat island as the categorical variable it is, and either one-hot-encode or embed it.)

net <- nn_module(
  initialize = function() {
    self$fc1 <- nn_linear(7, 32)
    self$fc2 <- nn_linear(32, 3)
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2() %>% 
      nnf_log_softmax(dim = 1)

model <- net()

We still need an optimizer:

optimizer <- optim_sgd(model$parameters, lr = 0.01)

And we’re ready to train:

for (epoch in 1:10) {
  l <- c()
  for (b in enumerate(dl)) {
    output <- model(b[[1]])
    loss <- nnf_nll_loss(output, b[[2]])
    l <- c(l, loss$item())
  cat(sprintf("Loss at epoch %d: %3f\n", epoch, mean(l)))