Search code examples
tensorflowregressiongpflowgaussian-process

How can I compile batched training of a gpflow GPR into a tf.function?


I need to train a GPR model in multiple batches per epoch using a custom loss function. I would like to do this using GPflow and I would like to compile my training using tf.function to increase the efficiency. However, gpflow.GPR must be re-instantiated each time you supply new data, so tf.function will have to re-trace each time. This makes the code slower rather than faster.

This is the initial setup:

import numpy as np
from itertools import islice
import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
from sklearn.model_selection import train_test_split
import gpflow
from gpflow.kernels import SquaredExponential
import time

data_size = 1000
train_fract = 0.8
batch_size = 250
n_epochs = 3
iterations_per_epoch = int(train_fract * data_size/batch_size)
tf.random.set_seed(3)

# Generate dummy data
x = np.arange(data_size)
y = np.arange(data_size) + np.random.rand(data_size)

# Slice into train and validate sets
x_train, x_validate, y_train, y_validate = train_test_split(x, y, random_state = 1, test_size = 1-train_fract )

# Convert data into tensorflow constants
x_train = tf.constant(x_train[:, np.newaxis], dtype=np.float64)
x_validate = tf.constant(x_validate[:, np.newaxis], dtype=np.float64)
y_train = tf.constant(y_train[:, np.newaxis], dtype=np.float64)
y_validate = tf.constant(y_validate[:, np.newaxis], dtype=np.float64)

# Batch data
batched_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train)) 
    .shuffle(buffer_size=len(x_train), seed=1) 
    .repeat(count=None)
    .batch(batch_size)
)    

# Create kernel
constrain_positive = tfb.Shift(np.finfo(np.float64).tiny)(tfb.Exp())
amplitude = tfp.util.TransformedVariable(initial_value=1, bijector=constrain_positive, dtype=np.float64, name="amplitude")
len_scale = tfp.util.TransformedVariable(initial_value=10, bijector=constrain_positive, dtype=np.float64, name="len_scale")
kernel = SquaredExponential(variance=amplitude, lengthscales=len_scale, name="squared_exponential_kernel")
obs_noise = tfp.util.TransformedVariable(initial_value=1e-3, bijector=constrain_positive, dtype=np.float64, name="observation_noise")


# Define custom loss function
@tf.function(autograph=False, experimental_compile=False)
def my_custom_loss(y_predict, y_true):
    return tf.math.reduce_mean(tf.math.squared_difference(y_predict, y_true))

#optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)

This is how I train without a tf.function:

gpr_model_j_i = gpflow.models.GPR(data=(x_train, y_train), kernel=kernel, noise_variance=obs_noise)

# Start training loop
for j in range(n_epochs):
    for i, (x_train_j_i, y_train_j_i) in enumerate(islice(batched_dataset, iterations_per_epoch)):
        with tf.GradientTape() as tape:
            gpr_model_j_i = gpflow.models.GPR(data=(x_train_j_i, y_train_j_i), kernel=kernel, noise_variance=gpr_model_j_i.likelihood.variance)
            y_predict_j_i = gpr_model_j_i.predict_f(x_validate)[0]
            loss_j_i = my_custom_loss(y_predict_j_i, y_validate)

        grads_j_i = tape.gradient(loss_j_i, gpr_model_j_i.trainable_variables)
        optimizer.apply_gradients(zip(grads_j_i, gpr_model_j_i.trainable_variables))

This is how I train with a tf.function:

@tf.function(autograph=False, experimental_compile=False)
def tf_function_attempt_3(model): #, optimizer):
    with tf.GradientTape() as tape:
        y_predict_j_i = model.predict_f(x_validate)[0]
        loss_j_i = my_custom_loss(y_predict_j_i, y_validate)

    grads_j_i = tape.gradient(loss_j_i, model.trainable_variables)
    optimizer.apply_gradients(zip(grads_j_i, model.trainable_variables))
    print("TRACING...", end="")



for j in range(n_epochs):
    for i, (x_train_j_i, y_train_j_i) in enumerate(islice(batched_dataset, iterations_per_epoch)):
        gpr_model_j_i = gpflow.models.GPR(data=(x_train_j_i, y_train_j_i), kernel=kernel, noise_variance=gpr_model_j_i.likelihood.variance)
        tf_function_attempt_3(gpr_model_j_i)#, optimizer)

The tf.function retraces for each batch and is significantly slower than the normal training.

Is there a way to speed up the batched training of my GPR model with tf.function while using a custom loss function and GPflow? If not, I am open to suggestions for an alternative approach.


Solution

  • You don't have to re-instantiate GPR each time. You can construct tf.Variable holders with unconstrained shape and then .assign to them:

    import gpflow
    import numpy as np
    import tensorflow as tf
    
    input_dim = 1
    initial_x, initial_y = np.zeros((0, input_dim)), np.zeros((0, 1))  # or your first batch
    x_var = tf.Variable(initial_x, shape=(None, input_dim), dtype=tf.float64)
    y_var = tf.Variable(initial_y, shape=(None,1), dtype=tf.float64)
    # in principle you could also set shape=(None, None)...
    
    m = gpflow.models.GPR((x_var, y_var), gpflow.kernels.SquaredExponential())
    loss = m.training_loss_closure()  # compile=True default wraps in tf.function()
    
    N1 = 3
    x1, y1 = np.random.randn(N1, input_dim), np.random.randn(N1, 1)
    m.data[0].assign(x1)
    m.data[1].assign(y1)
    loss()  # traces the first time
    
    N2 = 7
    x2, y2 = np.random.randn(N2, input_dim), np.random.randn(N2, 1)
    m.data[0].assign(x2)
    m.data[1].assign(y2)
    loss()  # does not trace again