Search code examples
pythonrecurrent-neural-networkmxnet

mxnet : how to make prediction using a trained RNN model


I'm testing the RNN model of mxnet. The tutorial here does not work and the error message said many functions had been deprecated. I did not find the up-to-date tutorial for RNN. There are still some examples in the mxnet project. But for RNN, the examples only show how to train a model using training set. They don't show how to use the trained model to make further prediction. The training code is as follows:

model.fit(
    train_data          = data_train,
    eval_data           = data_val,
    eval_metric         = mx.metric.Perplexity(invalid_label),
    kvstore             = args.kv_store,
    optimizer           = args.optimizer,
    optimizer_params    = { 'learning_rate': args.lr,
                            'momentum': args.mom,
                            'wd': args.wd },
    initializer         = mx.init.Xavier(factor_type="in", magnitude=2.34),
    num_epoch           = args.num_epochs,
    batch_end_callback  = mx.callback.Speedometer(args.batch_size, args.disp_batches))

Does someone know how to use the trained RNN model to make inference or prediction?

I must clearify that I'm looking for how to use RNN model to make prediction, not CNN or other models.

Thank you very much for helping me!!!


Solution

  • Usually model is extends BaseModel class. And BaseModel has the method predict. The method can work with same type that is used by fit method: DataIter with only one difference, it does not require train_data, only eval_data. So the actual prediction process can be implemented in a simple way like this:

    result = mod.predict(dataiter.next)