Search code examples
pythonpytorch

Is there a way to generate nn.Embedding efficiently using for loop?


I'm Pytorch newbie and I wonder if I can generate nn.Embedding efficiently using for loop.

class Example(nn.Module):
    def __init__(self):
        self.A_embed_dim = 3
        self.B_embed_dim = 3
        self.C_embed_dim = 5

        self.A_embedding = nn.Embedding(
            df.A.max() + 1, self.A_embed_dim
        )
        self.B_embedding = nn.Embedding(
            df.B.max() + 1, self.B.embed_dim
        )
        self.C_embedding = nn.Embedding(
            df.C.max() + 1, self.C.embed_dim
        )

In this case, only 3 columns exist, and it is easy to generate embeddings. But if there are more columns in dataframe(ex, 16 columns A to P), the code is long and doesn't look clean. Is there a way to create mulitple nn.Embedding using for loop?


Solution

  • Yes, you can use module list or module dict to do so.

    ModuleList:

    self.embeddings = nn.ModuleList(
        [
            nn.Embedding(vocab_size, dim) 
            for vocab_size, dim in embedding_args 
        ]
    )
    # embedding_args = [(5,10), (2, 8)]