I have an edge_index and want to remove one element n = 3
of it
edges = torch.tensor([
[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2]])
nodes = torch.unique(edges)
n = nodes[-1] # I want to remove this from edge_index
I tried this but it's not working
arr = edges[~(edges == [n]).all(axis=1)]
Change the code to
arr = edges[(~(edges == n)).all(axis=0).unsqueeze(0).repeat(edges.shape[0], 1)].reshape(edges.shape[0], -1)