Search code examples
pythoniteratormxnet

example for custom iterator not working


I am following the description and example for creating a custom iterator as described here: http://mxnet.io/tutorials/basic/data.html

The following code produces a ValueError:

mod.fit(data_iter, num_epoch=5)

ValueError: Shape of labels 0 does not match shape of predictions 1

My questions:

  • Can anyone reproduce this problem?
  • Does anyone know a solution?

I am using jupyter on a mac with everything freshly installed, including python... I have also tested on python directly using:

Python 3.6.1 |Anaconda custom (x86_64)| (default, May 11 2017, 13:04:09) 
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)] on darwin

Code:

import mxnet as mx
import os
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


import numpy as np
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
    print([batch.data, batch.label, batch.pad])

#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 20]


#lets save `data` into a csv file first and try reading it back
np.savetxt('data.csv', data, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
    print([batch.data, batch.pad])

#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 20]

class SimpleIter(mx.io.DataIter):
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = zip(data_names, data_shapes)
        self._provide_label = zip(label_names, label_shapes)
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data,\
                                                        self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label,\
                                                        self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration


import mxnet as mx
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())

#['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias', 'softmax_label']
#['softmax_output']



import logging
logging.basicConfig(level=logging.INFO)

n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['softmax_label'], [(n,)],
                  [lambda s: np.random.randint(0, num_classes, s)])

mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)

Error:

ValueError                                Traceback (most recent call last)
 <ipython-input-57-6ceb7dd11508> in <module>()
         9 
        10 mod = mx.mod.Module(symbol=net)
        ---> 11 mod.fit(data_iter, num_epoch=5)/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/base_module.py in fit(self, train_data, eval_data, eval_metric, epoch_end_callback, batch_end_callback, kvstore, optimizer, optimizer_params, eval_end_callback, eval_batch_end_callback, initializer, arg_params, aux_params, allow_missing, force_rebind, force_init, begin_epoch, num_epoch, validation_metric, monitor)
            493                     end_of_batch = True
            494 
        --> 495                 self.update_metric(eval_metric, data_batch.label)
            496 
            497                 if monitor is not None:
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/module.py in update_metric(self, eval_metric, labels)
            678             Typically ``data_batch.label``.
            679         """
        --> 680         self._exec_group.update_metric(eval_metric, labels)
            681 
            682     def _sync_params_from_devices(self):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/executor_group.py in update_metric(self, eval_metric, labels)
            561             labels_ = OrderedDict(zip(self.label_names, labels_slice))
            562             preds = OrderedDict(zip(self.output_names, texec.outputs))
        --> 563             eval_metric.update_dict(labels_, preds)
            564 
            565     def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in update_dict(self, label, pred)
             89             label = label.values()
             90 
        ---> 91         self.update(label, pred)
             92 
             93     def update(self, labels, preds):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in update(self, labels, preds)
            369             Predicted values.
            370         """
        --> 371         check_label_shapes(labels, preds)
            372 
            373         for label, pred_label in zip(labels, preds):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in check_label_shapes(labels, preds, shape)
             22     if label_shape != pred_shape:
             23         raise ValueError("Shape of labels {} does not match shape of "
        ---> 24                          "predictions {}".format(label_shape, pred_shape))
             25 
             26 
ValueError: Shape of labels 0 does not match shape of predictions 1

Solution

  • it is a python version hell problem. i was able to get it to work with everything working and compiled on python 2.7. the python 3.x versions seem to create the problem and the error message is not really helpful...