Search code examples
imagematlabmnist

Reading MNIST Image Database binary file in MATLAB


I have a binary file from the MNIST image database, renamed as "myfile.dat". This consists of a set of 4 unsigned 32-bit integers, followed by a chain of unsigned 8-bit integers. I want to read this file and store its contents as an array. Here is my code:

file_id = fopen('myfile.dat', 'rb');
data = fread(file_id, 'int');
size(data)
class(data)

And the output is:

ans =

    2502           1


ans =

double

The size of (2502, 1) is as expected. But why is it telling me that the data is double, when I have specified it to be int?

I know what the first few numbers should be and the output data is not as expected. I have also tried int32, uint and uint32 which give the same problem.


Solution

  • With our comments, the way MATLAB reads in 4 integers at a time is in little-endian format while the file format of the MNIST database file is in big-endian. As such, when we read in the first four bytes of the file, the bytes are 0x00, 0x00, 0x08, 0x03 as we expect. However, MATLAB will read this in as 0x03, 0x08, 0x00, 0x00. When converting this to an integer, we actually will get 50855936 which is not what we want.

    A hack to solve this would be when we need to read in several bytes at a time, we need to ensure that we specify the uint8 data type. This will store each byte into separate elements in an array. We can then compute the necessary number we need by bit-shifting each byte over by a specified amount, then summing the results. We will need to do this for the first few bytes in the data file.

    Alternatively, we can use the swapbytes method as you have stated in your comments and that'll give us exactly the same thing. You can just read 1 byte that is of a uint32 type, then swap the order of the bytes so that the bytes are in big-endian format. You'll need to bear in mind that even when you read in the data as uint32, the number will be stored in MATLAB as double so you will need to cast the number before going into swapbytes.

    Once we get to the actual image data, we can read in numRows x numCols bytes at a time, then reshape the array so that it becomes an image. We can store each image into a cell array. Without further ado, here's the code.

    clear all;
    close all;
    
    %//Open file
    fid = fopen('t10k-images-idx3-ubyte', 'r');
    
    %//Read in magic number
    %//A = fread(fid, 4, 'uint8');
    %//magicNumber = sum(bitshift(A', [24 16 8 0]));
    
    %//OR
    A = fread(fid, 1, 'uint32');
    magicNumber = swapbytes(uint32(A));
    
    %//Read in total number of images
    %//A = fread(fid, 4, 'uint8');
    %//totalImages = sum(bitshift(A', [24 16 8 0]));
    
    %//OR
    A = fread(fid, 1, 'uint32');
    totalImages = swapbytes(uint32(A));
    
    %//Read in number of rows
    %//A = fread(fid, 4, 'uint8');
    %//numRows = sum(bitshift(A', [24 16 8 0]));
    
    %//OR
    A = fread(fid, 1, 'uint32');
    numRows = swapbytes(uint32(A));
    
    %//Read in number of columns
    %//A = fread(fid, 4, 'uint8');
    %//numCols = sum(bitshift(A', [24 16 8 0]));
    
    %// OR
    A = fread(fid, 1, 'uint32');
    numCols = swapbytes(uint32(A));
    
    %//For each image, store into an individual cell
    imageCellArray = cell(1, totalImages);
    for k = 1 : totalImages
        %//Read in numRows*numCols pixels at a time
        A = fread(fid, numRows*numCols, 'uint8');
        %//Reshape so that it becomes a matrix
        %//We are actually reading this in column major format
        %//so we need to transpose this at the end
        imageCellArray{k} = reshape(uint8(A), numCols, numRows)';
    end
    
    %//Close the file
    fclose(fid);
    

    If you check the number of rows and columns (stored as numRows, numCols), the magic number (stored as magicNumber) and the total number of images (stored as totalImages), this should be equal to 2051, 28, 28 and 10000 respectively. After this code, the kth element of imageCellArray will store the kth digit in the MNIST database. If you do imshow(imageCellArray{k});, where k is any integer between 1 to 10000, you should be able to see a digit.

    Also, one final note: As reading in the matrix data will be in double, we need to cast this so that it's uint8 as the images are of that type from the database.

    Good luck!