I have encoded
my images(masks) with dimensions (img_width x img_height x 1) with OneHotEncoder
in this way:
import numpy as np
def OneHotEncoding(im,n_classes):
one_hot = np.zeros((im.shape[0], im.shape[1], n_classes),dtype=np.uint8)
for i, unique_value in enumerate(np.unique(im)):
one_hot[:, :, i][im == unique_value] = 1
return one_hot
After doing some data manipulation with deep learning, softmax
activation function will result in probabilities instead of 0
and 1
values, so in my Decoder I wanted to implement the following approach:
- Threshold the output to obtain
0
or1
only.- Multiply each channel with weight equal to the channel index.
- take the max between labels along channels axis.
import numpy as np
arr = np.array([
[[0.1,0.2,0,5],[0.2,0.4,0.7],[0.3,0.5,0.8]],
[[0.3,0.6,0 ],[0.4,0.9,0.1],[0 ,0 ,0.2]],
[[0.7,0.1,0.1],[0,6,0.1,0.1],[0.6,0.6,0.3]],
[[0.6,0.2,0.3],[0.4,0.5,0.3],[0.1,0.2,0.7]]
])
# print(arr.dtype,arr.shape)
def oneHotDecoder(img):
# Thresholding
img[img<0.5]=0
img[img>=0.5]=1
# weigts of the labels
img = [i*img[:,:,i] for i in range(img.shape[2])]
# take the max label
img = np.amax(img,axis=2)
print(img.shape)
return img
arr2 = oneHotDecoder(arr)
print(arr2)
- How to git rid of the error:
line 15, in oneHotDecoder img[img<0.5]=0 TypeError: '<' not supported between instances of 'list' and 'float'
- Is there any other issues in my implementaion that you suggest to improve?
Thanks in advance.
You have typos with commas and dots with some of your items (e.g. your first list should be [0.1, 0.2, 0.5]
instead of [0.1, 0.2, 0, 5]
).
The fixed list is:
l = [
[[0.1,0.2,0.5],[0.2,0.4,0.7],[0.3,0.5,0.8]],
[[0.3,0.6,0 ],[0.4,0.9,0.1],[0 ,0 ,0.2]],
[[0.7,0.1,0.1],[0.6,0.1,0.1],[0.6,0.6,0.3]],
[[0.6,0.2,0.3],[0.4,0.5,0.3],[0.1,0.2,0.7]]
]
Then you could do:
np.array(l) # np.dstack(l) would work as well
Which would yield:
array([[[0.1, 0.2, 0.5],
[0.2, 0.4, 0.7],
[0.3, 0.5, 0.8]],
[[0.3, 0.6, 0. ],
[0.4, 0.9, 0.1],
[0. , 0. , 0.2]],
[[0.7, 0.1, 0.1],
[0.6, 0.1, 0.1],
[0.6, 0.6, 0.3]],
[[0.6, 0.2, 0.3],
[0.4, 0.5, 0.3],
[0.1, 0.2, 0.7]]])