Search code examples
deep-learningpytorchneural-networkhuggingface-transformerstransformer-model

How to do the fusion of two parallel branch in an encoder design?


It seems I am not designing my encoder correctly, that is why I need the expert opinion on this since I am beginner to transformers and DL model design.

I have two different types of Transformers networks in an encoder as follows:

enter image description here

The embedding dimension of each branch is 256 and they are fused by a Linear layer

 self.fusion_head = nn.Linear(2*self.num_features, self.num_features) #self_num_features = 256

I have a forward feature function in my encoder

def transformer_forward(self,x):
    """

    :param x: The embeddings + pos_embed
    :return:
    """

    x_t1 = self.transformer_type1(x)  # torch.Size([1, 1280, 256])

    x_t2 = self.transformer_type2.forward(x)  # torch.Size([1, 1280, 256])

    # x = x_t1 + x_t2
    x = torch.cat([x_t1,x_t2],dim=2)

    x = self.fusion_head(x)

    return x

However, after training the models and loading the checkpoints, I realized that the self.fusion_head is place after transformer_type1 modules

. ... 3.0.fn.to_qkv.weight', 'module.encoder.transformer_type1.3.layers.3.0.fn.to_out.0.weight', 'module.encoder.transformer_type1.3.layers.3.0.fn.to_out.0.bias', 'module.encoder.transformer_type1.3.layers.3.1.norm.weight', 'module.encoder.transformer_type1.3.layers.3.1.norm.bias', 'module.encoder.transformer_type1.3.layers.3.1.fn.net.0.weight', 'module.encoder.transformer_type1.3.layers.3.1.fn.net.0.bias', 'module.encoder.transformer_type1.3.layers.3.1.fn.net.3.weight', 'module.encoder.transformer_type1.3.layers.3.1.fn.net.3.bias', 'module.encoder.mlp_head.0.weight', 'module.encoder.mlp_head.0.bias', 'module.encoder.mlp_head.1.weight', 'module.encoder.mlp_head.1.bias', 'module.encoder.fusion_head.weight', 'module.encoder.fusion_head.bias', 'module.encoder.transformer_type2.pos_embed', 'module.encoder.transformer_type2.patch_embed.proj.weight', 'module.encoder.transformer_type2.patch_embed.proj.bias', 'module.encoder.transformer_type2.patch_embed.norm.weight', 'module.encoder.transformer_type2.patch_embed.norm.bias', 'module.encoder.transformer_type2.blocks.0.norm1.weight', 'module.encoder.transformer_type2.blocks.0.norm1.bias', 'module.encoder.transformer_type2.blocks.0.filter.complex_weight', 'module.encoder.transformer_type2.blocks.0.norm2.weight', 'module.encoder.transformer_type2.blocks.0.norm2.bias', 'module.encoder.transformer_type2.blocks.0.mlp.fc1.weight', ...

Is the placing of this concatenation layer (i.e., fusion_head correct in the forward function? why it is placed after transformet_type1? Should not fusion_head layer be after both transformet_type1 and transformer_type2 in terms of order?


Solution

  • What you see there is the implementation of __repr__ of nn.Module. It prints the modules that you registered inside the __init__ method of your network. That the forward method does not define the order, makes sense because you can call modules several times or not at all inside forward.

    import torch
    from torch import nn
    class Bla(nn.Module):
        def __init__(self):
            super().__init__()
            self.b1 = nn.Linear(256, 128)
            self.b2 = nn.GELU()
            self.b3 = nn.Linear(128,5)
            self.b0 = nn.Embedding(100,256)
        def forward(self, x):
            emb = self.b0(x)
            emb = self.b1(emb)
            emb = self.b2(emb)
            emb = self.b3(emb)
            return emb
    
    net = Bla()
    print(net)
    

    Output:

    Bla(
      (b1): Linear(in_features=256, out_features=128, bias=True)
      (b2): GELU(approximate='none')
      (b3): Linear(in_features=128, out_features=5, bias=True)
      (b0): Embedding(100, 256)
    )