Tutorial: Reversible-Jump MCMC in Gen (with applications to Program Synthesis)

What is this notebook about?

In earlier tutorials, we saw how to write custom Metropolis-Hastings proposals and stitch them together to craft our own MCMC algorithms. Recall that to write a custom MH proposal, it was necessary to define a generative function that accepted as an argument the previous trace, and proposed a new trace by “acting out” the model, that is, by sampling new proposed values of random choices at the same addresses used by the model.

For example, a random walk proposal for the addresses :x and :y might have looked like this:

@gen function random_walk_proposal(previous_trace)
    x ~ normal(previous_trace[:x], 0.1)
    y ~ normal(previous_trace[:y], 0.2)
end

This pattern implies a severe restriction on the kinds of proposals we can write: the new x and y values must be sampled directly from a Ggen distribution (e.g., normal). For example, it is impossible to use this pattern to define a proposal that deterministcally swaps the current x and y values.

In this notebook, you’ll learn a more flexible approach to defining custom MH proposals. The technique is widely applicable, but is particularly well-suited to models with discrete and continuous parameters, where the discrete parameters determine which continuous parameters exist.

The polynomial curve-fitting demo we saw in the last problem set is one simple example: the discrete parameter degree determines which coefficient parameters exist. This situation is also common in program synthesis: discrete variables determine a program’s structure, but the values inside the program may be continuous.

Outline

Section 1. Recap of the piecewise-function model

Section 2. Basic Metropolis-Hastings inference

Section 3. Reversilbe-Jump “Split-Merge” proposals

Section 4. Bayesian program synthesis of GP kernels

Section 5. A tree regeneration proposal

using Gen, GenDistributions, Plots, Logging
include("dirichlet.jl")

Logging.disable_logging(Logging.Info);

1. Recap of the piecewise-constant function model

In the intro to modeling tutorial, you worked with a model of piecewise constant functions, with unknown changepoints. Here, we model the same scenario, but somewhat differently.

Given a dataset of xs, our model will randomly divide the range (xmin, xmax) into a random number of segments.

It does this by sampling a number of segments (:segment_count), then sampling a vector of proportions from a Dirichlet distribution (:fractions). The vector is guaranteed to sum to 1: if there are, say, three segments, this vector might be [0.3, 0.5, 0.2]. The length of each segment is the fraction of the interval assigned to it, times the length of the entire interval, e.g. 0.2 * (xmax - xmin). For each segmment, we generate a y value from a normal distribution. Finally, we sample the y values near the piecewise constant function described by the segments.

Using @dist to define new distributions for convenience

To sample the number of segments, we need a distribution with support only on the positive integers. We create one using the Distributions DSL:

# A distribution that is guaranteed to be 1 or higher.
@dist poisson_plus_one(rate) = poisson(rate) + 1;

Distributions declared with @dist can be used to make random choices inside of Gen models. Behind the scenes, @dist has analyzed the code and figured out how to evaluate the logpdf, or log density, of our newly defined distribution. So we can ask, e.g., what the density of poisson_plus_one(1) is at the point 3:

logpdf(poisson_plus_one, 3, 1)
-1.6931471805599454

Note that this is the same as the logpdf of poisson(1) at the point 2@dist’s main job is to automate the logic of converting the above call into this one:

logpdf(poisson, 2, 1)
-1.6931471805599454

Writing the model

We can now write the model itself. It is relatively straightforward, though there are a few things you may not have seen before:

@gen function piecewise_constant(xs::Vector{Float64})
    # Generate a number of segments (at least 1)
    segment_count ~ poisson_plus_one(1)
    
    # To determine changepoints, draw a vector on the simplex from a Dirichlet
    # distribution. This gives us the proportions of the entire interval that each
    # segment takes up. (The entire interval is determined by the minimum and maximum
    # x values.)
    fractions ~ dirichlet([1.0 for i=1:segment_count])
    
    # Generate values for each segment
    segments = [{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]
    
    # Determine a global noise level
    noise ~ gamma(1, 1)
    
    # Generate the y points for the input x points
    xmin, xmax = extrema(xs)
    cumfracs = cumsum(fractions)
    # Correct numeric issue: `cumfracs[end]` might be 0.999999
    @assert cumfracs[end]  1.0
    cumfracs[end] = 1.0

    inds = [findfirst(frac -> frac >= (x - xmin) / (xmax - xmin),
                      cumfracs) for x in xs]
    segment_values = segments[inds]
    for (i, val) in enumerate(segment_values)
        {(:y, i)} ~ normal(val, noise)
    end
end;

Let’s understand its behavior by visualizing several runs of the model. We begin by creating a simple dataset of xs, evenly spaced between -5 and 5.

xs_dense = collect(range(-5, stop=5, length=50));

Don’t worry about understanding the following code, which we use for visualization.

function trace_to_dict(tr)
    Dict(:values => [tr[(:segments, i)] for i=1:(tr[:segment_count])],
         :fracs  => tr[:fractions], :n => tr[:segment_count], :noise => tr[:noise],
         :ys => [tr[(:y, i)] for i=1:length(xs_dense)])
end;

function visualize_trace(tr; title="")
    xs, = get_args(tr)
    tr = trace_to_dict(tr)
    
    scatter(xs, tr[:ys], label=nothing, xlabel="X", ylabel="Y")
    
    cumfracs = [0.0, cumsum(tr[:fracs])...]
    xmin = minimum(xs)
    xmax = maximum(xs)
    for i in 1:tr[:n]
        segment_xs = [xmin + cumfracs[i] * (xmax - xmin), xmin + cumfracs[i+1] * (xmax - xmin)]
        segment_ys = fill(tr[:values][i], 2)
        plot!(segment_xs, segment_ys, label=nothing, linewidth=4)
    end
    Plots.title!(title)
end

traces = [simulate(piecewise_constant, (xs_dense,)) for _ in 1:9]
plot([visualize_trace(t) for t in traces]...)

svg

Many of the samples involve only one segment, but many of them involve more. The level of noise also varies from sample to sample.

2. Basic Metropolis-Hastings inference

Let’s create three synthetic datasets, each more challenging than the last, to test out our inference capabilities.

ys_simple  = ones(length(xs_dense)) .+ randn(length(xs_dense)) * 0.1
ys_medium  = Base.ifelse.(Int.(floor.(abs.(xs_dense ./ 3))) .% 2 .== 0,
                          2, 0) .+ randn(length(xs_dense)) * 0.1;
ys_complex = Int.(floor.(abs.(xs_dense ./ 2))) .% 5 .+ randn(length(xs_dense)) * 0.1;

We’ll need a helper function for creating a choicemap of constraints from a vector of ys:

function make_constraints(ys)
    choicemap([(:y, i) => ys[i] for i=1:length(ys)]...)
end;

As we saw in the last problem set, importance sampling does a decent job on the simple dataset:

NUM_CHAINS = 9

traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_simple), 5000)) for _ in 1:NUM_CHAINS]
Plots.plot([visualize_trace(t) for t in traces]...)

svg

But on the complex dataset, it takes many more particles (here, we use 50,000) to do even an OK job:

traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_complex), 50000)) for _ in 1:9]
scores = [get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores)-log(NUM_CHAINS))")
Plots.plot([visualize_trace(t) for t in traces]...)
Log mean score: -37.681098516181876

svg

Let’s try instead to write a simple Metropolis-Hastings algorithm to tackle the problem.

First, some visualization code (feel free to ignore it!).

function visualize_mh_alg(xs, ys, update, frames=200, iters_per_frame=1, N=NUM_CHAINS; verbose=true)    
    traces = [first(generate(piecewise_constant, (xs,), make_constraints(ys))) for _ in 1:N]
    viz = Plots.@animate for i in 1:frames
        
        if i*iters_per_frame % 100 == 0 && verbose
            println("Iteration $(i*iters_per_frame)")
        end
        
        for j in 1:N
            for k in 1:iters_per_frame
                traces[j] = update(traces[j], xs, ys)
            end
        end
       
        Plots.plot([visualize_trace(t; title=(j == 2 ? "Iteration $(i*iters_per_frame)/$(frames*iters_per_frame)" : "")) for (j,t) in enumerate(traces)]...)#, layout=l)
    end
    scores = [Gen.get_score(t) for t in traces]
    println("Log mean score: $(logsumexp(scores) - log(N))")
    gif(viz)
end;

And now the MH algorithm itself.

We’ll use a basic Block Resimulation sampler, which cycles through the following blocks of variables:

function simple_update(tr, xs, ys)
    tr, = mh(tr, select(:segment_count, :fractions))
    tr, = mh(tr, select(:fractions))
    tr, = mh(tr, select(:noise))
    for i=1:tr[:segment_count]
        tr, = mh(tr, select((:segments, i)))
    end
    tr
end
simple_update (generic function with 1 method)

Our algorithm makes quick work of the simple dataset:

visualize_mh_alg(xs_dense, ys_simple, simple_update, 100, 1)
Iteration 100
Log mean score: 46.82126477547658

On the medium dataset, it does an OK job with less computation than we used in importance sampling:

visualize_mh_alg(xs_dense, ys_medium, simple_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 17.75535817726937