Search code examples
pythonmxnet

MXNet NDArray with dtype string?


I want to train a CV network with MXNet and I created a custom Dataset class (works well) but I'd now like this class to return the name of the image file (a string).

Problem is MXNet refuses to batch strings.

So I tried to create a NDArray from my string but NDArray won't take strings as dtype. What should I do?

>>> import numpy as np
>>> import mxnet.ndarray as nd
>>> nd.array(["blabla"])
ValueError: could not convert string to float: blabla

Solution

  • @Manon Rmn, the solution for that is to have a custom Batchify function.

    import mxnet as mx
    import numpy as np
    from gluonnlp.data.batchify import Tuple, Stack
    
    class MyDataset(mx.gluon.data.Dataset):
    
        def __init__(self, size=10):
            self.size = size
            self.data = [np.array((1,2,3,4))]*size
            self.data_text = ["this is a test"]*size
    
        def __getitem__(self, idx):
            return self.data[idx], self.data_text[idx]
    
        def __len__(self):
            return self.size
    
    dataset = MyDataset()
    print(dataset[0])
    
    
    class List:
        def __call__(self, x):
            return x
    
    data = mx.gluon.data.DataLoader(dataset, batchify_fn=Tuple(Stack(), List()), batch_size=2)
    
    for matrix, text in data:
        break
    print(matrix)
    print(text)
    
    (array([1, 2, 3, 4]), 'this is a test')
    
    [[1 2 3 4]
     [1 2 3 4]]
    <NDArray 2x4 @cpu_shared(0)>
    ['this is a test', 'this is a test']
    

    I have issued a PR to get that into GluonNLP https://github.com/dmlc/gluon-nlp/pull/812