I'm testing the RNN model of mxnet. The tutorial here does not work and the error message said many functions had been deprecated. I did not find the up-to-date tutorial for RNN. There are still some examples in the mxnet project. But for RNN, the examples only show how to train a model using training set. They don't show how to use the trained model to make further prediction. The training code is as follows:
model.fit(
train_data = data_train,
eval_data = data_val,
eval_metric = mx.metric.Perplexity(invalid_label),
kvstore = args.kv_store,
optimizer = args.optimizer,
optimizer_params = { 'learning_rate': args.lr,
'momentum': args.mom,
'wd': args.wd },
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch = args.num_epochs,
batch_end_callback = mx.callback.Speedometer(args.batch_size, args.disp_batches))
Does someone know how to use the trained RNN model to make inference or prediction?
I must clearify that I'm looking for how to use RNN model to make prediction, not CNN or other models.
Thank you very much for helping me!!!
Usually model is extends BaseModel class. And BaseModel has the method predict
. The method can work with same type that is used by fit
method: DataIter
with only one difference, it does not require train_data
, only eval_data
. So the actual prediction process can be implemented in a simple way like this:
result = mod.predict(dataiter.next)