Search code examples
pythontensorflowtensorflow-probability

Tensorflow probability: ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)


I'm trying to get NUTS sampler working on a toymodel. I've run into the issue mentioned in the title.

This is the code to reproduce the error:

import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from functools import partial
import numpy as np

tfd = tfp.distributions


A = tf.random.normal(
    [10,10], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
noise_std = tf.random.normal([1])

x1 = tfd.Normal(0, 10 * tf.ones(10)).sample()
x1 = x1[..., tf.newaxis]
y = tf.linalg.matmul(A, x1) + noise_std


model = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.), #sigma 
    tfd.Normal(0, 10 * tf.ones(10)),
    lambda x_rv, sigma : tfd.Normal(loc=tf.linalg.matmul(A, x_rv[...,tf.newaxis]) + sigma, scale=1.0)
])

def target_log_prob_fn(sigma, x_rv):
    return model.log_prob([sigma, x_rv, y[tf.newaxis, ...]])


def trace_fn(_, pkr):  
    return (
        pkr.inner_results.inner_results.target_log_prob,
        pkr.inner_results.inner_results.leapfrogs_taken,
        pkr.inner_results.inner_results.has_divergence,
        pkr.inner_results.inner_results.energy,
        pkr.inner_results.inner_results.log_accept_ratio)

n_chains = 2

def run_nuts_template(
    trace_fn,
    target_log_prob_fn,
    inits,
    bijectors_list=None, 
    num_steps=500,
    num_burnin=500,
    n_chains=n_chains):
    
    step_size = np.random.rand(n_chains, 1)*.5 + 1.
    
    if not isinstance(inits, list):
        inits = [inits]
        
    if bijectors_list is None:
        bijectors_list = [tfb.Identity()]*len(inits)

    kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
        tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=tfp.mcmc.NoUTurnSampler(
                target_log_prob_fn,
                step_size=[step_size]*len(inits)
            ),
            bijector=bijectors_list
        ),
        target_accept_prob=.8,
        num_adaptation_steps=int(0.8*num_burnin),
        step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
              inner_results=pkr.inner_results._replace(step_size=new_step_size)
          ),
        step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
        log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
    )
    
    res = tfp.mcmc.sample_chain(
        num_results=num_steps,
        num_burnin_steps=num_burnin,
        current_state=inits,
        kernel=kernel,
        trace_fn=trace_fn
    )
    return res



inits = model.sample(n_chains)
run_nuts = partial(run_nuts_template, trace_fn)


inits = [tf.random.uniform(s.shape, -2, 2, tf.float32, name="initializer") for s in inits]

run_nuts(target_log_prob_fn, inits[:-1])


Error: ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)

Full stacktrace: https://pastebin.com/zAA58P53

ValueError                                Traceback (most recent call last)
<ipython-input-17-ab5eae0dd51a> in <module>
      3 ]
      4 
----> 5 run_nuts(
      6     target_log_prob_fn,
      7             inits[:-1]
 
<ipython-input-13-2d7a920574a8> in run_nuts_template(trace_fn, target_log_prob_fn, inits, bijectors_list, num_steps, num_burnin, n_chains)
     45     )
     46 
---> 47     res = tfp.mcmc.sample_chain(
     48         num_results=num_steps,
     49         num_burnin_steps=num_burnin,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
    359       return seed, next_state, current_kernel_results
    360 
--> 361     (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
    362         loop_fn=_trace_scan_fn,
    363         initial_state=(seed, current_state, previous_kernel_results),
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, name)
    462       return i + 1, state, num_steps_traced, trace_arrays
    463 
--> 464     _, final_state, _, trace_arrays = tf.while_loop(
    465         cond=lambda i, *_: i < length,
    466         body=_body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in _body(i, state, num_steps_traced, trace_arrays)
    452     def _body(i, state, num_steps_traced, trace_arrays):
    453       elem = elems_array.read(i)
--> 454       state = loop_fn(state, elem)
    455 
    456       trace_arrays, num_steps_traced = ps.cond(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _trace_scan_fn(seed_state_and_results, num_steps)
    352 
    353     def _trace_scan_fn(seed_state_and_results, num_steps):
--> 354       seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
    355           loop_num_iter=num_steps,
    356           body_fn=_seeded_one_step,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
    346       # where while/LoopCond needs it.
    347       loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
--> 348       return tf.while_loop(
    349           cond=lambda i, *args: i < loop_num_iter,
    350           body=lambda i, *args: [i + 1] + list(body_fn(*args)),
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in <lambda>(i, *args)
    348       return tf.while_loop(
    349           cond=lambda i, *args: i < loop_num_iter,
--> 350           body=lambda i, *args: [i + 1] + list(body_fn(*args)),
    351           loop_vars=[np.int32(0)] + initial_loop_vars,
    352           parallel_iterations=parallel_iterations
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _seeded_one_step(seed, *state_and_results)
    349       one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
    350       return [passalong_seed] + list(
--> 351           kernel.one_step(*state_and_results, **one_step_kwargs))
    352 
    353     def _trace_scan_fn(seed_state_and_results, num_steps):
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py in one_step(self, current_state, previous_kernel_results, seed)
    454       # Step the inner kernel.
    455       inner_kwargs = {} if seed is None else dict(seed=seed)
--> 456       new_state, new_inner_results = self.inner_kernel.one_step(
    457           current_state, inner_results, **inner_kwargs)
    458 
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py in one_step(self, current_state, previous_kernel_results, seed)
    397         self.name, 'transformed_kernel', 'one_step')):
    398       inner_kwargs = {} if seed is None else dict(seed=seed)
--> 399       transformed_next_state, kernel_results = self._inner_kernel.one_step(
    400           previous_kernel_results.transformed_state,
    401           previous_kernel_results.inner_results,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in one_step(self, current_state, previous_kernel_results, seed)
    392           )
    393 
--> 394       _, _, _, new_step_metastate = tf.while_loop(
    395           cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
    396               (iter_ < self.max_tree_depth) &
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, state, metastate)
    396               (iter_ < self.max_tree_depth) &
    397               tf.reduce_any(metastate.continue_tree)),
--> 398           body=lambda iter_, seed, state, metastate: self._loop_tree_doubling(  # pylint: disable=g-long-lambda
    399               previous_kernel_results.step_size,
    400               previous_kernel_results.momentum_state_memory,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed)
    570           momentum_subtree_cumsum,
    571           leapfrogs_taken
--> 572       ] = self._build_sub_tree(
    573           directions_expanded,
    574           integrator,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name)
    750           final_not_divergence,
    751           momentum_state_memory,
--> 752       ] = tf.while_loop(
    753           cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
    754                       leapfrogs_taken, state, state_c, continue_tree,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)
    758                       leapfrogs_taken, state, state_c, continue_tree,
    759                       not_divergence, momentum_state_memory: (
--> 760                           self._loop_build_sub_tree(
    761                               directions, integrator, current_step_meta_info,
    762                               iter_, energy_diff_sum, init_momentum_cumsum,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed)
    811           next_target,
    812           next_target_grad_parts
--> 813       ] = integrator(prev_tree_state.momentum,
    814                      prev_tree_state.state,
    815                      prev_tree_state.target,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in __call__(self, momentum_parts, state_parts, target, target_grad_parts, kinetic_energy_fn, name)
    295           next_target,
    296           next_target_grad_parts,
--> 297       ] = tf.while_loop(
    298           cond=lambda i, *_: i < self.num_steps,
    299           body=lambda i, *args: [i + 1] + list(_one_step(  # pylint: disable=no-value-for-parameter,g-long-lambda
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in <lambda>(i, *args)
    297       ] = tf.while_loop(
    298           cond=lambda i, *_: i < self.num_steps,
--> 299           body=lambda i, *args: [i + 1] + list(_one_step(  # pylint: disable=no-value-for-parameter,g-long-lambda
    300               self.target_fn, self.step_sizes, get_velocity_parts, *args)),
    301           loop_vars=[
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in _one_step(target_fn, step_sizes, get_velocity_parts, half_next_momentum_parts, state_parts, target, target_grad_parts)
    353               next_target_grad_parts))
    354 
--> 355     tensorshape_util.set_shape(next_target, target.shape)
    356     for ng, g in zip(next_target_grad_parts, target_grad_parts):
    357       tensorshape_util.set_shape(ng, g.shape)
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py in set_shape(tensor, shape)
    326   """
    327   if hasattr(tensor, 'set_shape'):
--> 328     tensor.set_shape(shape)
    329 
    330 
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
   1213   def set_shape(self, shape):
   1214     if not self.shape.is_compatible_with(shape):
-> 1215       raise ValueError(
   1216           "Tensor's shape %s is not compatible with supplied shape %s" %
   1217           (self.shape, shape))
 
ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)

The only variable with shape of (2, ) is sigma, but I was not able to figure how is turns into a (2, 2) tensor.


Solution

  • The issue is with the shapes of your step sizes. As written, the chain state parts have (including the n_chains batch shape) shape [2] and [2, 10], resp. However your step sizes are all being initialized with shape [2, 1]. This is correct for the second state part ([2, 1] broadcasts with [2, 10]) but not for the first -- you end up with a [2, 2] somewhere, presumably in the step_size * grad(tlp, state[0]) term (paraphrasing) in the integrator.

    I rewrote the step size initialization to this, which could probably use some code-golfing, but works as intended:

        step_size = np.random.rand(n_chains) * .5 + 1.
        step_size = [np.reshape(step_size, [n_chains] + [1] * (x.shape.ndims - 1))
                     for x in inits]
        # now step size shapes are [2] and [2, 1]
    

    Some other notes:

    1. Check out tf.linalg.matvec -- it will save you some chars (and possibly some flops) by avoiding the newaxis + matmul
    2. it looks like you're doing stan-style initialization, with random inits in the [-2, 2]^n hypercube. In general you'll want to do this in the unconstrained space, so you will probably want to push those random init's through your constraining bijectors, at least when they're not the defaults (Identity). TransformedTransitionKernel assumes the user-provided states are in the constrained (sample) space. Hopefully this is clear...please let me know if not!
    3. There's a new experimental feature, "pinning" that might simplify some things for you here. You'll want to first make sure your distributions have names (pass a name arg to the constructors), but then you can call pinned = model.pinned(y=y) to get an instance of tfp.experimental.distributions.JointDistributionPinned. You can use pinned.unnormalized_log_prob in place of your def target_log_prob_fn (it does the same thing), and you can also call pinned.experimental_default_event_space_bijector() to get a bijector that "just does the right thing" for transforming the un-pinned variables. Ie, you can just hand that thing off to TransformedTransitionKernel. It actually is a "multipart" (or "joint") bijector, so it eats lists and returns lists; you no longer need to list-wrap your passed-in bijector. TTK, as of recently, knows how to use these multipart bijectors. As the names suggest, these are all pretty new and experimental/subject to API tweaks, but should be in good working order; please let us know if you try them and run into issues!

    Here's a colab w/ the slightly modified step size code: https://colab.research.google.com/drive/1o-nygALqdq2ppj5rU9d6UVafZM5SBCd4

    HTH!