I have very big train set so that Matlab. And I need to do large scale training.
Is it possible to split the training set into parts and iteratively train the network and on each iteration update the "net" instead of over-writing to it?
The code below shows the idea and it won't work. In each iteration it updates the net depending on the only the trained data set.
TF1 = 'tansig';TF2 = 'tansig'; TF3 = 'tansig';% layers of the transfer function , TF3 transfer function for the output layers
net = newff(trainSamples.P,trainSamples.T,[NodeNum1,NodeNum2,NodeOutput],{TF1 TF2 TF3},'traingdx');% Network created
net.trainfcn = 'traingdm' ; %'traingdm';
net.trainParam.epochs = 1000;
net.trainParam.min_grad = 0;
net.trainParam.max_fail = 2000; %large value for infinity
while(1) // iteratively takes 10 data point at a time.
p %=> get updated with following 10 new data points
t %=> get updated with following 10 new data points
[net,tr] = train(net, p, t,[], []);
end
Here an example of how to train a NN iteratively ( mini batch ) in matlab:
just create a toy dataset
[ x,t] = building_dataset;
minibatch size and number
M = 420
imax = 10;
lets check direct-training vs minibatch training
net = feedforwardnet(70,'trainscg');
dnet = feedforwardnet(70,'trainscg');
standard training here : 1 single call with the whole data
dnet.trainParam.epochs=100;
[ dnet tr y ] = train( dnet, x, t ,'useGPU','only','showResources','no');
a measure of error : MEA , easy to measure MSE or any other you want
dperf = mean(mean(abs(t-dnet(x))))
this is the iterative part: 1 epoch per call
net.trainParam.epochs=1;
e=1;
until we reach the previous method error, for epoch comparison
while perf(end)>dperf
very very important to randomize the data at each epoch !!
idx = randperm(size(x,2));
train iteratively with all the data chunks
for i=1:imax
k = idx(1+M*(i-1) : M*i);
[ net tr ] = train( net, x( : , k ), t( : , k ) );
end
compute the performance at each epoch
perf(e) = mean(mean(abs(t-net(x))))
e=e+1;
end
check the performance, we want a nice quasi-smooth and exp(-x) like curve
plot(perf)