Tutorial: Reversible-Jump MCMC in Gen (with applications to Program Synthesis)
What is this notebook about?
In earlier tutorials, we saw how to write custom Metropolis-Hastings proposals and stitch them together to craft our own MCMC algorithms. Recall that to write a custom MH proposal, it was necessary to define a generative function that accepted as an argument the previous trace, and proposed a new trace by “acting out” the model, that is, by sampling new proposed values of random choices at the same addresses used by the model.
For example, a random walk proposal for the addresses :x
and :y
might have looked like this:
@gen function random_walk_proposal(previous_trace)
x ~ normal(previous_trace[:x], 0.1)
y ~ normal(previous_trace[:y], 0.2)
end
This pattern implies a severe restriction on the kinds of proposals we can write: the new x
and y
values must be sampled directly from a Ggen distribution (e.g., normal
). For example, it is impossible to use this pattern to define a proposal that deterministcally swaps the current x
and y
values.
In this notebook, you’ll learn a more flexible approach to defining custom MH proposals. The technique is widely applicable, but is particularly well-suited to models with discrete and continuous parameters, where the discrete parameters determine which continuous parameters exist.
The polynomial curve-fitting demo we saw in the last problem set is one simple example: the discrete parameter degree
determines which coefficient parameters exist. This situation is also common in program synthesis: discrete variables determine a program’s structure, but the values inside the program may be continuous.
Outline
Section 1. Recap of the piecewise-function model
Section 2. Basic Metropolis-Hastings inference
Section 3. Reversilbe-Jump “Split-Merge” proposals
Section 4. Bayesian program synthesis of GP kernels
Section 5. A tree regeneration proposal
using Gen, GenDistributions, Plots, Logging
include("dirichlet.jl")
Logging.disable_logging(Logging.Info);
1. Recap of the piecewise-constant function model
In the intro to modeling tutorial, you worked with a model of piecewise constant functions, with unknown changepoints. Here, we model the same scenario, but somewhat differently.
Given a dataset of xs
, our model will randomly divide the range (xmin, xmax)
into a random number of segments.
It does this by sampling a number of segments (:segment_count
), then sampling a vector of proportions from a Dirichlet distribution (:fractions
). The vector is guaranteed to sum to 1: if there are, say, three segments, this vector might be [0.3, 0.5, 0.2]
. The length of each segment is the fraction of the interval assigned to it, times the length of the entire interval, e.g. 0.2 * (xmax - xmin)
. For each segmment, we generate a y
value from a normal distribution. Finally, we sample the y
values near the piecewise constant function described by the segments.
Using @dist
to define new distributions for convenience
To sample the number of segments, we need a distribution with support only on the positive integers. We create one using the Distributions DSL:
# A distribution that is guaranteed to be 1 or higher.
@dist poisson_plus_one(rate) = poisson(rate) + 1;
Distributions declared with @dist
can be used to make random choices inside of Gen models.
Behind the scenes, @dist
has analyzed the code and figured out how to evaluate the logpdf
,
or log density, of our newly defined distribution. So we can ask, e.g., what the density of
poisson_plus_one(1)
is at the point 3
:
logpdf(poisson_plus_one, 3, 1)
-1.6931471805599454
Note that this is the same as the logpdf of poisson(1)
at the point 2
— @dist
’s main job is to
automate the logic of converting the above call into this one:
logpdf(poisson, 2, 1)
-1.6931471805599454
Writing the model
We can now write the model itself. It is relatively straightforward, though there are a few things you may not have seen before:
-
List comprehensions allow you to create a list without writing an entire for loop. For example,
[{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]
creates a list of elements, one for eachi
in the range[1, ..., segment_count]
, each of which is generated using the expression{(:segments, i)} ~ normal(0, 1)
. -
The Dirichlet distribution is a distribution over the simplex of vectors whose elements sum to 1. We use it to generate the fractions of the entire interval (x_min, x_max) that each segment of our piecewise function covers.
-
The logic to generate the
y
points is as follows: we compute the cumulative fractionscumfracs = cumsum(fractions)
, such thatxmin + (xmax - xmin) * cumfracs[j]
is the x-value of the right endpoint of thej
th segment. Then we sample at each address(:y, i)
a normal whose mean is the y-value of the segment that containsxs[i]
.
@gen function piecewise_constant(xs::Vector{Float64})
# Generate a number of segments (at least 1)
segment_count ~ poisson_plus_one(1)
# To determine changepoints, draw a vector on the simplex from a Dirichlet
# distribution. This gives us the proportions of the entire interval that each
# segment takes up. (The entire interval is determined by the minimum and maximum
# x values.)
fractions ~ dirichlet([1.0 for i=1:segment_count])
# Generate values for each segment
segments = [{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]
# Determine a global noise level
noise ~ gamma(1, 1)
# Generate the y points for the input x points
xmin, xmax = extrema(xs)
cumfracs = cumsum(fractions)
# Correct numeric issue: `cumfracs[end]` might be 0.999999
@assert cumfracs[end] ≈ 1.0
cumfracs[end] = 1.0
inds = [findfirst(frac -> frac >= (x - xmin) / (xmax - xmin),
cumfracs) for x in xs]
segment_values = segments[inds]
for (i, val) in enumerate(segment_values)
{(:y, i)} ~ normal(val, noise)
end
end;
Let’s understand its behavior by visualizing several runs of the model. We begin by creating a simple dataset of xs, evenly spaced between -5 and 5.
xs_dense = collect(range(-5, stop=5, length=50));
Don’t worry about understanding the following code, which we use for visualization.
function trace_to_dict(tr)
Dict(:values => [tr[(:segments, i)] for i=1:(tr[:segment_count])],
:fracs => tr[:fractions], :n => tr[:segment_count], :noise => tr[:noise],
:ys => [tr[(:y, i)] for i=1:length(xs_dense)])
end;
function visualize_trace(tr; title="")
xs, = get_args(tr)
tr = trace_to_dict(tr)
scatter(xs, tr[:ys], label=nothing, xlabel="X", ylabel="Y")
cumfracs = [0.0, cumsum(tr[:fracs])...]
xmin = minimum(xs)
xmax = maximum(xs)
for i in 1:tr[:n]
segment_xs = [xmin + cumfracs[i] * (xmax - xmin), xmin + cumfracs[i+1] * (xmax - xmin)]
segment_ys = fill(tr[:values][i], 2)
plot!(segment_xs, segment_ys, label=nothing, linewidth=4)
end
Plots.title!(title)
end
traces = [simulate(piecewise_constant, (xs_dense,)) for _ in 1:9]
plot([visualize_trace(t) for t in traces]...)
Many of the samples involve only one segment, but many of them involve more. The level of noise also varies from sample to sample.
2. Basic Metropolis-Hastings inference
Let’s create three synthetic datasets, each more challenging than the last, to test out our inference capabilities.
ys_simple = ones(length(xs_dense)) .+ randn(length(xs_dense)) * 0.1
ys_medium = Base.ifelse.(Int.(floor.(abs.(xs_dense ./ 3))) .% 2 .== 0,
2, 0) .+ randn(length(xs_dense)) * 0.1;
ys_complex = Int.(floor.(abs.(xs_dense ./ 2))) .% 5 .+ randn(length(xs_dense)) * 0.1;
We’ll need a helper function for creating a choicemap of constraints from a vector of ys
:
function make_constraints(ys)
choicemap([(:y, i) => ys[i] for i=1:length(ys)]...)
end;
As we saw in the last problem set, importance sampling does a decent job on the simple dataset:
NUM_CHAINS = 9
traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_simple), 5000)) for _ in 1:NUM_CHAINS]
Plots.plot([visualize_trace(t) for t in traces]...)
But on the complex dataset, it takes many more particles (here, we use 50,000) to do even an OK job:
traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_complex), 50000)) for _ in 1:9]
scores = [get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores)-log(NUM_CHAINS))")
Plots.plot([visualize_trace(t) for t in traces]...)
Log mean score: -37.681098516181876
Let’s try instead to write a simple Metropolis-Hastings algorithm to tackle the problem.
First, some visualization code (feel free to ignore it!).
function visualize_mh_alg(xs, ys, update, frames=200, iters_per_frame=1, N=NUM_CHAINS; verbose=true)
traces = [first(generate(piecewise_constant, (xs,), make_constraints(ys))) for _ in 1:N]
viz = Plots.@animate for i in 1:frames
if i*iters_per_frame % 100 == 0 && verbose
println("Iteration $(i*iters_per_frame)")
end
for j in 1:N
for k in 1:iters_per_frame
traces[j] = update(traces[j], xs, ys)
end
end
Plots.plot([visualize_trace(t; title=(j == 2 ? "Iteration $(i*iters_per_frame)/$(frames*iters_per_frame)" : "")) for (j,t) in enumerate(traces)]...)#, layout=l)
end
scores = [Gen.get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores) - log(N))")
gif(viz)
end;
And now the MH algorithm itself.
We’ll use a basic Block Resimulation sampler, which cycles through the following blocks of variables:
-
Block 1:
:segment_count
and:fractions
. Resampling these tries proposing a completely new division of the interval into pieces. However, it reuses the(:segments, i)
values wherever possible; that is, if we currently have three segments and:segment_count
is proposed to change to 5, only two new segment values will be sampled. -
Block 2:
:fractions
. This proposal tries leaving the number of segments the same, but resamples their relative lengths. -
Block 3:
:noise
. This proposal adjusts the global noise parameter. -
Blocks 4 and up:
(:segments, i)
. Tries separately proposing new values for each segment (and accepts or rejects each proposal independently).
function simple_update(tr, xs, ys)
tr, = mh(tr, select(:segment_count, :fractions))
tr, = mh(tr, select(:fractions))
tr, = mh(tr, select(:noise))
for i=1:tr[:segment_count]
tr, = mh(tr, select((:segments, i)))
end
tr
end
simple_update (generic function with 1 method)
Our algorithm makes quick work of the simple dataset:
visualize_mh_alg(xs_dense, ys_simple, simple_update, 100, 1)
Iteration 100
Log mean score: 46.82126477547658
On the medium dataset, it does an OK job with less computation than we used in importance sampling:
visualize_mh_alg(xs_dense, ys_medium, simple_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 17.75535817726937