I am going through some CNN articles. I see that they transform the input image to (channel, width, height)
.
A code example taken from MXNET CNN Tutorial.
def transform(data, label):
# 2,0,1 means channels,width, height
return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
Can any one explain why do we do this transformation?
There are several image formats for 2-dimensional convolution, the main ones are:
NCHW
format, i.e., (batch, channels, height, width)
.NHWC
format, i.e., (batch, height, width, channels)
.They are basically equivalent and can be easily converted from one to another, though there is evidence that certain low-level implementations perform more efficiently when a particular data format is used (see this question).
Computational engines usually accept both formats, but have different defaults, e.g.,
NHWC
by default.NCHW
format.NHWC
by default.MXNet accepts both formats too, but the default is NCHW
:
The default data layout is
NCHW
, namely(batch_size, channel, height, width)
. We can choose other layouts such asNHWC
.
This default is pretty much the only reason to reshape the tensors, simply to avoid layout
argument in the network.