Search code examples
pythonmachine-learningdeep-learningpytorchtransformer-model

Understanding batching in pytorch models


I have following model which forms one of the step in my overall model pipeline:

import torch
import torch.nn as nn

class NPB(nn.Module):
    def __init__(self, d, nhead, num_layers, dropout=0.1):
        super(NPB, self).__init__()
            
        self.te = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
            num_layers=num_layers,
        ) 

        self.t_emb = nn.Parameter(torch.randn(1, d))
        
        self.L = nn.Parameter(torch.randn(1, d)) 

        self.td = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
            num_layers=num_layers,
        ) 

        self.ffn = nn.Linear(d, 6)
    
    def forward(self, t_v, t_i):
        print("--------------- t_v, t_i -----------------")
        print('t_v: ', tuple(t_v.shape))
        print('t_i: ', tuple(t_i.shape))

        print("--------------- t_v + t_i + t_emb -----------------")
        _x = t_v + t_i + self.t_emb
        print(tuple(_x.shape))

        print("--------------- te ---------------")
        _x = self.te(_x)
        print(tuple(_x.shape))
        
        print("--------------- td ---------------")
        _x = self.td(self.L, _x)
        print(tuple(_x.shape))

        print("--------------- ffn ---------------")
        _x = self.ffn(_x)
        print(tuple(_x.shape))

        return _x

Here t_v and t_i are inputs from earlier encoder blocks. I pass them as shape of (4,256), where 256 is number of features and 4 is batch size. t_emb is temporal embedding. L represents learned matrix representing the embedding of the query. I tested this module block with following code:

t_v = torch.randn((4,256))
t_i = torch.randn((4,256))
npb = NPB(d=256, nhead=8, num_layers=2)
npb(t_v, t_i)

It outputted:

=============== NPB ===============
--------------- t_v, t_i -----------------
t_v:  (4, 256)
t_i:  (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(1, 256)
--------------- ffn ---------------
(1, 6)

I was expecting the output should be of shape (4,6), 6 values for each sample in the batch of size 6. But the output was of size (1,6). After a lot of tweaking, I tried changing t_emb and L shape from (1,d) to (4,d), since I did not wanted all sampled to share these variables (through broadcasting:

self.t_emb = nn.Parameter(torch.randn(4, d)) # [n, d] = [4, 256]     
self.L = nn.Parameter(torch.randn(4, d)) 

This gives desired output of shape (4,6:

--------------- t_v, t_i -----------------
t_v:  (4, 256)
t_i:  (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(4, 256)
--------------- ffn ---------------
(4, 6)

I have following doubts:

Q1. Exactly why changing L and t_emb shape from (1,d) to (4,d) worked? Why it did not work with (1,d) through broadcasting?
Q2. Am I doing batching right way or the output is artificially correct while under the hood its doing something different than what I am expecting (predicting 6 values for each sample in the batch of size 4)?


Solution

  • Check the docs - transformer class, transformer decoder

    For an unbatched (2 dim) input where src = (S, E) and tgt = (T, E), the output will be of shape (T, E).

    In the transformer decoder layer, the first argument is tgt which defines the output size.

    Since you define your tgt param L as torch.randn(1, d), your transformer decoder output will be of size (1, d).

    This has nothing to do with broadcasting, this is just the input/output mechanics of the transformer layer.