I designed the Graph Attention Network.
However, during the operations inside the layer, the values of features becoming equal.
class GraphAttentionLayer(nn.Module):
## in_features = out_features = 1024
def __init__(self, in_features, out_features, dropout):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
self.a1 = nn.Parameter(torch.zeros(size=(out_features, 1)))
self.a2 = nn.Parameter(torch.zeros(size=(out_features, 1)))
nn.init.xavier_normal_(self.W.data, gain=1.414)
nn.init.xavier_normal_(self.a1.data, gain=1.414)
nn.init.xavier_normal_(self.a2.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU()
def forward(self, input, adj):
h = torch.mm(input, self.W)
a_input1 = torch.mm(h, self.a1)
a_input2 = torch.mm(h, self.a2)
a_input = torch.mm(a_input1, a_input2.transpose(1, 0))
e = self.leakyrelu(a_input)
zero_vec = torch.zeros_like(e)
attention = torch.where(adj > 0, e, zero_vec) # most of values is close to 0
attention = F.softmax(attention, dim=1) # all values are 0.0014 which is 1/707 (707^2 is the dimension of attention)
attention = F.dropout(attention, self.dropout)
return attention
The dimension of 'attention' is (707 x 707) and I observed the value of attention is near 0 before the softmax.
After the softmax, all values are 0.0014 which is 1/707.
I wonder how to keep the values normalized and prevent this situation.
Thanks
Since you say this happens during training I would assume it is at the start. With random initialization you often get near identical values at the end of the network during the start of the training process.
When all values are more or less equal the output of the softmax will be 1/num_elements
for every element, so they sum up to 1 over the dimension you chose. So in your case you get 1/707
as all the values, which just sounds to me your weights are freshly initialized and the outputs are mostly random at this stage.
I would let it train for a while and observe if this changes.