Search code examples
machine-learningprobabilistic-programmingpyro.ai

Different access methods to Pyro Paramstore give different results


I am following the Pyro introductory tutorial in forecasting, and trying to access the learned parameters after training the model, I get different results using different access methods for some of them (while getting identical results for others).

Here is the stripped-down reproducible code from the tutorial:

import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster

pyro.enable_validation(True)
pyro.clear_param_store()

pyro.__version__
# '1.3.1'
torch.__version__
# '1.5.0+cu101'

# import & prepare the data
dataset = load_bart_od()
T, O, D = dataset["counts"].shape
data = dataset["counts"][:T // (24 * 7) * 24 * 7].reshape(T // (24 * 7), -1).sum(-1).log()
data = data.unsqueeze(-1)
T0 = 0              # begining
T2 = data.size(-2)  # end
T1 = T2 - 52        # train/test split

# define the model class
class Model1(ForecastingModel):

    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)  
        feature_dim = covariates.size(-1)

        bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))
        prediction = bias + (weight * covariates).sum(-1, keepdim=True)
        assert prediction.shape[-2:] == zero_data.shape

        noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
        noise_dist = dist.Normal(0, noise_scale)

        self.predict(noise_dist, prediction)

# fit the model
pyro.set_rng_seed(1)
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = torch.stack([time], dim=-1)
forecaster = Forecaster(Model1(), data[:T1], covariates[:T1], learning_rate=0.1)

So far so good; now, I want to inspect the learned latent parameters stored in Paramstore. Seems there are more than one ways to do this; using the get_all_param_names() method:

for name in pyro.get_param_store().get_all_param_names():
    print(name, pyro.param(name).data.numpy())

I get

AutoNormal.locs.bias [14.585433]
AutoNormal.scales.bias [0.00631594]
AutoNormal.locs.weight [0.11947815]
AutoNormal.scales.weight [0.00922901]
AutoNormal.locs.noise_scale [-2.0719821]
AutoNormal.scales.noise_scale [0.03469057]

But using the named_parameters() method:

pyro.get_param_store().named_parameters()

gives the same values for the location (locs) parameters, but different values for all scales ones:

dict_items([
('AutoNormal.locs.bias', Parameter containing: tensor([14.5854], requires_grad=True)), 
('AutoNormal.scales.bias', Parameter containing: tensor([-5.0647], requires_grad=True)), 
('AutoNormal.locs.weight', Parameter containing: tensor([0.1195], requires_grad=True)), 
('AutoNormal.scales.weight', Parameter containing: tensor([-4.6854], requires_grad=True)),
('AutoNormal.locs.noise_scale', Parameter containing: tensor([-2.0720], requires_grad=True)), 
('AutoNormal.scales.noise_scale', Parameter containing: tensor([-3.3613], requires_grad=True))
])

How is this possible? According to the documentation, Paramstore is a simple key-value store; and there are only these six keys in it:

pyro.get_param_store().get_all_param_names() # .keys() method gives identical result
# result
dict_keys([
'AutoNormal.locs.bias',
'AutoNormal.scales.bias', 
'AutoNormal.locs.weight', 
'AutoNormal.scales.weight', 
'AutoNormal.locs.noise_scale', 
'AutoNormal.scales.noise_scale'])

so, there is no way that one method access one set of items and the other a different one.

Am I missing something here?


Solution

  • Here is the situation, as revealed in the Github thread I opened in parallel with this question...

    Paramstore is no more just a simple key-value store - it also performs constraint transformations; quoting a Pyro developer from the above link:

    here's some historical background. The ParamStore was originally just a key-value store. Then we added support for constrained parameters; this introduced a new layer of separation between user-facing constrained values and internal unconstrained values. We created a new dict-like user-facing interface that exposed only constrained values, but to keep backwards compatibility with old code we kept the old interface around. The two interfaces are distinguished in the source files [...] but as you observe it looks like we forgot to mark the old interface as DEPRECATED.

    I guess in clarifying docs we should:

    1. clarify that the ParamStore is no longer a simple key-value store but also performs constraint transforms;

    2. mark all "old" style interface methods as DEPRECATED;

    3. remove "old" style interface usage from examples and tutorials.

    As a consequence, it turns out that, while pyro.param() returns the results in the constrained (user-facing) space, the older method named_parameters() returns the unconstrained (i.e. for internal use only) values, hence the apparent discrepancy.

    It's not difficult to verify indeed that the scales values returned by the two methods above are related by a logarithmic transformation:

    import numpy as np
    items = list(pyro.get_param_store().named_parameters())  # unconstrained space
    
    i = 0
    for name in pyro.get_param_store().keys():  
      if 'scales' in name:
        temp = np.log(
                      pyro.param(name).item()  # constrained space
                     )
        print(temp, items[i][1][0].item() , np.allclose(temp, items[i][1][0].item()))
      i+=1
    
    # result:
    -5.027793402915326 -5.0277934074401855 True
    -4.600319371162187 -4.6003193855285645 True
    -3.3920585732532835 -3.3920586109161377 True
    

    Why does this discrepancy affect only scales parameters? That's because scales (i.e. essentially variances) are by definition constrained to be positive; that doesn't hold for locs (i.e. means), which are not constrained, hence the two representations coincide for them.

    As a result of the question above, a new bullet has now been added in the Paramstore documentation, giving a relevant hint:

    in general parameters are associated with both constrained and unconstrained values. for example, under the hood a parameter that is constrained to be positive is represented as an unconstrained tensor in log space.

    as well as in the documentation of the named_parameters() method of the old interface:

    Note that, in the event the parameter is constrained, unconstrained_value is in the unconstrained space implicitly used by the constraint.