Search code examples
pytorchgpytorch

Poor fits for simple 2D Gaussian processes in `GPyTorch`


I'm having a lot of difficulty fitting a simple 2-dimensional GP out-of-the box using GPyTorch. As you can see below, the fit is very poor, and does not improve much with either swapping out the RBF kernel for something like a Matern. The optimization does appear to converge, but not on anything sensible.

class GPRegressionModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)

        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.RBFKernel(ard_num_dims=2),
            )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

enter image description here

Does anyone have good tutorial examples beyond the ones included in the docs?


Solution

  • I was running into similar issues when trying to fit high-dimensional Gaussian Processes. A couple of suggestions (not sure if these will work):

    1. Try using a ZeroMean, rather than a constant mean. It could be that having more hyperparameters (constant mean hyperparameter values) could lead the -mll objective to a local minima, rather than to a global minima. Using a different optimizer, e.g. lbfgs (which is second-order, rather than adam or sgd, which are both first-order) may help with this as well.

    2. Try normalizing your input data using min-max normalization, and your targets using standard normal N(0,1) normalization. In short, these normalization steps ensure that your data is consistent with the default priors for these GPR models. For more info on this, check out this GitHub issue.

    3. Try changing the learning rate for optimizing your hyperparameters, or the number of epochs you train this for.

    4. If you see any issues with numerical precision (particularly after normalization, if you choose to use it), try changing your model and datasets to double, 64-bit precision by converting torch tensors to double-precision tensors with tensor.double().

    Again, can't guarantee that these will fix the issues you're running into, but hopefully they help!

    Here are some tutorials I've put together as well.