Search code examples
luaneural-networkconv-neural-networktorchtraining-data

Torch - Using optim package with CNN


I am trying to train a CNN using optim package. I am using this code, obtained from a video tutorial (see around 24:01), as a reference. This particular example uses a normal neural network. I also used this reference.

My code for the CNN can be found here. The problem is that if the input X is not a single image, I get an error:

In 14 module of nn.Sequential:
/home/ubuntu/torch/install/share/lua/5.1/nn/Linear.lua:57: size mismatch at /tmp/luarocks_cutorch-scm-1-1695/cutorch/lib/THC/generic/THCTensorMathBlas.cu:52

When I don't use GPU, the error becomes more clear:

size mismatch, [1024 x 2048], [409600]

For convenience, I have copied my complete model:

-- Model begins
local model = nn.Sequential()
model:add(nn.SpatialConvolution(1, 64, 3, 3))
model:add(nn.ReLU())
model:add(nn.SpatialMaxPooling(2, 2, 2, 2))

model:add(nn.SpatialConvolution(64, 128, 3, 3))
model:add(nn.ReLU())
model:add(nn.SpatialMaxPooling(2, 2, 2, 2))

model:add(nn.SpatialConvolution(128, 256, 3, 3))
model:add(nn.ReLU())
model:add(nn.SpatialMaxPooling(2, 2, 2, 2))

model:add(nn.SpatialConvolution(256, 512, 3, 3))
model:add(nn.ReLU())
model:add(nn.SpatialMaxPooling(2, 2, 2, 2))

model:add(nn.View(-1))
model:add(nn.Linear(2048, 1024))
model:add(nn.ReLU())
model:add(nn.Linear(1024, 5))
model:add(nn.LogSoftMax())
-- Model ends

1) Is it right to use nn.View(-1)?

2) I understand that the input to the first Linear layer is not 2048 when the input X has more than one image. But how does optim.sgd work properly with the entire training set used as input (X), in the case of the normal neural network given in first reference?

3) What would be the best way to use optim.sgd (or preferably optim.adam) in this problem?


Solution

  • The module nn.View has not been written properly. This layer should be written as: nn.View(−1, out_channels * out_height * out_width). In the above case, it would be nn.View(-1, 2048). Source: See the comments section of this.