Tutorial: ReversibleJump MCMC in Gen (with applications to Program Synthesis)
What is this notebook about?
In earlier tutorials, we saw how to write custom MetropolisHastings proposals and stitch them together to craft our own MCMC algorithms. Recall that to write a custom MH proposal, it was necessary to define a generative function that accepted as an argument the previous trace, and proposed a new trace by “acting out” the model, that is, by sampling new proposed values of random choices at the same addresses used by the model.
For example, a random walk proposal for the addresses :x
and :y
might have looked like this:
@gen function random_walk_proposal(previous_trace)
x ~ normal(previous_trace[:x], 0.1)
y ~ normal(previous_trace[:y], 0.2)
end
This pattern implies a severe restriction on the kinds of proposals we can write: the new x
and y
values must be sampled directly from a Ggen distribution (e.g., normal
). For example, it is impossible to use this pattern to define a proposal that deterministcally swaps the current x
and y
values.
In this notebook, you’ll learn a more flexible approach to defining custom MH proposals. The technique is widely applicable, but is particularly wellsuited to models with discrete and continuous parameters, where the discrete parameters determine which continuous parameters exist.
The polynomial curvefitting demo we saw in the last problem set is one simple example: the discrete parameter degree
determines which coefficient parameters exist. This situation is also common in program synthesis: discrete variables determine a program’s structure, but the values inside the program may be continuous.
Outline
Section 1. Recap of the piecewisefunction model
Section 2. Basic MetropolisHastings inference
Section 3. ReversilbeJump “SplitMerge” proposals
Section 4. Bayesian program synthesis of GP kernels
Section 5. A tree regeneration proposal
using Gen, GenDistributions, Plots, Logging
include("dirichlet.jl")
Logging.disable_logging(Logging.Info);
1. Recap of the piecewiseconstant function model
In the intro to modeling tutorial, you worked with a model of piecewise constant functions, with unknown changepoints. Here, we model the same scenario, but somewhat differently.
Given a dataset of xs
, our model will randomly divide the range (xmin, xmax)
into a random number of segments.
It does this by sampling a number of segments (:segment_count
), then sampling a vector of proportions from a Dirichlet distribution (:fractions
). The vector is guaranteed to sum to 1: if there are, say, three segments, this vector might be [0.3, 0.5, 0.2]
. The length of each segment is the fraction of the interval assigned to it, times the length of the entire interval, e.g. 0.2 * (xmax  xmin)
. For each segmment, we generate a y
value from a normal distribution. Finally, we sample the y
values near the piecewise constant function described by the segments.
Using @dist
to define new distributions for convenience
To sample the number of segments, we need a distribution with support only on the positive integers. We create one using the Distributions DSL:
# A distribution that is guaranteed to be 1 or higher.
@dist poisson_plus_one(rate) = poisson(rate) + 1;
Distributions declared with @dist
can be used to make random choices inside of Gen models.
Behind the scenes, @dist
has analyzed the code and figured out how to evaluate the logpdf
,
or log density, of our newly defined distribution. So we can ask, e.g., what the density of
poisson_plus_one(1)
is at the point 3
:
logpdf(poisson_plus_one, 3, 1)
1.6931471805599454
Note that this is the same as the logpdf of poisson(1)
at the point 2
— @dist
’s main job is to
automate the logic of converting the above call into this one:
logpdf(poisson, 2, 1)
1.6931471805599454
Writing the model
We can now write the model itself. It is relatively straightforward, though there are a few things you may not have seen before:

List comprehensions allow you to create a list without writing an entire for loop. For example,
[{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]
creates a list of elements, one for eachi
in the range[1, ..., segment_count]
, each of which is generated using the expression{(:segments, i)} ~ normal(0, 1)
. 
The Dirichlet distribution is a distribution over the simplex of vectors whose elements sum to 1. We use it to generate the fractions of the entire interval (x_min, x_max) that each segment of our piecewise function covers.

The logic to generate the
y
points is as follows: we compute the cumulative fractionscumfracs = cumsum(fractions)
, such thatxmin + (xmax  xmin) * cumfracs[j]
is the xvalue of the right endpoint of thej
th segment. Then we sample at each address(:y, i)
a normal whose mean is the yvalue of the segment that containsxs[i]
.
@gen function piecewise_constant(xs::Vector{Float64})
# Generate a number of segments (at least 1)
segment_count ~ poisson_plus_one(1)
# To determine changepoints, draw a vector on the simplex from a Dirichlet
# distribution. This gives us the proportions of the entire interval that each
# segment takes up. (The entire interval is determined by the minimum and maximum
# x values.)
fractions ~ dirichlet([1.0 for i=1:segment_count])
# Generate values for each segment
segments = [{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]
# Determine a global noise level
noise ~ gamma(1, 1)
# Generate the y points for the input x points
xmin, xmax = extrema(xs)
cumfracs = cumsum(fractions)
# Correct numeric issue: `cumfracs[end]` might be 0.999999
@assert cumfracs[end] ≈ 1.0
cumfracs[end] = 1.0
inds = [findfirst(frac > frac >= (x  xmin) / (xmax  xmin),
cumfracs) for x in xs]
segment_values = segments[inds]
for (i, val) in enumerate(segment_values)
{(:y, i)} ~ normal(val, noise)
end
end;
Let’s understand its behavior by visualizing several runs of the model. We begin by creating a simple dataset of xs, evenly spaced between 5 and 5.
xs_dense = collect(range(5, stop=5, length=50));
Don’t worry about understanding the following code, which we use for visualization.
function trace_to_dict(tr)
Dict(:values => [tr[(:segments, i)] for i=1:(tr[:segment_count])],
:fracs => tr[:fractions], :n => tr[:segment_count], :noise => tr[:noise],
:ys => [tr[(:y, i)] for i=1:length(xs_dense)])
end;
function visualize_trace(tr; title="")
xs, = get_args(tr)
tr = trace_to_dict(tr)
scatter(xs, tr[:ys], label=nothing, xlabel="X", ylabel="Y")
cumfracs = [0.0, cumsum(tr[:fracs])...]
xmin = minimum(xs)
xmax = maximum(xs)
for i in 1:tr[:n]
segment_xs = [xmin + cumfracs[i] * (xmax  xmin), xmin + cumfracs[i+1] * (xmax  xmin)]
segment_ys = fill(tr[:values][i], 2)
plot!(segment_xs, segment_ys, label=nothing, linewidth=4)
end
Plots.title!(title)
end
traces = [simulate(piecewise_constant, (xs_dense,)) for _ in 1:9]
plot([visualize_trace(t) for t in traces]...)
Many of the samples involve only one segment, but many of them involve more. The level of noise also varies from sample to sample.
2. Basic MetropolisHastings inference
Let’s create three synthetic datasets, each more challenging than the last, to test out our inference capabilities.
ys_simple = ones(length(xs_dense)) .+ randn(length(xs_dense)) * 0.1
ys_medium = Base.ifelse.(Int.(floor.(abs.(xs_dense ./ 3))) .% 2 .== 0,
2, 0) .+ randn(length(xs_dense)) * 0.1;
ys_complex = Int.(floor.(abs.(xs_dense ./ 2))) .% 5 .+ randn(length(xs_dense)) * 0.1;
We’ll need a helper function for creating a choicemap of constraints from a vector of ys
:
function make_constraints(ys)
choicemap([(:y, i) => ys[i] for i=1:length(ys)]...)
end;
As we saw in the last problem set, importance sampling does a decent job on the simple dataset:
NUM_CHAINS = 9
traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_simple), 5000)) for _ in 1:NUM_CHAINS]
Plots.plot([visualize_trace(t) for t in traces]...)
But on the complex dataset, it takes many more particles (here, we use 50,000) to do even an OK job:
traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_complex), 50000)) for _ in 1:9]
scores = [get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores)log(NUM_CHAINS))")
Plots.plot([visualize_trace(t) for t in traces]...)
Log mean score: 37.681098516181876
Let’s try instead to write a simple MetropolisHastings algorithm to tackle the problem.
First, some visualization code (feel free to ignore it!).
function visualize_mh_alg(xs, ys, update, frames=200, iters_per_frame=1, N=NUM_CHAINS; verbose=true)
traces = [first(generate(piecewise_constant, (xs,), make_constraints(ys))) for _ in 1:N]
viz = Plots.@animate for i in 1:frames
if i*iters_per_frame % 100 == 0 && verbose
println("Iteration $(i*iters_per_frame)")
end
for j in 1:N
for k in 1:iters_per_frame
traces[j] = update(traces[j], xs, ys)
end
end
Plots.plot([visualize_trace(t; title=(j == 2 ? "Iteration $(i*iters_per_frame)/$(frames*iters_per_frame)" : "")) for (j,t) in enumerate(traces)]...)#, layout=l)
end
scores = [Gen.get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores)  log(N))")
gif(viz)
end;
And now the MH algorithm itself.
We’ll use a basic Block Resimulation sampler, which cycles through the following blocks of variables:

Block 1:
:segment_count
and:fractions
. Resampling these tries proposing a completely new division of the interval into pieces. However, it reuses the(:segments, i)
values wherever possible; that is, if we currently have three segments and:segment_count
is proposed to change to 5, only two new segment values will be sampled. 
Block 2:
:fractions
. This proposal tries leaving the number of segments the same, but resamples their relative lengths. 
Block 3:
:noise
. This proposal adjusts the global noise parameter. 
Blocks 4 and up:
(:segments, i)
. Tries separately proposing new values for each segment (and accepts or rejects each proposal independently).
function simple_update(tr, xs, ys)
tr, = mh(tr, select(:segment_count, :fractions))
tr, = mh(tr, select(:fractions))
tr, = mh(tr, select(:noise))
for i=1:tr[:segment_count]
tr, = mh(tr, select((:segments, i)))
end
tr
end
simple_update (generic function with 1 method)
Our algorithm makes quick work of the simple dataset:
visualize_mh_alg(xs_dense, ys_simple, simple_update, 100, 1)
Iteration 100
Log mean score: 46.82126477547658
On the medium dataset, it does an OK job with less computation than we used in importance sampling:
visualize_mh_alg(xs_dense, ys_medium, simple_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 17.75535817726937