Search code examples
pythonmachine-learningnlpdeep-learningmxnet

How do you concatenate symbols in mxnet


I have 2 symbols in MXNet and would like to concatenate them. How can i do this:

eg: a = [100,200], b = [300,400], Id like to get

c = [100,200,300,400]


Solution

  • You can do this by using the "Concat" method.

    a = mx.sym.Variable('a')
    b = mx.sym.Variable('b')
    c = mx.sym.Concat(a,b,dim=0)
    

    To verify this, you can execute your symbol using an executor to check:

    e = c.bind(mx.cpu(), {'a': mx.nd.array([100,200]), 'b':mx.nd.array([300,400])})
    y = e.forward()
    y[0].asnumpy()
    

    You will get the output:

    array([ 100.,  200.,  300.,  400.], dtype=float32)