Search code examples
deep-learningpytorchlinear-algebrachainer

PyTorch and Chainer implementations of the Linear layer- are they equivalent?


I want to use a Linear, Fully-Connected Layer as one of the input layers in my network. The input has shape (batch_size, in_channels, num_samples). It is based on the Tacotron paper: https://arxiv.org/pdf/1703.10135.pdf, the Enocder prenet part. It feels to me as if Chainer and PyTorch have different implementations of the Linear layer - are they really performing the same operations or am I misunderstanding something?

In PyTorch, behavior of the Linear layer follows the documentations: https://pytorch.org/docs/0.3.1/nn.html#torch.nn.Linear according to which, the shape of the input and output data are as follows:

Input: (N,∗,in_features) where * means any number of additional dimensions

Output: (N,∗,out_features) where all but the last dimension are the same shape as the input.

Now, let's try creating a linear layer in pytorch and performing the operation. I want an output with 8 channels, and the input data will have 3 channels.

import numpy as np
import torch
from torch import nn
linear_layer_pytorch = nn.Linear(3, 8)

Let's create some dummy input data of shape (1, 4, 3) - (batch_size, num_samples, in_channels:

data = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4], dtype=np.float32).reshape(1, 4, 3)
data_pytorch = torch.from_numpy(data)

and finally, perform the operation:

results_pytorch = linear_layer_pytorch(data_pytorch)
results_pytorch.shape

the shape of the output is as follows: Out[27]: torch.Size([1, 4, 8]) Taking a look at the source of the PyTorch implementation:

def linear(input, weight, bias=None):
    # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
    r"""
    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

    Shape:

        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Weight: :math:`(out\_features, in\_features)`
        - Bias: :math:`(out\_features)`
        - Output: :math:`(N, *, out\_features)`
    """
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret

It transposes the weight matrix that is passed to it, broadcasts it along the batch_size axis and performs a matrix multiplications. Having in mind how a linear layer works, I imagine it as 8 nodes, connected through a synapse, holding a weight, with every channel in an input sample, thus in my case it has 3*8 weights. And that is exactly the shape I see in debugger (8, 3).

Now, let's jump to Chainer. The Chainer's linear layer documentation is available here: https://docs.chainer.org/en/stable/reference/generated/chainer.links.Linear.html#chainer.links.Linear. According to this documentation, the Linear layer wraps the function linear, which according to the docs, flattens the input along the non-batch dimensions and the shape of it's weight matrix is (output_size, flattend_input_size)

import chainer 
linear_layer_chainer = chainer.links.Linear(8)
results_chainer = linear_layer_chainer(data)
results_chainer.shape
Out[21]: (1, 8)

Creating the layer as linear_layer_chainer = chainer.links.Linear(3, 8) and calling it causes a size mismatch. So in case of chainer, I have gotten a totally different results, because this time around I have a weight matrix that is of shape (8, 12) and my results have a shape of (1, 8). So now, here is my question : since the results are clearly different,both the weight matrices and the outputs have different shapes, how can I make them equivalent and what should be the desired output? In the PyTorch implementation of Tacotron it seems that the PyTorch approach is used as is (https://github.com/mozilla/TTS/blob/master/layers/tacotron.py) - Prenet. If that is the case, how can I make the Chainer produce the same results (I have to implement this in Chainer). I will be grateful for any inshight, sorry that the post has gotten this long.


Solution

  • Chainer Linear layer (a bit frustratingly) does not apply the transformation to the last axis. Chainer flattens the rest of the axes. Instead you need to provide how many batch axes there are, documentation which is 2 in your case:

    # data.shape == (1, 4, 3)
    results_chainer = linear_layer_chainer(data, n_batch_axes=2)
    # 2 batch axes (1,4) means you apply linear to (..., 3)
    # results_chainer.shape == (1, 4, 8)
    

    You can also use l(data, n_batch_axes=len(data.shape)-1) to always apply to the last dimension which is the default behaviour in PyTorch, Keras etc.