Search code examples
luaneural-networktorchlstmrecurrent-neural-network

How to train LSTM for a simplest function recognition


I'm learning LSTM networks and decided to try synthetic test. I want LSTM network fed by some points (x,y) to distinguish between three basic functions:

  • line: y = k*x + b
  • parabola: y = k*x^2 + b
  • sqrt: y = k*sqrt(x) + b

I'm using lua + torch.

Dataset is totally virtual - it is created on-the-fly at the 'dataset' object. When training cycle asks for another minibatch of samples, function mt.__index returns sample, created dynamically. It randomly selects on of the three described functions and picks some random points for them.

Idea is that LSTM network would learn some features to recognize what kind of a function do last points belong to.

Full yet simple source script included:

require "torch"
require "nn"
require "rnn"

-- hyper-parameters 
batchSize = 8
rho = 5 -- sequence length
hiddenSize = 100
outputSize = 3
lr = 0.001

-- Initialize synthetic dataset
-- dataset[index] returns table of the form: {inputs, targets}
-- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
-- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
local dataset = {}
dataset.size = function (self)
  return 1000
end
local mt = {}
mt.__index = function (self, i)
  local class = math.random(3)

  local t = torch.Tensor(3):zero()
  t[class] = 1
  local targets = {}
  for i = 1,batchSize do table.insert(targets, class) end

  local inputs = {}
  local k = math.random()
  local b = math.random()*5

  -- Line
  if class == 1 then
    for i = 1,batchSize do
      local x = math.random()*10 + 5
      local y = k*x + b
      input = torch.Tensor(2)
      input[1] = x
      input[2] = y
      table.insert(inputs, input)
    end

  -- Parabola
  elseif class == 2 then
    for i = 1,batchSize do
      local x = math.random()*10 + 5
      local y = k*x*x + b
      input = torch.Tensor(2)
      input[1] = x
      input[2] = y
      table.insert(inputs, input)
    end

  -- Sqrt
  else
    for i = 1,batchSize do
      local x = math.random()*5 + 5
      local y = k*math.sqrt(x) + b
      input = torch.Tensor(2)
      input[1] = x
      input[2] = y
      table.insert(inputs, input)
    end
  end

  return { inputs, targets }
end -- dataset.__index meta function
setmetatable(dataset, mt)

-- Initialize random number generator
math.randomseed( os.time() )

-- build simple recurrent neural network
local model = nn.Sequencer(
  nn.Sequential()
    :add( nn.LSTM(2, hiddenSize, rho) )
    :add( nn.Linear(hiddenSize, outputSize) )
    :add( nn.LogSoftMax() )
)

print(model)

-- build criterion
local criterion = nn.SequencerCriterion( nn.ClassNLLCriterion() )

-- training
model:training()

local epoch = 1
while true do

  print ("Epoch "..tostring(epoch).." started")

  for iteration = 1, dataset:size() do
    -- 1. Load minibatch of samples
    local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
    local inputs = sample[1]
    local targets = sample[2]

    -- 2. Perform forward run and calculate error
    local outputs = model:forward(inputs)
    local err = criterion:forward(outputs, targets)

    print(string.format("Epoch %d Iteration %d Error = %f", epoch, iteration, err))

    -- 3. Backward sequence through model(i.e. backprop through time)
    local gradOutputs = criterion:backward(outputs, targets)
    -- Sequencer handles the backwardThroughTime internally
    model:backward(inputs, gradOutputs)
    model:updateParameters(lr)
    model:zeroGradParameters()     

  end -- for dataset

  epoch = epoch + 1
end -- while epoch

The problem is: network does not converge. Could you share any ideas what I'm doing wrong?


Solution

  • I decided to post my own answer since I solved the problem and received good results.

    First about applicability of LSTM to this kind of task. As stated, LSTM is good to deal with time series. You may also think of line, parabola and sqrt as a kind of a time function. So LSTM is totally applicable here. Say you're receiving experimental results, one vector at a moment, and you want to find out what kind of function could describe your series?

    One may argue that in the code above we always get feed NN with a fixed number of points (i.e. batch_size). So why use LSTM? Maybe try to use instead some Linear or Convolution Network?

    Well, don't forget - this is a synthetic test. In a real life application you may feed NN with some significant amount of data points and expect it to recognize the form of function.

    For instance in the code below we train NN with 8 points at once (batch_size), but when we test NN we use only 4 points (test_size).

    And we get pretty good results: after about 1000 iterations NN gives ~99% of correct answers.

    But one-layer NN is not a magician. It can't learn any features if we change the form of functions on each iterations. I.e. in the original code k and b are changed at every request to dataset. What we should do is to generate them at startup and do not change.

    So the working code below:

    require "torch"
    require "nn"
    require "rnn"
    
    -- Initialize random number generator
    math.randomseed( os.time() )
    
    -- hyper-parameters 
    batch_size = 8
    test_size = 4
    rho = 5 -- sequence length
    hidden_size = 100
    output_size = 3
    learning_rate = 0.001
    
    -- Initialize synthetic dataset
    -- dataset[index] returns table of the form: {inputs, targets}
    -- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
    -- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
    local dataset = {}
    dataset.k = math.random()
    dataset.b = math.random()*5
    dataset.size = function (self)
      return 1000
    end
    local mt = {}
    mt.__index = function (self, i)
      local class = math.random(3)
    
      local t = torch.Tensor(3):zero()
      t[class] = 1
      local targets = {}
      for i = 1,batch_size do table.insert(targets, class) end
    
      local inputs = {}
      local k = self.k
      local b = self.b
    
      -- Line
      if class == 1 then
        for i = 1,batch_size do
          local x = math.random()*10 + 5
          local y = k*x + b
          input = torch.Tensor(2)
          input[1] = x
          input[2] = y
          table.insert(inputs, input)
        end
    
      -- Parabola
      elseif class == 2 then
        for i = 1,batch_size do
          local x = math.random()*10 + 5
          local y = k*x*x + b
          input = torch.Tensor(2)
          input[1] = x
          input[2] = y
          table.insert(inputs, input)
        end
    
      -- Sqrt
      else
        for i = 1,batch_size do
          local x = math.random()*5 + 5
          local y = k*math.sqrt(x) + b
          input = torch.Tensor(2)
          input[1] = x
          input[2] = y
          table.insert(inputs, input)
        end
      end
    
      return { inputs, targets }
    end -- dataset.__index meta function
    setmetatable(dataset, mt)
    
    
    -- build simple recurrent neural network
    local model = nn.Sequencer(
      nn.Sequential()
        :add( nn.LSTM(2, hidden_size, rho) )
        :add( nn.Linear(hidden_size, output_size) )
        :add( nn.LogSoftMax() )
    )
    
    print(model)
    
    -- build criterion
    local criterion = nn.SequencerCriterion( nn.ClassNLLCriterion() )
    
    
    local epoch = 1
    local err = 0
    local pos = 0
    local N = math.floor( dataset:size() * 0.1 )
    
    while true do
    
      print ("Epoch "..tostring(epoch).." started")
    
      -- training
      model:training()
      for iteration = 1, dataset:size() do
        -- 1. Load minibatch of samples
        local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
        local inputs = sample[1]
        local targets = sample[2]
    
        -- 2. Perform forward run and calculate error
        local outputs = model:forward(inputs)
        local _err = criterion:forward(outputs, targets)
    
        print(string.format("Epoch %d (pos=%f) Iteration %d Error = %f", epoch, pos, iteration, _err))
    
        -- 3. Backward sequence through model(i.e. backprop through time)
        local gradOutputs = criterion:backward(outputs, targets)
        -- Sequencer handles the backwardThroughTime internally
        model:backward(inputs, gradOutputs)
        model:updateParameters(learning_rate)
        model:zeroGradParameters()     
    
      end -- for training
    
      -- Testing
      model:evaluate()
      err = 0
      pos = 0
      for iteration = 1, N do
        -- 1. Load minibatch of samples
        local sample = dataset[ math.random(dataset:size()) ]
        local inputs = sample[1]
        local targets = sample[2]
        -- Drop last points to reduce to test_size
        for i = #inputs, test_size, -1 do
          inputs[i] = nil
          targets[i] = nil
        end
    
        -- 2. Perform forward run and calculate error
        local outputs = model:forward(inputs)
        err = err + criterion:forward(outputs, targets)
    
        local p = 0
        for i = 1, #outputs do
          local _, oi = torch.max(outputs[i], 1)
          if oi[1] == targets[i] then p = p + 1 end
        end
        pos = pos + p/#outputs
    
      end -- for testing
      err = err / N
      pos = pos / N
      print(string.format("Epoch %d testing results: pos=%f err=%f", epoch, pos, err))
    
      if (pos > 0.95) then break end
    
      epoch = epoch + 1
    end -- while epoch