Search code examples
pythonmxnet

Mxnet datatype is float64, but keeps saying that it's float32


I'm a pytorch and tensorflow user. I came across Mxnet in order to use AWS sagemaker's elastic inference.

Mxnet gluon dataset api seems to be very similar to pytorch's dataset.

class CustomDataset(mxnet.gluon.data.Dataset):
    def __init__(self):
        self.train_df = pd.read_csv('/shared/KTUTOR/test_summary_data.csv')
    def __getitem__(self, idx):
        return mxnet.nd.array(self.train_df.loc[idx, ['TT', 'TF', 'FT', 'FF']], dtype='float64'), mxnet.nd.array(self.train_df.loc[idx, ['p1']], dtype='float64')
    def __len__(self):
        return len(self.train_df)

I defined my customdataset like above, and set the datatypes as float64.

test_data = mxnet.gluon.data.DataLoader(CustomDataset(), batch_size=8, shuffle=True, num_workers=2)

I wrapped my dataset with DataLoader, and no error up to this point. The error rises when I pass the data to the network.

for epoch in range(1):
for data, label in test_data:
    print(data.dtype)
    print(label.dtype)
    with autograd.record():
        output = net(data)
        loss = softmax_cross_entropy(output, label)
    loss.backward()
    trainer.step(batch_size)

The error rise in net(data), and the error message looks like below.

MXNetError: [07:53:55] src/operator/contrib/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected float64, got float32
Stack trace:
  [bt] (0) /root/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x4b09db) 
[0x7f00f96519db] ...

When I print the type of data and label, they are all float64, but MXNet tells me that the datatype of the data is float32. Can someone explain why this is happening ? Thanks much in advance.


Solution

  • Is your network in float64 or float32? Try to cast the weights to float64:

    net = net.cast('float64')

    That being said, in my experience it is not common to train DL models in float64, float32 and float16 are much more common for training. And MXNet allows you to easily use float16 precision for training either explicitly, or automatically with the AMP tool (Automatic Mixed Precision)