Search code examples

Setting the task covariance matrix to the correlation matrix in GPyTorch

I am trying to set the calculated correlation matrix as the task covariance in GPyTorch for a numerical experiment. In case you are wondering, the data is normalized already.

corr_matrix = ...
corr_matrix = torch.from_numpy(corr_matrix)

I define a very simplistic model the following way:

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=23
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=23, rank=23)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=23)
model = MultitaskGPModel(train_x.float(), train_y.float(), likelihood)

From what I understood a way to fix the parameter would be something like this:

model.covar_module.task_covar_module._parameters['covar_factor'] = torch.from_numpy(corr_matrix)

I don't touch anything on the optimizer, given that the tensor I use to give the covariance matrix has the argument requires_grad = False

Nonetheless I got this error:

RuntimeError: KroneckerProductLinearOperator expects lazy tensors with the same batch shapes. Got [torch.Size([]), torch.Size([23])].

I don't really understand how to insert this parameter so it respects the Kronecker product evaluation. So I would appreciate any help.


  • So, I found out what the problem was. I was having a bad dimensionality set up.

    model.covar_module.task_covar_module._parameters['raw_var'] = torch.zeros_like(model.covar_module.task_covar_module._parameters['raw_var'])
    model.covar_module.task_covar_module._parameters['covar_factor'] = torch.from_numpy(corr_matrix).float()

    That way it got solved