Search code examples
rdeep-learningmxnet

How to save a model when using MXnet


I am using MXnet for training a CNN (in R) and I can train the model without any error with the following code:

model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )

But as this process is time-consuming, I run it on a server during the night and I want to save the model for the purpose of using it after finishing the training.

I used:

save(list = ls(), file="mymodel.RData")

and

mx.model.save("mymodel", 10)

But none of them can save the model! for example when I load the "mymodel.RData", I can not predict the labels for the test set!

Another example is when I load the "mymodel.RData" and try to plot it with the following code:

graph.viz(model$symbol$as.json())

I get the following error:

Error in model$symbol$as.json() : external pointer is not valid

Can anybody give me a solution for saving and then loading this model for future use?

Thanks


Solution

  • You can save the model by

    model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )