Search code examples
pythontheanorounding-error

How to replace maximal values in columns by 1.0 using Theano?


I have a real-valued matrix in Theano and I want to generate another matrix such that in each column of the new matrix I have one 1.0 and 0.0 otherwise. The 1.0 should indicate the location of the maximal value in the column of the input matrix.

For example. If the following matrix is used as the input

1.0   2.0   3.0   5.0
2.1   0.0   4.0   0.0
0.0   3.0   1.0   4.0

The following matrix has to be generated as the output:

0.0   0.0   0.0   1.0
1.0   0.0   1.0   0.0
0.0   1.0   0.0   0.0

The solution that I use so far is as follows:

tmp = T.max(inp, axis = 0).dimshuffle('x',0)
out = T.switch(T.eq(tmp, inp), 1.0, 0.0)

This solution seems to work but I am not sure how robust it is. The main concern is that I compare if the current value is equal exactly to the maximum in the column. Can it happen that, because of some "rounding" error, the maximal value will not be recognized us such?


Solution

  • Try this:

    out = T.eye(3)[T.argmax(inp, axis=0)].T  # replace 3 with number of rows
    

    This will output:

    0.0   0.0   0.0   1.0
    1.0   0.0   1.0   0.0
    0.0   1.0   0.0   0.0