Search code examples
python-3.xtensorflowtensorflow-probability

Creating a custom gradient function for HMC for large datasets


I'm trying to infer the parameters of a Gaussian process using HMC in tensorflow-probability.

I have multiple independent data sequences that are generated from the same underlying process and I want to infer the kernel parameters that they all share.

To calculate the likelihood I'm using eager mode and looping over each independent sequence. I'm able to calculate the likelihood but I run into resource exhausted errors when trying to calculate the gradient.

I know this will be very slow but I would like to be able to use HMC with any size dataset without running out of memory.

I use the following code to create synthetic data, this creates N samples from a GP with p data points.

L = 5
variance=2
m_noise=0.05

kernel=psd_kernels.ExponentiatedQuadratic(np.float64(variance), np.float64(L))


def gram_matrix(xs):
    return kernel.matrix(xs,xs).numpy() + m_noise*np.identity(xs.shape[0])


observation_index_points = []
observations = []
N=200
p = 2000
for i in range(0, N):

    xs = np.sort(np.random.uniform(0,100,p))[...,None]
    mean = [0 for x in xs]
    gram = gram_matrix(xs)

    ys = np.random.multivariate_normal(mean, gram)

    observation_index_points.append(xs)

    observations.append(ys)

for i in range(0, N):

    plt.plot(observation_index_points[i],observations[i])
plt.show()

The following code to calculate the log likelihood will run with the sampler for small values of N but fails with larger values of N (resource exhausted). The error occurs when trying to calculate the gradient of the likelihood.

@tf.function()
def gp_log_prob(amplitude, length_scale, seg_index_points, noise_variance, seg_observations):

    kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
    gp = tfd.GaussianProcess(kernel=kernel,
                                index_points=seg_index_points,
                                observation_noise_variance=noise_variance)


    return gp.log_prob(seg_observations)

rv_amplitude = tfd.LogNormal(np.float64(0.), np.float64(1))
rv_length_scale = tfd.LogNormal(np.float64(0.), np.float64(1))
rv_noise_variance = tfd.LogNormal(np.float64(0.), np.float64(1))

def joint_log_prob_no_grad(amplitude, length_scale, noise_variance):
    ret_val = rv_amplitude.log_prob(amplitude) \
                + rv_length_scale.log_prob(length_scale) \
                + rv_noise_variance.log_prob(noise_variance)

    for i in range(N):
        ret_val = ret_val + gp_log_prob(amplitude, 
                                        length_scale, 
                                        observation_index_points[i], 
                                        noise_variance, 
                                        observations[i])

    return ret_val

BUT I can calculate the gradient for large N using gradienttape inside the loop. This code runs for any N and returns the correct likelihood and gradient:

def joint_log_prob(amplitude, length_scale, noise_variance):

    with tf.GradientTape() as tape:
        tape.watch(amplitude)
        tape.watch(length_scale)
        tape.watch(noise_variance)

        ret_val = rv_amplitude.log_prob(amplitude) \
                    + rv_length_scale.log_prob(length_scale) \
                    + rv_noise_variance.log_prob(noise_variance)

    grads = tape.gradient(ret_val, [amplitude, length_scale, noise_variance])

    for i in range(N):
        with tf.GradientTape() as tape:
            tape.watch([amplitude, length_scale, noise_variance])
            gp_prob = gp_log_prob(amplitude, length_scale, 
                                  observation_index_points[i], noise_variance, observations[i])

        gp_grads = tape.gradient(gp_prob, [amplitude, length_scale, noise_variance])


        grads = [a+b for a,b in zip(grads,gp_grads)]
        ret_val = ret_val + gp_prob

    return ret_val, grads

x = tf.convert_to_tensor(np.float64(1.0))
y = tf.convert_to_tensor(np.float64(1.0))
z = tf.convert_to_tensor(np.float64(0.1))

joint_log_prob(x,y,z) # correct output even for large N

If I then turn this into a customgradient it fails again:

@tf.custom_gradient
def joint_log_prob_cg(amplitude, length_scale, noise_variance):

    with tf.GradientTape() as tape:
        tape.watch(amplitude)
        tape.watch(length_scale)
        tape.watch(noise_variance)

        ret_val = rv_amplitude.log_prob(amplitude) \
                    + rv_length_scale.log_prob(length_scale) \
                    + rv_noise_variance.log_prob(noise_variance)

    grads = tape.gradient(ret_val, [amplitude, length_scale, noise_variance])

    for i in range(N):
        with tf.GradientTape() as tape:
            tape.watch([amplitude, length_scale, noise_variance])
            gp_prob = gp_log_prob(amplitude, length_scale, 
                                  observation_index_points[i], noise_variance, observations[i])

        gp_grads = tape.gradient(gp_prob, [amplitude, length_scale, noise_variance])


        grads = [a+b for a,b in zip(grads,gp_grads)]
        ret_val = ret_val + gp_prob

    def grad(dy):
        return grads
    return ret_val, grad

with tf.GradientTape() as t:
    t.watch([x,y,z])
    lp = joint_log_prob_cg(x,y,z)

t.gradient(lp, [x,y,z]) # fails for large N

My question is how can I get grads from the joint_log_prob function above (which I know can be computed for any large dataset) into the HMC sampler? It seems that if the whole function is wrapped in the gradienttape call then the for loop is unrolled and it runs out of memory - but is there a way around this?


Solution

  • In case anyone is interested, I was able to solve this by using a custom gradient and stopping the tape recording in the for loop. I needed to import the tape utilities:

    from tensorflow.python.eager import tape
    

    then stop recording around the for loop

    with tape.stop_recording():
        for i in range(N):
            ...
    

    This stops tracing, I then have to compute the gradient in graph mode and add them manually to stop the OOM error.