Search code examples
pytorchembedding

Retrieving original data from PyTorch nn.Embedding


I'm passing a dataframe with 5 categories (ex. car, bus, ...) into nn.Embedding.

When I do embedding.parameters(), I can see that there are 5tensors but how do I know which index corresponds to the original input (ex. car, bus, ...)?


Solution

  • You can't as tensors are unnamed (only dimensions can be named, see PyTorch's Named Tensors). You have to keep the names in separate data container, for example (4 categories here):

    import pandas as pd
    import torch
    
    df = pd.DataFrame(
        {
            "bus": [1.0, 2, 3, 4, 5],
            "car": [6.0, 7, 8, 9, 10],
            "bike": [11.0, 12, 13, 14, 15],
            "train": [16.0, 17, 18, 19, 20],
        }
    )
    
    df_data = df.to_numpy().T
    df_names = list(df)
    
    embedding = torch.nn.Embedding(df_data.shape[0], df_data.shape[1])
    embedding.weight.data = torch.from_numpy(df_data)
    

    Now you can simply use it with any index you want:

    index = 1
    embedding(torch.tensor(index)), df_names[index]
    

    This would give you (tensor[6, 7, 8, 9, 10], "car") so the data and respective column name.