Search code examples
juliainferencegenerativeprobabilistic-programminggenerative-programming

Gen: How to combine multiple generative function traces in a higher-order generative function?


I'm going through the "Introduction to Modeling in Gen" Notebook at https://github.com/probcomp/gen-quickstart

Section 5 (Calling other generative functions) asks to "Construct a data set for which it is ambiguous whether the line or sine wave model is best"

I'm having a hard problem understanding how I work with the traces (and returns) of the component functions to create a meaningful higher-order trace that I can use.

To me the most straightforward "ambiguous" model is line(xs).+sine(xs). So I Gen.simulateed line and sine to get the traces and adding them together, like this:

@gen function combo(xs::Vector{Float64})
    my_sin = simulate(sine_model_2,(xs,))
    my_lin = simulate(line_model_2,(xs,))
    if @trace(bernoulli(0.5), :is_line)
        @trace(normal(get_choices(my_lin)[:slope], 0.01), :slope)
        @trace(normal(get_choices(my_lin)[:intercept], 0.01), :intercept)
        @trace(normal(get_choices(my_lin)[:noise], 0.01), :noise)        
    else
        @trace(normal(get_choices(my_sin)[:phase], 0.01), :phase)
        @trace(normal(get_choices(my_sin)[:period], 0.01), :period)
        @trace(normal(get_choices(my_sin)[:amplitude], 0.01), :amplitude)
        @trace(normal(get_choices(my_sin)[:noise], 0.01), :noise)
    end
    combo = [get_choices(my_sin)[(:y, i)] + get_choices(my_lin)[(:y, i)] for i=1:length(xs)]
    for (i, c) in enumerate(combo)
        @trace(normal(c, 0.1), (:y, i))
    end
    end;

This is clearly wrong and I know I'm missing something fundamental in the whole idea of traces and prob programming in Gen.

I'd expect to be able to introspect sine/line_model's trace from within combo, and do element-wise addition on the traces to get a new trace. And not have to randomly pick a number close to :intercept, :phase, etc. so I can include it in my trace later on.

By the way, when I do:

traces = [Gen.simulate(combo,(xs,)) for _=1:12];
grid(render_combined, traces)

I get failed attempt at function

Please help thanks!


Solution

  • Hi there — thanks for your interest in Gen! :)

    Addresses of the combined model's trace

    The combined model from the tutorial looks like this:

    @gen function combined_model(xs::Vector{Float64})
        if @trace(bernoulli(0.5), :is_line)
            @trace(line_model_2(xs))
        else
            @trace(sine_model_2(xs))
        end
    end;
    

    Its traces will have the following addresses:

    • :is_line, storing a Boolean indicating whether the generated dataset was linear or not.
    • Any addresses from line_model_2 or sine_model_2, depending on which was called.

    Note that traces of both line_model_2 and sine_model_2 contain the addresses (:y, i) for each integer i between 1 and length(xs). Because of this, so will combined_model's traces: these are the addresses representing the final sampled y values, regardless of which of the two processes generated them.

    Constructing a new dataset

    The question to "construct a data set for which it is ambiguous whether the line or sine wave model is best" does not require writing a new generative function (with @gen), but rather, constructing a list of xs and a list of ys (in plain Julia) that you think might make a difficult-to-disambiguate dataset. You can then pass your xs and ys into the do_inference function defined earlier in the notebook, to see what the system concludes about your dataset. Note that the do_inference function constructs a constraint choicemap that constrains each (:y, i) to the value ys[i] from the dataset you passed in. This works because (:y, i) is always the name of the ith datapoint, no matter the value of :is_line.

    Updating / manipulating traces

    You write:

    I'd expect to be able to introspect sine/line_model's trace from within combo, and do element-wise addition on the traces to get a new trace. And not have to randomly pick a number close to :intercept, :phase, etc. so I can include it in my trace later on.

    You can certainly call simulate twice to get two traces, outside a generative function like combo. But traces cannot be manipulated in arbitrary ways (e.g. "elementwise addition"): as data structures, traces maintain certain invariants, like always knowing the exact probability of their current values under the model that generated them, and always holding values that actually could have been generated from the model.

    The dictionary-like data structure you're looking for is a choicemap. Choicemaps are mutable and can be built up to include arbitrary values at arbitrary addresses. For example, you can write:

    observations = Gen.choicemap()
    for (i, y) in enumerate(ys)
      observations[(:y, i)] = y
    end
    

    Choicemaps can be used as constraints to generate new traces (using Gen.generate), as arguments to Gen's low-level Gen.update method (with allows you to update a trace while recomputing any relevant probabilities, and erroring if your updates are invalid), and in several other places.

    Hope that helps :)