Search code examples
luatorch

In Lua Torch, the product of two zero matrices has nan entries


I have encountered a strange behavior of the torch.mm function in Lua/Torch. Here is a simple program that demonstrates the problem.

iteration = 0;
a = torch.Tensor(2, 2);
b = torch.Tensor(2, 2);
prod = torch.Tensor(2,2);

a:zero();
b:zero();

repeat
   prod = torch.mm(a,b);
   ent  = prod[{2,1}]; 
   iteration = iteration + 1;
until ent ~= ent

print ("error at iteration " .. iteration);
print (prod);

The program consists of one loop, in which the program multiplies two zero 2x2 matrices and tests if entry ent of the product matrix is equal to nan. It seems that the program should run forever since the product should always be equal to 0, and hence ent should be 0. However, the program prints:

error at iteration 548   
0.000000 0.000000
nan nan
[torch.DoubleTensor of size 2x2]

Why is this happening?

Update:

  1. The problem disappears if I replace prod = torch.mm(a,b) with torch.mm(prod,a,b), which suggests that something is wrong with the memory allocation.
  2. My version of Torch was compiled without BLAS & LAPACK libraries. After I recompiled torch with OpenBLAS, the problem disappeared. However, I am still interested in its cause.

Solution

  • The part of code that auto-generates the Lua wrapper for torch.mm can be found here.

    When you write prod = torch.mm(a,b) within your loop it corresponds to the following C code behind the scenes (generated by this wrapper thanks to cwrap):

    /* this is the tensor that will hold the results */
    arg1 = THDoubleTensor_new(); 
    THDoubleTensor_resize2d(arg1, arg5->size[0], arg6->size[1]);
    arg3 = arg1;
    /* .... */
    luaT_pushudata(L, arg1, "torch.DoubleTensor");
    /* effective matrix multiplication operation that will fill arg1 */
    THDoubleTensor_addmm(arg1,arg2,arg3,arg4,arg5,arg6);
    

    So:

    • a new result tensor is created and resized with the proper dimensions,
    • but this new tensor is NOT initialized, i.e. there is no calloc or explicit fill here so it points to junk memory and could contain NaN-s,
    • this tensor is pushed on the stack so as to be available on the Lua side as the return value.

    The last point means that this returned tensor is different from the initial prod one (i.e. within the loop, prod shadows the initial value).

    On the other hand calling torch.mm(prod,a,b) does use your initial prod tensor to store the results (behind the scenes there is no need to create a dedicated tensor in that case). Since in your code snippet you do not initialize / fill it with given values it could also contain junk.

    In both cases the core operation is a gemm multiplication like C = beta * C + alpha * A * B, with beta=0 and alpha=1. The naive implementation looks like that:

      real *a_ = a;
      for(i = 0; i < m; i++)
      {
        real *b_ = b;
        for(j = 0; j < n; j++)
        {
          real sum = 0;
          for(l = 0; l < k; l++)
            sum += a_[l*lda]*b_[l];
          b_ += ldb;
          /*
           * WARNING: beta*c[j*ldc+i] could give NaN even if beta=0
           *          if the other operand c[j*ldc+i] is NaN!
           */
          c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
        }
        a_++;
      }
    

    Comments are mine.

    So:

    1. with torch.mm(a,b): at each iteration, a new result tensor is created without being initialized (it could contain NaN-s). So every iteration presents a risk of returning NaN-s (see above warning),
    2. with torch.mm(prod,a,b): there is the same risk since you do not initialized the prod tensor. BUT: this risk only exists at the first iteration of the repeat / until loop since right after prod is filled with 0-s and re-used for the subsequent iterations.

    So this is why you do not observe a problem here (it is less frequent).

    In case 1: this should be improved at the Torch level, i.e. make sure the wrapper initializes the output (e.g. with THDoubleTensor_fill(arg1, 0);).

    In case 2: you should initialize prod initially and use the torch.mm(prod,a,b) construct to avoid any NaN problem.

    --

    EDIT: this problem is now fixed (see this pull request).