Tutorial: Reversible-Jump MCMC in Gen (with applications to Program Synthesis)

What is this notebook about?

In earlier tutorials, we saw how to write custom Metropolis-Hastings proposals and stitch them together to craft our own MCMC algorithms. Recall that to write a custom MH proposal, it was necessary to define a generative function that accepted as an argument the previous trace, and proposed a new trace by “acting out” the model, that is, by sampling new proposed values of random choices at the same addresses used by the model.

For example, a random walk proposal for the addresses :x and :y might have looked like this:

@gen function random_walk_proposal(previous_trace)
    x ~ normal(previous_trace[:x], 0.1)
    y ~ normal(previous_trace[:y], 0.2)
end

This pattern implies a severe restriction on the kinds of proposals we can write: the new x and y values must be sampled directly from a Ggen distribution (e.g., normal). For example, it is impossible to use this pattern to define a proposal that deterministcally swaps the current x and y values.

In this notebook, you’ll learn a more flexible approach to defining custom MH proposals. The technique is widely applicable, but is particularly well-suited to models with discrete and continuous parameters, where the discrete parameters determine which continuous parameters exist.

The polynomial curve-fitting demo we saw in the last problem set is one simple example: the discrete parameter degree determines which coefficient parameters exist. This situation is also common in program synthesis: discrete variables determine a program’s structure, but the values inside the program may be continuous.

Outline

Section 1. Recap of the piecewise-function model

Section 2. Basic Metropolis-Hastings inference

Section 3. Reversilbe-Jump “Split-Merge” proposals

Section 4. Bayesian program synthesis of GP kernels

Section 5. A tree regeneration proposal

using Gen, GenDistributions, Plots, Logging
include("dirichlet.jl")

Logging.disable_logging(Logging.Info);

1. Recap of the piecewise-constant function model

In the intro to modeling tutorial, you worked with a model of piecewise constant functions, with unknown changepoints. Here, we model the same scenario, but somewhat differently.

Given a dataset of xs, our model will randomly divide the range (xmin, xmax) into a random number of segments.

It does this by sampling a number of segments (:segment_count), then sampling a vector of proportions from a Dirichlet distribution (:fractions). The vector is guaranteed to sum to 1: if there are, say, three segments, this vector might be [0.3, 0.5, 0.2]. The length of each segment is the fraction of the interval assigned to it, times the length of the entire interval, e.g. 0.2 * (xmax - xmin). For each segmment, we generate a y value from a normal distribution. Finally, we sample the y values near the piecewise constant function described by the segments.

Using @dist to define new distributions for convenience

To sample the number of segments, we need a distribution with support only on the positive integers. We create one using the Distributions DSL:

# A distribution that is guaranteed to be 1 or higher.
@dist poisson_plus_one(rate) = poisson(rate) + 1;

Distributions declared with @dist can be used to make random choices inside of Gen models. Behind the scenes, @dist has analyzed the code and figured out how to evaluate the logpdf, or log density, of our newly defined distribution. So we can ask, e.g., what the density of poisson_plus_one(1) is at the point 3:

logpdf(poisson_plus_one, 3, 1)
-1.6931471805599454

Note that this is the same as the logpdf of poisson(1) at the point 2@dist’s main job is to automate the logic of converting the above call into this one:

logpdf(poisson, 2, 1)
-1.6931471805599454

Writing the model

We can now write the model itself. It is relatively straightforward, though there are a few things you may not have seen before:

@gen function piecewise_constant(xs::Vector{Float64})
    # Generate a number of segments (at least 1)
    segment_count ~ poisson_plus_one(1)
    
    # To determine changepoints, draw a vector on the simplex from a Dirichlet
    # distribution. This gives us the proportions of the entire interval that each
    # segment takes up. (The entire interval is determined by the minimum and maximum
    # x values.)
    fractions ~ dirichlet([1.0 for i=1:segment_count])
    
    # Generate values for each segment
    segments = [{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]
    
    # Determine a global noise level
    noise ~ gamma(1, 1)
    
    # Generate the y points for the input x points
    xmin, xmax = extrema(xs)
    cumfracs = cumsum(fractions)
    # Correct numeric issue: `cumfracs[end]` might be 0.999999
    @assert cumfracs[end]  1.0
    cumfracs[end] = 1.0

    inds = [findfirst(frac -> frac >= (x - xmin) / (xmax - xmin),
                      cumfracs) for x in xs]
    segment_values = segments[inds]
    for (i, val) in enumerate(segment_values)
        {(:y, i)} ~ normal(val, noise)
    end
end;

Let’s understand its behavior by visualizing several runs of the model. We begin by creating a simple dataset of xs, evenly spaced between -5 and 5.

xs_dense = collect(range(-5, stop=5, length=50));

Don’t worry about understanding the following code, which we use for visualization.

function trace_to_dict(tr)
    Dict(:values => [tr[(:segments, i)] for i=1:(tr[:segment_count])],
         :fracs  => tr[:fractions], :n => tr[:segment_count], :noise => tr[:noise],
         :ys => [tr[(:y, i)] for i=1:length(xs_dense)])
end;

function visualize_trace(tr; title="")
    xs, = get_args(tr)
    tr = trace_to_dict(tr)
    
    scatter(xs, tr[:ys], label=nothing, xlabel="X", ylabel="Y")
    
    cumfracs = [0.0, cumsum(tr[:fracs])...]
    xmin = minimum(xs)
    xmax = maximum(xs)
    for i in 1:tr[:n]
        segment_xs = [xmin + cumfracs[i] * (xmax - xmin), xmin + cumfracs[i+1] * (xmax - xmin)]
        segment_ys = fill(tr[:values][i], 2)
        plot!(segment_xs, segment_ys, label=nothing, linewidth=4)
    end
    Plots.title!(title)
end

traces = [simulate(piecewise_constant, (xs_dense,)) for _ in 1:9]
plot([visualize_trace(t) for t in traces]...)

svg

Many of the samples involve only one segment, but many of them involve more. The level of noise also varies from sample to sample.

2. Basic Metropolis-Hastings inference

Let’s create three synthetic datasets, each more challenging than the last, to test out our inference capabilities.

ys_simple  = ones(length(xs_dense)) .+ randn(length(xs_dense)) * 0.1
ys_medium  = Base.ifelse.(Int.(floor.(abs.(xs_dense ./ 3))) .% 2 .== 0,
                          2, 0) .+ randn(length(xs_dense)) * 0.1;
ys_complex = Int.(floor.(abs.(xs_dense ./ 2))) .% 5 .+ randn(length(xs_dense)) * 0.1;

We’ll need a helper function for creating a choicemap of constraints from a vector of ys:

function make_constraints(ys)
    choicemap([(:y, i) => ys[i] for i=1:length(ys)]...)
end;

As we saw in the last problem set, importance sampling does a decent job on the simple dataset:

NUM_CHAINS = 9

traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_simple), 5000)) for _ in 1:NUM_CHAINS]
Plots.plot([visualize_trace(t) for t in traces]...)

svg

But on the complex dataset, it takes many more particles (here, we use 50,000) to do even an OK job:

traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_complex), 50000)) for _ in 1:9]
scores = [get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores)-log(NUM_CHAINS))")
Plots.plot([visualize_trace(t) for t in traces]...)
Log mean score: -37.681098516181876

svg

Let’s try instead to write a simple Metropolis-Hastings algorithm to tackle the problem.

First, some visualization code (feel free to ignore it!).

function visualize_mh_alg(xs, ys, update, frames=200, iters_per_frame=1, N=NUM_CHAINS; verbose=true)    
    traces = [first(generate(piecewise_constant, (xs,), make_constraints(ys))) for _ in 1:N]
    viz = Plots.@animate for i in 1:frames
        
        if i*iters_per_frame % 100 == 0 && verbose
            println("Iteration $(i*iters_per_frame)")
        end
        
        for j in 1:N
            for k in 1:iters_per_frame
                traces[j] = update(traces[j], xs, ys)
            end
        end
       
        Plots.plot([visualize_trace(t; title=(j == 2 ? "Iteration $(i*iters_per_frame)/$(frames*iters_per_frame)" : "")) for (j,t) in enumerate(traces)]...)#, layout=l)
    end
    scores = [Gen.get_score(t) for t in traces]
    println("Log mean score: $(logsumexp(scores) - log(N))")
    gif(viz)
end;

And now the MH algorithm itself.

We’ll use a basic Block Resimulation sampler, which cycles through the following blocks of variables:

function simple_update(tr, xs, ys)
    tr, = mh(tr, select(:segment_count, :fractions))
    tr, = mh(tr, select(:fractions))
    tr, = mh(tr, select(:noise))
    for i=1:tr[:segment_count]
        tr, = mh(tr, select((:segments, i)))
    end
    tr
end
simple_update (generic function with 1 method)

Our algorithm makes quick work of the simple dataset:

visualize_mh_alg(xs_dense, ys_simple, simple_update, 100, 1)
Iteration 100
Log mean score: 46.82126477547658

On the medium dataset, it does an OK job with less computation than we used in importance sampling:

visualize_mh_alg(xs_dense, ys_medium, simple_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 17.75535817726937

But on the complex dataset, it is still unreliable:

visualize_mh_alg(xs_dense, ys_complex, simple_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: -30.43054336514916


Exercise

One problem with the simple block resimulation algorithm is that the proposals for (:segments, i) are totally uninformed by the data. In this problem, you’ll write a custom proposal (using the techniques we covered in Problem Set 1B) that uses the data to propose good values of y for each segment.

Write a generative function segments_proposal that can serve as a smart proposal distribution for this problem. It should:

We will use this proposal to replace the “Block 1” move from our previous algorithm. This should make it easier to have proposals accepted, because whenever we propose a new segmentation of the interval, we propose it with reasonable y values attached.

We have provided some starter code:

@gen function segments_proposal(t,xs, ys)
    xmin, xmax = minimum(xs), maximum(xs)
    x_range = xmax - xmin
    
    # Propose segment_count and fractions
    # <edit the next two lines>
    segment_count = nothing
    fractions = nothing
    
    for i=1:segment_count
        # Propose (:segments, i) from a normal distribution
        # <your code here; and edit the next line>
        {(:segments, i)} ~ normal(nothing, nothing)
    end
end;

We define custom_update to use segments_proposal in place of the first block from simple_update.

function custom_update(tr, xs, ys)
    (tr, _) = mh(tr, segments_proposal, (xs, ys))
    (tr, _) = mh(tr, select(:fractions))
    (tr, _) = mh(tr, select(:noise))
    for i=1:(tr[:segment_count])
        (tr, _) = mh(tr, select((:segments, i)))
    end
    tr
end
custom_update (generic function with 1 method)

Let’s see how this one does on each dataset:

visualize_mh_alg(xs_dense, ys_medium, custom_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 40.45738893036163

visualize_mh_alg(xs_dense, ys_complex, custom_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: -19.739460952009143

This will often outperform the simplest MH solution, but still leaves something to be desired. The smart segment_proposal helps find good function values for each segment, but doesn’t help us in cases where the segment proportions are wrong; in these cases, the model just decides that noise must be high.


3. Involution MH for Split-Merge proposals

How might we improve on the MH algorithms from the previous section, to more effectively search for good values of the fractions variable?

One approach is to add new proposals to the mix that iteratively refine the fractions, rather than relying on blind resimulation. A natural strategy might be to add split and merge proposals:

MCMC, Metropolis-Hastings, and Reversibility

Note that alone, neither of these two proposals makes for a valid Metropolis-Hastings transition proposal. Why? One requirement of MCMC algorithms is that each transition kernel (kernel, not proposal: in MH, the kernel includes the accept-reject step) leave the posterior distribution stationary. What does this mean? Suppose we somehow obtained an oracle that let us sample a trace from the exact posterior, $t \sim p(t \mid \text{observations})$. Then if we run an MCMC transition kernel $T$ on the trace to obtain a new trace $t’ \sim T(t’ \leftarrow t)$, the marginal distribution of $t’$ should still be the posterior. In simpler terms, no one should be able to tell the difference between “traces sampled from the posterior” (i.e., using the output of the oracle directly) and “traces sampled from the posterior then sent through a transition kernel” (i.e., passing the output of the oracle as the input to the transition kernel and using the result). As a formula,

\[\int p(t \mid \text{observations}) \, T(t' \leftarrow t) \, \mathrm{d}t = p(t' \mid \text{observations})\]

Now suppose we set our transition kernel $T$ to a Metropolis-Hastings split move. The split move always increases the number of segments. So if t comes from the true posterior, then the distribution of t', in terms of number of segments, will necessarily be shifted upward (unless MH deterministically rejects every proposal – which is what MH in Gen will do if given a split proposal). The same issue exists, in reverse, for the merge proposal.

In general, this “stationarity” rule means that our Metropolis-Hastings proposals must be reversible, meaning that if there is some probability that a proposal can take you from one region of the state space to another, it must also have some probability of sending you back from the new region to the old region. If this criterion is satisfied, then the MH accept-reject step can accept and reject proposals with the proper probabilities to ensure that the stationarity property described above holds.

To make “split” and “merge” fulfill this “reversibility” criterion, we can think of them as a constituting a single proposal, which randomly chooses whether to split or merge at each iteration. This is an example of a reversible-jump proposal [1].

References:

  1. Green, Peter J., and David I. Hastie. “Reversible jump MCMC.” Genetics 155.3 (2009): 1391-1403.

Implementing Split-Merge in Gen

This is a sensible proposal. But if we try to write this proposal in Gen, we quickly hit several roadblocks. For example:

To get around these issues, Gen provides a variant of Metropolis-Hastings that is a bit trickier to use but is ultimately more flexible.

The idea is this: first, we write a generative function that samples all the randomness the proposal will require. In our case, this will involve

@gen function split_merge_proposal_randomness(t)
    old_n = t[:segment_count]
    
    # Choose whether to split (T) or to merge (F), keeping in mind
    # that if old_n == 1, then our decision is made for us
    # (split).
    if ({:split_or_merge} ~ bernoulli(old_n == 1 ? 1 : 0.3))
        
        # split
        # What index to split at?
        index ~ uniform_discrete(1,old_n)
        # Where is the splitting point, relative to the segment being split?
        split_percentage ~ uniform(0, 1)
        # New values for the two new segments
        new_value_1 ~ normal(t[(:segments, index)], 0.1)
        new_value_2 ~ normal(t[(:segments, index)], 0.1)
    else
        # merge
        # What index to merge at? (Merge index i and i + 1)
        index ~ uniform_discrete(1, old_n-1) # merge i and i+1
        
        # Sample a new value for the merged segment, near the mean of the
        # two existing segments.
        new_value ~ normal((t[(:segments, index)] + t[(:segments, index+1)]) / 2.0, 0.1)
    end
end;

Let’s look at what this function samples:

tr, = generate(piecewise_constant, (xs_dense,), make_constraints(ys_complex,));
get_choices(simulate(split_merge_proposal_randomness, (tr,)))
│
├── :index : 1
│
├── :split_or_merge : true
│
├── :new_value_2 : -0.12211827468079874
│
├── :split_percentage : 0.7179861487246315
│
└── :new_value_1 : -0.07557249242625629

We now need to write an ordinary Julia function that tells Gen how to use the proposal randomness generated by the generative function above to create a proposed next trace for Metropolis-Hastings.

The function takes four inputs:

It is supposed to return three outputs:

We call the Julia function an involution, because it must have the following property. Suppose we propose a new trace using proposal_randomness_func and then our Julia function f:

forward_choices = get_choices(simulate(proposal_randomness_func, (old_trace,)))
new_trace, backward_choices, = f(old_trace, forward_choices, ...)

Now suppose we use backward_choices and ask what f would do to new_trace:

back_to_the_first_trace_maybe, backward_backward_choices, = f(new_trace, backward_choices, ...)

Then we require that backward_backward_choices == forward_choices, and that back_to_the_first_trace_maybe == old_trace.

Here’s an involution for Split-Merge:

function involution(t, forward_choices, forward_retval, proposal_args)
    # Goal: based on `forward_choices`, return a new trace
    # and the `backward_choices` required to reverse this
    # proposal.
    new_trace_choices = choicemap()
    backward_choices  = choicemap()
    
    # First, check whether we're dealing with a split or a merge.
    split_or_merge = forward_choices[:split_or_merge]
    
    # To reverse a split, use a merge (and vice versa)
    backward_choices[:split_or_merge] = !split_or_merge
        
    # Where is the split / merge occurring?
    # Reversing the proposal would use the same index.
    index = forward_choices[:index]
    backward_choices[:index] = index
    
    # Now we handle the segment values and proportions.
    # First, pull out the existing values from the previous
    # trace `t`. IMPORTANT NOTE: we need to `copy`
    # `t[:fractions]`, because we intend to perform mutating
    # Julia operations like `insert!` and `deleteat!` on it,
    # but do not wish to change the memory underlying the 
    # original trace `t` (in case the proposal is rejected
    # and we need to return to the original trace).
    fractions      = copy(t[:fractions])
    segment_values = [t[(:segments, i)] for i=1:t[:segment_count]]
    
    # How we update `fractions` and `segment_values` depends on
    # whether this is a split or a merge.
    if split_or_merge
        # If this is a split, then add a new element at `index`,
        # according to the split proportion.
        insert!(fractions, index, fractions[index] * forward_choices[:split_percentage])
        fractions[index + 1] *= (1 - forward_choices[:split_percentage])
        
        # Segment values
        backward_choices[:new_value] = segment_values[index]
        segment_values[index] = forward_choices[:new_value_1]
        insert!(segment_values, index, forward_choices[:new_value_2])
    else
        # If this is a merge, then combine the two segments `index`
        # and `index + 1`.
        proportion = fractions[index] / (fractions[index] + fractions[index + 1])
        fractions[index] += fractions[index + 1]
        deleteat!(fractions, index + 1)
        
        # Set the relevant segment values.
        backward_choices[:new_value_1] = segment_values[index]
        backward_choices[:new_value_2] = segment_values[index + 1]
        backward_choices[:split_percentage] = proportion
        segment_values[index] = forward_choices[:new_value]
        deleteat!(segment_values, index + 1)
    end
    
    # Fill a choicemap of the newly proposed trace's values
    new_trace_choices[:fractions] = fractions
    for (i, value) in enumerate(segment_values)
        new_trace_choices[(:segments, i)] = value
    end
    new_trace_choices[:segment_count] = length(fractions)
    
    # Obtain an updated trace matching the choicemap, and a weight
    new_trace, weight, = update(t, get_args(t), (NoChange(),), new_trace_choices)
    (new_trace, backward_choices, weight)
end
involution (generic function with 1 method)

Look at the last two lines of the involution. The involution is responsible for returning a new trace, not just a new choicemap. So, the common pattern to follow in involutions is:

We can create a new MH update that uses our proposal, and test it on the new dataset:

@gen function mean_segments_proposal(t, xs, ys, i)
    xmin, xmax = minimum(xs), maximum(xs)
    x_range = xmax - xmin
    fracs = t[:fractions]
    min = xmin + x_range * sum(fracs[1:i-1])
    max = xmin + x_range * sum(fracs[1:i])
    relevant_ys = [y for (x,y) in zip(xs,ys) if x >= min && x <= max]
    {(:segments, i)} ~ normal(sum(relevant_ys)/length(relevant_ys), 0.3)
end

function custom_update_inv(tr, xs, ys)
    tr, accepted = mh(tr, split_merge_proposal_randomness, (), involution)
    for i=1:tr[:segment_count]
        tr, = mh(tr, mean_segments_proposal, (xs, ys, i))
    end
    tr, = mh(tr, select(:noise))
    tr
end

visualize_mh_alg(xs_dense, ys_complex, custom_update_inv, 75, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Iteration 600
Iteration 700
Log mean score: 26.43941395921412