Tutorial: Reversible-Jump MCMC in Gen (with applications to Program Synthesis)
What is this notebook about?
In earlier tutorials, we saw how to write custom Metropolis-Hastings proposals and stitch them together to craft our own MCMC algorithms. Recall that to write a custom MH proposal, it was necessary to define a generative function that accepted as an argument the previous trace, and proposed a new trace by “acting out” the model, that is, by sampling new proposed values of random choices at the same addresses used by the model.
For example, a random walk proposal for the addresses :x and :y might have looked like this:
@gen function random_walk_proposal(previous_trace)
x ~ normal(previous_trace[:x], 0.1)
y ~ normal(previous_trace[:y], 0.2)
end
This pattern implies a severe restriction on the kinds of proposals we can write: the new x and y values must be sampled directly from a Ggen distribution (e.g., normal). For example, it is impossible to use this pattern to define a proposal that deterministcally swaps the current x and y values.
In this notebook, you’ll learn a more flexible approach to defining custom MH proposals. The technique is widely applicable, but is particularly well-suited to models with discrete and continuous parameters, where the discrete parameters determine which continuous parameters exist.
The polynomial curve-fitting demo we saw in the last problem set is one simple example: the discrete parameter degree determines which coefficient parameters exist. This situation is also common in program synthesis: discrete variables determine a program’s structure, but the values inside the program may be continuous.
Outline
Section 1. Recap of the piecewise-function model
Section 2. Basic Metropolis-Hastings inference
Section 3. Reversilbe-Jump “Split-Merge” proposals
Section 4. Bayesian program synthesis of GP kernels
Section 5. A tree regeneration proposal
using Gen, GenDistributions, Plots, Logging
include("dirichlet.jl")
Logging.disable_logging(Logging.Info);
1. Recap of the piecewise-constant function model
In the intro to modeling tutorial, you worked with a model of piecewise constant functions, with unknown changepoints. Here, we model the same scenario, but somewhat differently.
Given a dataset of xs, our model will randomly divide the range (xmin, xmax) into a random number of segments.
It does this by sampling a number of segments (:segment_count), then sampling a vector of proportions from a Dirichlet distribution (:fractions). The vector is guaranteed to sum to 1: if there are, say, three segments, this vector might be [0.3, 0.5, 0.2]. The length of each segment is the fraction of the interval assigned to it, times the length of the entire interval, e.g. 0.2 * (xmax - xmin). For each segmment, we generate a y value from a normal distribution. Finally, we sample the y values near the piecewise constant function described by the segments.
Using @dist to define new distributions for convenience
To sample the number of segments, we need a distribution with support only on the positive integers. We create one using the Distributions DSL:
# A distribution that is guaranteed to be 1 or higher.
@dist poisson_plus_one(rate) = poisson(rate) + 1;
Distributions declared with @dist can be used to make random choices inside of Gen models.
Behind the scenes, @dist has analyzed the code and figured out how to evaluate the logpdf,
or log density, of our newly defined distribution. So we can ask, e.g., what the density of
poisson_plus_one(1) is at the point 3:
logpdf(poisson_plus_one, 3, 1)
-1.6931471805599454
Note that this is the same as the logpdf of poisson(1) at the point 2 — @dist’s main job is to
automate the logic of converting the above call into this one:
logpdf(poisson, 2, 1)
-1.6931471805599454
Writing the model
We can now write the model itself. It is relatively straightforward, though there are a few things you may not have seen before:
-
List comprehensions allow you to create a list without writing an entire for loop. For example,
[{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]creates a list of elements, one for eachiin the range[1, ..., segment_count], each of which is generated using the expression{(:segments, i)} ~ normal(0, 1). -
The Dirichlet distribution is a distribution over the simplex of vectors whose elements sum to 1. We use it to generate the fractions of the entire interval (x_min, x_max) that each segment of our piecewise function covers.
-
The logic to generate the
ypoints is as follows: we compute the cumulative fractionscumfracs = cumsum(fractions), such thatxmin + (xmax - xmin) * cumfracs[j]is the x-value of the right endpoint of thejth segment. Then we sample at each address(:y, i)a normal whose mean is the y-value of the segment that containsxs[i].
@gen function piecewise_constant(xs::Vector{Float64})
# Generate a number of segments (at least 1)
segment_count ~ poisson_plus_one(1)
# To determine changepoints, draw a vector on the simplex from a Dirichlet
# distribution. This gives us the proportions of the entire interval that each
# segment takes up. (The entire interval is determined by the minimum and maximum
# x values.)
fractions ~ dirichlet([1.0 for i=1:segment_count])
# Generate values for each segment
segments = [{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]
# Determine a global noise level
noise ~ gamma(1, 1)
# Generate the y points for the input x points
xmin, xmax = extrema(xs)
cumfracs = cumsum(fractions)
# Correct numeric issue: `cumfracs[end]` might be 0.999999
@assert cumfracs[end] ≈ 1.0
cumfracs[end] = 1.0
inds = [findfirst(frac -> frac >= (x - xmin) / (xmax - xmin),
cumfracs) for x in xs]
segment_values = segments[inds]
for (i, val) in enumerate(segment_values)
{(:y, i)} ~ normal(val, noise)
end
end;
Let’s understand its behavior by visualizing several runs of the model. We begin by creating a simple dataset of xs, evenly spaced between -5 and 5.
xs_dense = collect(range(-5, stop=5, length=50));
Don’t worry about understanding the following code, which we use for visualization.
function trace_to_dict(tr)
Dict(:values => [tr[(:segments, i)] for i=1:(tr[:segment_count])],
:fracs => tr[:fractions], :n => tr[:segment_count], :noise => tr[:noise],
:ys => [tr[(:y, i)] for i=1:length(xs_dense)])
end;
function visualize_trace(tr; title="")
xs, = get_args(tr)
tr = trace_to_dict(tr)
scatter(xs, tr[:ys], label=nothing, xlabel="X", ylabel="Y")
cumfracs = [0.0, cumsum(tr[:fracs])...]
xmin = minimum(xs)
xmax = maximum(xs)
for i in 1:tr[:n]
segment_xs = [xmin + cumfracs[i] * (xmax - xmin), xmin + cumfracs[i+1] * (xmax - xmin)]
segment_ys = fill(tr[:values][i], 2)
plot!(segment_xs, segment_ys, label=nothing, linewidth=4)
end
Plots.title!(title)
end
traces = [simulate(piecewise_constant, (xs_dense,)) for _ in 1:9]
plot([visualize_trace(t) for t in traces]...)
Many of the samples involve only one segment, but many of them involve more. The level of noise also varies from sample to sample.
2. Basic Metropolis-Hastings inference
Let’s create three synthetic datasets, each more challenging than the last, to test out our inference capabilities.
ys_simple = ones(length(xs_dense)) .+ randn(length(xs_dense)) * 0.1
ys_medium = Base.ifelse.(Int.(floor.(abs.(xs_dense ./ 3))) .% 2 .== 0,
2, 0) .+ randn(length(xs_dense)) * 0.1;
ys_complex = Int.(floor.(abs.(xs_dense ./ 2))) .% 5 .+ randn(length(xs_dense)) * 0.1;
We’ll need a helper function for creating a choicemap of constraints from a vector of ys:
function make_constraints(ys)
choicemap([(:y, i) => ys[i] for i=1:length(ys)]...)
end;
As we saw in the last problem set, importance sampling does a decent job on the simple dataset:
NUM_CHAINS = 9
traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_simple), 5000)) for _ in 1:NUM_CHAINS]
Plots.plot([visualize_trace(t) for t in traces]...)
But on the complex dataset, it takes many more particles (here, we use 50,000) to do even an OK job:
traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_complex), 50000)) for _ in 1:9]
scores = [get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores)-log(NUM_CHAINS))")
Plots.plot([visualize_trace(t) for t in traces]...)
Log mean score: -37.681098516181876
Let’s try instead to write a simple Metropolis-Hastings algorithm to tackle the problem.
First, some visualization code (feel free to ignore it!).
function visualize_mh_alg(xs, ys, update, frames=200, iters_per_frame=1, N=NUM_CHAINS; verbose=true)
traces = [first(generate(piecewise_constant, (xs,), make_constraints(ys))) for _ in 1:N]
viz = Plots.@animate for i in 1:frames
if i*iters_per_frame % 100 == 0 && verbose
println("Iteration $(i*iters_per_frame)")
end
for j in 1:N
for k in 1:iters_per_frame
traces[j] = update(traces[j], xs, ys)
end
end
Plots.plot([visualize_trace(t; title=(j == 2 ? "Iteration $(i*iters_per_frame)/$(frames*iters_per_frame)" : "")) for (j,t) in enumerate(traces)]...)#, layout=l)
end
scores = [Gen.get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores) - log(N))")
gif(viz)
end;
And now the MH algorithm itself.
We’ll use a basic Block Resimulation sampler, which cycles through the following blocks of variables:
-
Block 1:
:segment_countand:fractions. Resampling these tries proposing a completely new division of the interval into pieces. However, it reuses the(:segments, i)values wherever possible; that is, if we currently have three segments and:segment_countis proposed to change to 5, only two new segment values will be sampled. -
Block 2:
:fractions. This proposal tries leaving the number of segments the same, but resamples their relative lengths. -
Block 3:
:noise. This proposal adjusts the global noise parameter. -
Blocks 4 and up:
(:segments, i). Tries separately proposing new values for each segment (and accepts or rejects each proposal independently).
function simple_update(tr, xs, ys)
tr, = mh(tr, select(:segment_count, :fractions))
tr, = mh(tr, select(:fractions))
tr, = mh(tr, select(:noise))
for i=1:tr[:segment_count]
tr, = mh(tr, select((:segments, i)))
end
tr
end
simple_update (generic function with 1 method)
Our algorithm makes quick work of the simple dataset:
visualize_mh_alg(xs_dense, ys_simple, simple_update, 100, 1)
Iteration 100
Log mean score: 46.82126477547658
On the medium dataset, it does an OK job with less computation than we used in importance sampling:
visualize_mh_alg(xs_dense, ys_medium, simple_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 17.75535817726937
But on the complex dataset, it is still unreliable:
visualize_mh_alg(xs_dense, ys_complex, simple_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: -30.43054336514916
Exercise
One problem with the simple block resimulation algorithm is that the proposals for (:segments, i) are
totally uninformed by the data. In this problem, you’ll write a custom proposal (using the techniques we
covered in Problem Set 1B) that uses the data to propose good values of y for each segment.
Write a generative function segments_proposal that can serve as a smart proposal distribution for
this problem. It should:
- Propose a new
:segment_countfrompoisson_plus_one(1)(the prior). - Propose a new
:fractionsfromdirichlet([1.0 for i=1:segment_count])(the prior). - In each segment, propose the function value
(:segments, i)to be (a noisy version of) the average of theyvalues in our dataset from the given segment. Draw(:segments, i)from a normal distribution with that mean, and a small standard deviation (e.g. 0.3).
We will use this proposal to replace the “Block 1” move from our previous algorithm. This should
make it easier to have proposals accepted, because whenever we propose a new segmentation of the
interval, we propose it with reasonable y values attached.
We have provided some starter code:
@gen function segments_proposal(t,xs, ys)
xmin, xmax = minimum(xs), maximum(xs)
x_range = xmax - xmin
# Propose segment_count and fractions
# <edit the next two lines>
segment_count = nothing
fractions = nothing
for i=1:segment_count
# Propose (:segments, i) from a normal distribution
# <your code here; and edit the next line>
{(:segments, i)} ~ normal(nothing, nothing)
end
end;
We define custom_update to use segments_proposal in place of the first block from simple_update.
function custom_update(tr, xs, ys)
(tr, _) = mh(tr, segments_proposal, (xs, ys))
(tr, _) = mh(tr, select(:fractions))
(tr, _) = mh(tr, select(:noise))
for i=1:(tr[:segment_count])
(tr, _) = mh(tr, select((:segments, i)))
end
tr
end
custom_update (generic function with 1 method)
Let’s see how this one does on each dataset:
visualize_mh_alg(xs_dense, ys_medium, custom_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 40.45738893036163
visualize_mh_alg(xs_dense, ys_complex, custom_update, 50, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: -19.739460952009143
This will often outperform the simplest MH solution, but still leaves something to be desired. The smart segment_proposal helps find good function values for each segment, but doesn’t help us in cases where the segment proportions are wrong; in these cases, the model just decides that noise must be high.
3. Involution MH for Split-Merge proposals
How might we improve on the MH algorithms from the previous section, to more effectively search for good values of the fractions variable?
One approach is to add new proposals to the mix that iteratively refine the fractions, rather than
relying on blind resimulation. A natural strategy might be to add split and merge proposals:
- A split chooses a segment to break into two pieces at a random point (and chooses new values for the two segments).
- A merge chooses two adjacent segments to merge together into one segment (with a shared value).
MCMC, Metropolis-Hastings, and Reversibility
Note that alone, neither of these two proposals makes for a valid Metropolis-Hastings transition proposal. Why? One requirement of MCMC algorithms is that each transition kernel (kernel, not proposal: in MH, the kernel includes the accept-reject step) leave the posterior distribution stationary. What does this mean? Suppose we somehow obtained an oracle that let us sample a trace from the exact posterior, $t \sim p(t \mid \text{observations})$. Then if we run an MCMC transition kernel $T$ on the trace to obtain a new trace $t’ \sim T(t’ \leftarrow t)$, the marginal distribution of $t’$ should still be the posterior. In simpler terms, no one should be able to tell the difference between “traces sampled from the posterior” (i.e., using the output of the oracle directly) and “traces sampled from the posterior then sent through a transition kernel” (i.e., passing the output of the oracle as the input to the transition kernel and using the result). As a formula,
\[\int p(t \mid \text{observations}) \, T(t' \leftarrow t) \, \mathrm{d}t = p(t' \mid \text{observations})\]Now suppose we set our transition kernel $T$ to a Metropolis-Hastings split move. The split move
always increases the number of segments. So if t comes from the true posterior, then the distribution
of t', in terms of number of segments, will necessarily be shifted upward (unless MH deterministically
rejects every proposal – which is what MH in Gen will do if given a split proposal). The same issue exists,
in reverse, for the merge proposal.
In general, this “stationarity” rule means that our Metropolis-Hastings proposals must be reversible, meaning that if there is some probability that a proposal can take you from one region of the state space to another, it must also have some probability of sending you back from the new region to the old region. If this criterion is satisfied, then the MH accept-reject step can accept and reject proposals with the proper probabilities to ensure that the stationarity property described above holds.
To make “split” and “merge” fulfill this “reversibility” criterion, we can think of them as a constituting a single proposal, which randomly chooses whether to split or merge at each iteration. This is an example of a reversible-jump proposal [1].
References:
- Green, Peter J., and David I. Hastie. “Reversible jump MCMC.” Genetics 155.3 (2009): 1391-1403.
Implementing Split-Merge in Gen
This is a sensible proposal. But if we try to write this proposal in Gen, we quickly hit several roadblocks. For example:
- The proposal needs to make several random choices that are not meant to serve as proposals for corresponding random choices in the model. For example, the proposal must decide whether to “split” or “merge,” and then it needs to decide at which index it will split or merge. But Gen interprets every traced random choice made by a proposal as corresponding to some choice in the model.
- Once we choose to split (or merge), it’s unclear how we should propose to the various
relevant addresses: from what distribution should the proposal sample
fractions, for example? What we need is to propose a deterministic value forfractions, based on the random choices the proposal makes.
To get around these issues, Gen provides a variant of Metropolis-Hastings that is a bit trickier to use but is ultimately more flexible.
The idea is this: first, we write a generative function that samples all the randomness the proposal will require. In our case, this will involve
- choosing whether to split or merge
- choosing at what index the split or merge will happen
- if splitting, choosing where in a segment to split
- choosing new values for merged segments or newly created split segments
@gen function split_merge_proposal_randomness(t)
old_n = t[:segment_count]
# Choose whether to split (T) or to merge (F), keeping in mind
# that if old_n == 1, then our decision is made for us
# (split).
if ({:split_or_merge} ~ bernoulli(old_n == 1 ? 1 : 0.3))
# split
# What index to split at?
index ~ uniform_discrete(1,old_n)
# Where is the splitting point, relative to the segment being split?
split_percentage ~ uniform(0, 1)
# New values for the two new segments
new_value_1 ~ normal(t[(:segments, index)], 0.1)
new_value_2 ~ normal(t[(:segments, index)], 0.1)
else
# merge
# What index to merge at? (Merge index i and i + 1)
index ~ uniform_discrete(1, old_n-1) # merge i and i+1
# Sample a new value for the merged segment, near the mean of the
# two existing segments.
new_value ~ normal((t[(:segments, index)] + t[(:segments, index+1)]) / 2.0, 0.1)
end
end;
Let’s look at what this function samples:
tr, = generate(piecewise_constant, (xs_dense,), make_constraints(ys_complex,));
get_choices(simulate(split_merge_proposal_randomness, (tr,)))
│
├── :index : 1
│
├── :split_or_merge : true
│
├── :new_value_2 : -0.12211827468079874
│
├── :split_percentage : 0.7179861487246315
│
└── :new_value_1 : -0.07557249242625629
We now need to write an ordinary Julia function that tells Gen how to use the proposal randomness generated by the generative function above to create a proposed next trace for Metropolis-Hastings.
The function takes four inputs:
- The previous trace
t - The proposal randomness generated by our generative function above,
forward_randomness - The return value of the generative function we defined,
forward_ret– we won’t use this for now - The arguments, other than
t, that were used to generateforward_randomnessfrom the generative function written above. In our case, that function has no additional arguments, so this will always be empty.
It is supposed to return three outputs:
new_trace: an updated trace to propose. We can construct this however we want based ontandforward_choices.backward_choices: a choicemap for theproposal_randomnessgenerative function, capable of “sendingnew_traceback to the old tracet”. For example, if we are enacting asplitproposal, this would specify the precisemergeproposal necessary to undo the split. This serves as “proof” that our proposal really is reversible.weight: usually,get_score(new_trace) - get_score(t). The reason we need to return this is that it is possible to write proposals that need to return something different here; if your proposal involves deterministically transforming continuous random variables in non-volume-preserving ways, see the Gen documentation for details on howweightshould change. In this notebook,weightwill always beget_score(new_trace) - get_score(t).
We call the Julia function an involution, because it must have the following property.
Suppose we propose a new trace using proposal_randomness_func and then our Julia function
f:
forward_choices = get_choices(simulate(proposal_randomness_func, (old_trace,)))
new_trace, backward_choices, = f(old_trace, forward_choices, ...)
Now suppose we use backward_choices and ask what f would do to new_trace:
back_to_the_first_trace_maybe, backward_backward_choices, = f(new_trace, backward_choices, ...)
Then we require that backward_backward_choices == forward_choices, and that back_to_the_first_trace_maybe == old_trace.
Here’s an involution for Split-Merge:
function involution(t, forward_choices, forward_retval, proposal_args)
# Goal: based on `forward_choices`, return a new trace
# and the `backward_choices` required to reverse this
# proposal.
new_trace_choices = choicemap()
backward_choices = choicemap()
# First, check whether we're dealing with a split or a merge.
split_or_merge = forward_choices[:split_or_merge]
# To reverse a split, use a merge (and vice versa)
backward_choices[:split_or_merge] = !split_or_merge
# Where is the split / merge occurring?
# Reversing the proposal would use the same index.
index = forward_choices[:index]
backward_choices[:index] = index
# Now we handle the segment values and proportions.
# First, pull out the existing values from the previous
# trace `t`. IMPORTANT NOTE: we need to `copy`
# `t[:fractions]`, because we intend to perform mutating
# Julia operations like `insert!` and `deleteat!` on it,
# but do not wish to change the memory underlying the
# original trace `t` (in case the proposal is rejected
# and we need to return to the original trace).
fractions = copy(t[:fractions])
segment_values = [t[(:segments, i)] for i=1:t[:segment_count]]
# How we update `fractions` and `segment_values` depends on
# whether this is a split or a merge.
if split_or_merge
# If this is a split, then add a new element at `index`,
# according to the split proportion.
insert!(fractions, index, fractions[index] * forward_choices[:split_percentage])
fractions[index + 1] *= (1 - forward_choices[:split_percentage])
# Segment values
backward_choices[:new_value] = segment_values[index]
segment_values[index] = forward_choices[:new_value_1]
insert!(segment_values, index, forward_choices[:new_value_2])
else
# If this is a merge, then combine the two segments `index`
# and `index + 1`.
proportion = fractions[index] / (fractions[index] + fractions[index + 1])
fractions[index] += fractions[index + 1]
deleteat!(fractions, index + 1)
# Set the relevant segment values.
backward_choices[:new_value_1] = segment_values[index]
backward_choices[:new_value_2] = segment_values[index + 1]
backward_choices[:split_percentage] = proportion
segment_values[index] = forward_choices[:new_value]
deleteat!(segment_values, index + 1)
end
# Fill a choicemap of the newly proposed trace's values
new_trace_choices[:fractions] = fractions
for (i, value) in enumerate(segment_values)
new_trace_choices[(:segments, i)] = value
end
new_trace_choices[:segment_count] = length(fractions)
# Obtain an updated trace matching the choicemap, and a weight
new_trace, weight, = update(t, get_args(t), (NoChange(),), new_trace_choices)
(new_trace, backward_choices, weight)
end
involution (generic function with 1 method)
Look at the last two lines of the involution. The involution is responsible for returning a new trace, not just a new choicemap. So, the common pattern to follow in involutions is:
- Throughout the involution, fill a choicemap (here,
new_trace_choices) with the updates that you want to make to the old tracet. - At the end, call Gen’s
updatefunction, passing in (1) the old tracet, (2) the model function’s arguments, (3) a tuple of “argdiffs” (in our case, we know that the argumentxshas not changed, so we pass inNoChange()), and (4) the choicemap of updates.updatereturns two useful values: the new trace, and aweight, which is equal toget_score(t') - get_score(t). This is the weight we need to return from the involution.
We can create a new MH update that uses our proposal, and test it on the new dataset:
@gen function mean_segments_proposal(t, xs, ys, i)
xmin, xmax = minimum(xs), maximum(xs)
x_range = xmax - xmin
fracs = t[:fractions]
min = xmin + x_range * sum(fracs[1:i-1])
max = xmin + x_range * sum(fracs[1:i])
relevant_ys = [y for (x,y) in zip(xs,ys) if x >= min && x <= max]
{(:segments, i)} ~ normal(sum(relevant_ys)/length(relevant_ys), 0.3)
end
function custom_update_inv(tr, xs, ys)
tr, accepted = mh(tr, split_merge_proposal_randomness, (), involution)
for i=1:tr[:segment_count]
tr, = mh(tr, mean_segments_proposal, (xs, ys, i))
end
tr, = mh(tr, select(:noise))
tr
end
visualize_mh_alg(xs_dense, ys_complex, custom_update_inv, 75, 10)
Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Iteration 600
Iteration 700
Log mean score: 26.43941395921412