Search code examples
pythonmachine-learningdeep-learningmxnet

MXNet print intermediate symbol values


How do i find the actual numerical values held in an MXNet symbol.

Suppose I have,

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y, 

if x = [100,200] and y=[300,400], I want to print:

z = [400,600],

sort of like tensorflow's eval() method


Solution

  • After looking around a bit, I found you can do this by:

    x = mx.sym.Variable('x')
    y = mx.sym.Variable('y')
    z = x + y
    executor = z.bind(mx.cpu(), {'x': mx.nd.array([100,200]), 'y':mx.nd.array([300,400])})
    output = executor.forward()
    

    will give you the 'output':

    [<NDArray 2 @cpu(0)>]
    

    To print the actual numerical output:

    print output[0].asnumpy()
    array([ 400.,  600.], dtype=float32)