Search code examples
rsamplemnist

Take a sample of the MNIST dataset


I am working with the MNIST dataset and performing different classification methods on it, but my runtimes are ridiculous, so I am looking for a way to maybe use an a portion of the training part of the set, but keep the test portion at 10K. I have tried a number of different options but nothing is working.

I need to take a sample either from the entire set, or lower the training x and y from 60000 to maybe 20000.

My current code:


library(keras)

mnist <- dataset_mnist()

train_images <- mnist$train$x 
train_labels <- mnist$train$y 
test_images <- mnist$test$x   
test_labels <- mnist$test$y 

I have tried to use the sample() function and other types of splits to no avail.


Solution

  • In the following example I'm downloading MNIST myself and loading it through reticulate / numpy. Shouldn't make much difference. When you want to get a sample with sample(), you usually take a sample of indices you'll use for subsetting. To get a balanced sample, you might want to draw a specific number or proportion from each label group:

    library(reticulate)
    library(dplyr)
    
    # Download MNIST dataset as numpy npz, 
    # load through reticulate, build something along the lines of keras::dataset_mnist() output
    np <- import("numpy")
    mnist_npz <- curl::curl_download("https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", "mnist.npz")
    mnist_np <- np$load(mnist_npz)
    
    mnist_lst <- list(
      train = list(
        x = mnist_np[["x_train"]],
        y = mnist_np[["y_train"]]
      ),
      test = list(
        x = mnist_np[["x_test"]],
        y = mnist_np[["y_test"]]
      )
    )
    
    train_images <- mnist_lst$train$x 
    train_labels <- mnist_lst$train$y 
    test_images  <- mnist_lst$test$x   
    test_labels  <- mnist_lst$test$y 
    
    # sample row indices, 
    # 100 per class to keep the dataset balanced
    sample_idx <- 
      train_labels |>
      tibble(y = _) |>
      tibble::rowid_to_column("idx") |>
      slice_sample(n = 100, by = y ) |>
      arrange(idx) |>
      pull(idx)
    
    # use sample_idx for subsetting
    train_images_sample <- train_images[sample_idx,,] 
    train_labels_sample <- train_labels[sample_idx]
    
    str(train_images_sample)
    #>  int [1:1000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
    str(train_labels_sample)
    #>  int [1:1000(1d)] 9 7 5 6 8 7 7 5 2 9 ...
    
    # original label distribution
    table(train_labels)
    #> train_labels
    #>    0    1    2    3    4    5    6    7    8    9 
    #> 5923 6742 5958 6131 5842 5421 5918 6265 5851 5949
    
    # sample distribution
    table(train_labels_sample)
    #> train_labels_sample
    #>   0   1   2   3   4   5   6   7   8   9 
    #> 100 100 100 100 100 100 100 100 100 100
    

    Created on 2024-03-29 with reprex v2.1.0