Search code examples
tensorflowmandelbrot

tensorflow variable.assign_add function is mysterious in this example


I'm trying to learn tensorflow from working examples online but came across the example where i'm literally wondered how it works. Can any explain the maths behind this particular function of tensorflow and how [ns] get its value out of boolean data type.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

Y, X = np.mgrid[-2.3:2.3:0.005, -5:5:0.005]
Z = X+1j*Y

c = tf.constant(Z, np.complex64)#.astype(np.complex64))
zs = tf.Variable(c)
ns = tf.Variable(tf.zeros_like(c, tf.float32))

sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

zs_ = zs*zs + c

not_diverged = tf.abs(zs_) > 4

step = tf.group(zs.assign(zs_),
 ns.assign_add(tf.cast(not_diverged, tf.float32)))

nx = tf.reduce_sum(ns)
zx = tf.reduce_sum(zs_)
cx = tf.reduce_sum(c)
zf = tf.reduce_all(not_diverged)

for i in range(200): 
    step.run()
    print(sess.run([nx,zx,cx,zf]))

plt.imshow(ns.eval())
plt.show()

Solution

  • import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    # this defines the complex plane
    Y, X = np.mgrid[-2.3:2.3:0.005, -5:5:0.005]
    Z = X+1j*Y
    c = tf.constant(Z, np.complex64)
    
    # tensors are immutable in tensorflow,
    # but variabels arent, so use variable
    # to update values later on
    zs = tf.Variable(c)
    
    # ns will keep count of what has diverged
    ns = tf.Variable(tf.zeros_like(c, tf.float32))
    
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    
    # mandlebrot set M is defined as
    # c \in M \iff |P_c^n(0)| <= 2 \iff abs(P_c^n(0)) <= 4
    # where P_c(z) = z^2 + c
    # the variable name is confusing, as it is actually
    # the opposite, I renamed it below
    zs_ = zs*zs + c
    diverged = tf.abs(zs_) > 4
    
    # ns gets its value as a bool casted to a float
    # is given by True \mapsto 1., False \mapsto 0.
    # the assign add just says, add tf.cast(diverged, tf.float32)
    # to the variabel ns, and assign that value to the variable
    step = tf.group(
        zs.assign(zs_),
        ns.assign_add(tf.cast(diverged, tf.float32)))
    
    
    # here we iterate n to whatever we like
    # each time we are moving further along the
    # sequence P^n_c(0), which must be bounded
    # in a disk of radius 2 to be in M
    for i in range(200):
        step.run()
    
    # anywhere with value > 0 in the plot is not in the Mandlebrot set
    # anywhere with value = 0 MIGHT be in the Mandlebrot set
    # we don't know for sure if it is in the set, 
    # because we can only ever take n to be some
    # finite number. But to be in the Mandlebrot set, it has
    # to be bounded for all n!
    plt.imshow(ns.eval())
    plt.show()