Search code examples
pythonpytorch

how to remove all occurances of a node in an edges in pytorch


I have an edge_index and want to remove one element n = 3 of it

edges = torch.tensor([
    [0, 1, 1, 2, 2, 3],
    [1, 0, 2, 1, 3, 2]])




nodes = torch.unique(edges)
n = nodes[-1]  # I want to remove this from edge_index

I tried this but it's not working

arr = edges[~(edges == [n]).all(axis=1)]

Solution

  • Change the code to

    arr = edges[(~(edges == n)).all(axis=0).unsqueeze(0).repeat(edges.shape[0], 1)].reshape(edges.shape[0], -1)