I'm trying to get NUTS sampler working on a toymodel. I've run into the issue mentioned in the title.
This is the code to reproduce the error:
import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from functools import partial
import numpy as np
tfd = tfp.distributions
A = tf.random.normal(
[10,10], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
noise_std = tf.random.normal([1])
x1 = tfd.Normal(0, 10 * tf.ones(10)).sample()
x1 = x1[..., tf.newaxis]
y = tf.linalg.matmul(A, x1) + noise_std
model = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), #sigma
tfd.Normal(0, 10 * tf.ones(10)),
lambda x_rv, sigma : tfd.Normal(loc=tf.linalg.matmul(A, x_rv[...,tf.newaxis]) + sigma, scale=1.0)
])
def target_log_prob_fn(sigma, x_rv):
return model.log_prob([sigma, x_rv, y[tf.newaxis, ...]])
def trace_fn(_, pkr):
return (
pkr.inner_results.inner_results.target_log_prob,
pkr.inner_results.inner_results.leapfrogs_taken,
pkr.inner_results.inner_results.has_divergence,
pkr.inner_results.inner_results.energy,
pkr.inner_results.inner_results.log_accept_ratio)
n_chains = 2
def run_nuts_template(
trace_fn,
target_log_prob_fn,
inits,
bijectors_list=None,
num_steps=500,
num_burnin=500,
n_chains=n_chains):
step_size = np.random.rand(n_chains, 1)*.5 + 1.
if not isinstance(inits, list):
inits = [inits]
if bijectors_list is None:
bijectors_list = [tfb.Identity()]*len(inits)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.NoUTurnSampler(
target_log_prob_fn,
step_size=[step_size]*len(inits)
),
bijector=bijectors_list
),
target_accept_prob=.8,
num_adaptation_steps=int(0.8*num_burnin),
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
inner_results=pkr.inner_results._replace(step_size=new_step_size)
),
step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
)
res = tfp.mcmc.sample_chain(
num_results=num_steps,
num_burnin_steps=num_burnin,
current_state=inits,
kernel=kernel,
trace_fn=trace_fn
)
return res
inits = model.sample(n_chains)
run_nuts = partial(run_nuts_template, trace_fn)
inits = [tf.random.uniform(s.shape, -2, 2, tf.float32, name="initializer") for s in inits]
run_nuts(target_log_prob_fn, inits[:-1])
Error: ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)
Full stacktrace: https://pastebin.com/zAA58P53
ValueError Traceback (most recent call last)
<ipython-input-17-ab5eae0dd51a> in <module>
3 ]
4
----> 5 run_nuts(
6 target_log_prob_fn,
7 inits[:-1]
<ipython-input-13-2d7a920574a8> in run_nuts_template(trace_fn, target_log_prob_fn, inits, bijectors_list, num_steps, num_burnin, n_chains)
45 )
46
---> 47 res = tfp.mcmc.sample_chain(
48 num_results=num_steps,
49 num_burnin_steps=num_burnin,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
359 return seed, next_state, current_kernel_results
360
--> 361 (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
362 loop_fn=_trace_scan_fn,
363 initial_state=(seed, current_state, previous_kernel_results),
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, name)
462 return i + 1, state, num_steps_traced, trace_arrays
463
--> 464 _, final_state, _, trace_arrays = tf.while_loop(
465 cond=lambda i, *_: i < length,
466 body=_body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in _body(i, state, num_steps_traced, trace_arrays)
452 def _body(i, state, num_steps_traced, trace_arrays):
453 elem = elems_array.read(i)
--> 454 state = loop_fn(state, elem)
455
456 trace_arrays, num_steps_traced = ps.cond(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _trace_scan_fn(seed_state_and_results, num_steps)
352
353 def _trace_scan_fn(seed_state_and_results, num_steps):
--> 354 seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
355 loop_num_iter=num_steps,
356 body_fn=_seeded_one_step,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
346 # where while/LoopCond needs it.
347 loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
--> 348 return tf.while_loop(
349 cond=lambda i, *args: i < loop_num_iter,
350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in <lambda>(i, *args)
348 return tf.while_loop(
349 cond=lambda i, *args: i < loop_num_iter,
--> 350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
351 loop_vars=[np.int32(0)] + initial_loop_vars,
352 parallel_iterations=parallel_iterations
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _seeded_one_step(seed, *state_and_results)
349 one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
350 return [passalong_seed] + list(
--> 351 kernel.one_step(*state_and_results, **one_step_kwargs))
352
353 def _trace_scan_fn(seed_state_and_results, num_steps):
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py in one_step(self, current_state, previous_kernel_results, seed)
454 # Step the inner kernel.
455 inner_kwargs = {} if seed is None else dict(seed=seed)
--> 456 new_state, new_inner_results = self.inner_kernel.one_step(
457 current_state, inner_results, **inner_kwargs)
458
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py in one_step(self, current_state, previous_kernel_results, seed)
397 self.name, 'transformed_kernel', 'one_step')):
398 inner_kwargs = {} if seed is None else dict(seed=seed)
--> 399 transformed_next_state, kernel_results = self._inner_kernel.one_step(
400 previous_kernel_results.transformed_state,
401 previous_kernel_results.inner_results,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in one_step(self, current_state, previous_kernel_results, seed)
392 )
393
--> 394 _, _, _, new_step_metastate = tf.while_loop(
395 cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda
396 (iter_ < self.max_tree_depth) &
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, state, metastate)
396 (iter_ < self.max_tree_depth) &
397 tf.reduce_any(metastate.continue_tree)),
--> 398 body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda
399 previous_kernel_results.step_size,
400 previous_kernel_results.momentum_state_memory,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed)
570 momentum_subtree_cumsum,
571 leapfrogs_taken
--> 572 ] = self._build_sub_tree(
573 directions_expanded,
574 integrator,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name)
750 final_not_divergence,
751 momentum_state_memory,
--> 752 ] = tf.while_loop(
753 cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda
754 leapfrogs_taken, state, state_c, continue_tree,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)
758 leapfrogs_taken, state, state_c, continue_tree,
759 not_divergence, momentum_state_memory: (
--> 760 self._loop_build_sub_tree(
761 directions, integrator, current_step_meta_info,
762 iter_, energy_diff_sum, init_momentum_cumsum,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed)
811 next_target,
812 next_target_grad_parts
--> 813 ] = integrator(prev_tree_state.momentum,
814 prev_tree_state.state,
815 prev_tree_state.target,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in __call__(self, momentum_parts, state_parts, target, target_grad_parts, kinetic_energy_fn, name)
295 next_target,
296 next_target_grad_parts,
--> 297 ] = tf.while_loop(
298 cond=lambda i, *_: i < self.num_steps,
299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in <lambda>(i, *args)
297 ] = tf.while_loop(
298 cond=lambda i, *_: i < self.num_steps,
--> 299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
300 self.target_fn, self.step_sizes, get_velocity_parts, *args)),
301 loop_vars=[
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in _one_step(target_fn, step_sizes, get_velocity_parts, half_next_momentum_parts, state_parts, target, target_grad_parts)
353 next_target_grad_parts))
354
--> 355 tensorshape_util.set_shape(next_target, target.shape)
356 for ng, g in zip(next_target_grad_parts, target_grad_parts):
357 tensorshape_util.set_shape(ng, g.shape)
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py in set_shape(tensor, shape)
326 """
327 if hasattr(tensor, 'set_shape'):
--> 328 tensor.set_shape(shape)
329
330
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
1213 def set_shape(self, shape):
1214 if not self.shape.is_compatible_with(shape):
-> 1215 raise ValueError(
1216 "Tensor's shape %s is not compatible with supplied shape %s" %
1217 (self.shape, shape))
ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)
The only variable with shape of (2, ) is sigma
, but I was not able to figure how is turns into a (2, 2)
tensor.
The issue is with the shapes of your step sizes. As written, the chain state parts have (including the n_chains
batch shape) shape [2]
and [2, 10]
, resp. However your step sizes are all being initialized with shape [2, 1]
. This is correct for the second state part ([2, 1]
broadcasts with [2, 10]
) but not for the first -- you end up with a [2, 2]
somewhere, presumably in the step_size * grad(tlp, state[0])
term (paraphrasing) in the integrator.
I rewrote the step size initialization to this, which could probably use some code-golfing, but works as intended:
step_size = np.random.rand(n_chains) * .5 + 1.
step_size = [np.reshape(step_size, [n_chains] + [1] * (x.shape.ndims - 1))
for x in inits]
# now step size shapes are [2] and [2, 1]
Some other notes:
tf.linalg.matvec
-- it will save you some chars (and possibly some flops) by avoiding the newaxis + matmul[-2, 2]^n
hypercube. In general you'll want to do this in the unconstrained space, so you will probably want to push those random init's through your constraining bijectors, at least when they're not the defaults (Identity
). TransformedTransitionKernel
assumes the user-provided states are in the constrained (sample) space. Hopefully this is clear...please let me know if not!name
arg to the constructors), but then you can call pinned = model.pinned(y=y)
to get an instance of tfp.experimental.distributions.JointDistributionPinned
. You can use pinned.unnormalized_log_prob
in place of your def target_log_prob_fn
(it does the same thing), and you can also call pinned.experimental_default_event_space_bijector()
to get a bijector that "just does the right thing" for transforming the un-pinned variables. Ie, you can just hand that thing off to TransformedTransitionKernel
. It actually is a "multipart" (or "joint") bijector, so it eats lists and returns lists; you no longer need to list-wrap your passed-in bijector. TTK, as of recently, knows how to use these multipart bijectors. As the names suggest, these are all pretty new and experimental/subject to API tweaks, but should be in good working order; please let us know if you try them and run into issues!Here's a colab w/ the slightly modified step size code: https://colab.research.google.com/drive/1o-nygALqdq2ppj5rU9d6UVafZM5SBCd4
HTH!