# Tutorial: Reversible-Jump MCMC in Gen (with applications to Program Synthesis)

### What is this notebook about?

In earlier tutorials, we saw how to write custom Metropolis-Hastings proposals and stitch them together to craft our own MCMC algorithms. Recall that to write a custom MH proposal, it was necessary to define a generative function that accepted as an argument the previous trace, and proposed a new trace by “acting out” the model, that is, by sampling new proposed values of random choices at the same addresses used by the model.

For example, a random walk proposal for the addresses :x and :y might have looked like this:

@gen function random_walk_proposal(previous_trace)
x ~ normal(previous_trace[:x], 0.1)
y ~ normal(previous_trace[:y], 0.2)
end


This pattern implies a severe restriction on the kinds of proposals we can write: the new x and y values must be sampled directly from a Ggen distribution (e.g., normal). For example, it is impossible to use this pattern to define a proposal that deterministcally swaps the current x and y values.

In this notebook, you’ll learn a more flexible approach to defining custom MH proposals. The technique is widely applicable, but is particularly well-suited to models with discrete and continuous parameters, where the discrete parameters determine which continuous parameters exist.

The polynomial curve-fitting demo we saw in the last problem set is one simple example: the discrete parameter degree determines which coefficient parameters exist. This situation is also common in program synthesis: discrete variables determine a program’s structure, but the values inside the program may be continuous.

## Outline

Section 1. Recap of the piecewise-function model

Section 2. Basic Metropolis-Hastings inference

Section 3. Reversilbe-Jump “Split-Merge” proposals

Section 4. Bayesian program synthesis of GP kernels

Section 5. A tree regeneration proposal

using Gen, GenDistributions, Plots, Logging
include("dirichlet.jl")

Logging.disable_logging(Logging.Info);


## 1. Recap of the piecewise-constant function model

In the intro to modeling tutorial, you worked with a model of piecewise constant functions, with unknown changepoints. Here, we model the same scenario, but somewhat differently.

Given a dataset of xs, our model will randomly divide the range (xmin, xmax) into a random number of segments.

It does this by sampling a number of segments (:segment_count), then sampling a vector of proportions from a Dirichlet distribution (:fractions). The vector is guaranteed to sum to 1: if there are, say, three segments, this vector might be [0.3, 0.5, 0.2]. The length of each segment is the fraction of the interval assigned to it, times the length of the entire interval, e.g. 0.2 * (xmax - xmin). For each segmment, we generate a y value from a normal distribution. Finally, we sample the y values near the piecewise constant function described by the segments.

### Using @dist to define new distributions for convenience

To sample the number of segments, we need a distribution with support only on the positive integers. We create one using the Distributions DSL:

# A distribution that is guaranteed to be 1 or higher.
@dist poisson_plus_one(rate) = poisson(rate) + 1;


Distributions declared with @dist can be used to make random choices inside of Gen models. Behind the scenes, @dist has analyzed the code and figured out how to evaluate the logpdf, or log density, of our newly defined distribution. So we can ask, e.g., what the density of poisson_plus_one(1) is at the point 3:

logpdf(poisson_plus_one, 3, 1)

-1.6931471805599454


Note that this is the same as the logpdf of poisson(1) at the point 2@dist’s main job is to automate the logic of converting the above call into this one:

logpdf(poisson, 2, 1)

-1.6931471805599454


### Writing the model

We can now write the model itself. It is relatively straightforward, though there are a few things you may not have seen before:

• List comprehensions allow you to create a list without writing an entire for loop. For example, [{(:segments, i)} ~ normal(0, 1) for i=1:segment_count] creates a list of elements, one for each i in the range [1, ..., segment_count], each of which is generated using the expression {(:segments, i)} ~ normal(0, 1).

• The Dirichlet distribution is a distribution over the simplex of vectors whose elements sum to 1. We use it to generate the fractions of the entire interval (x_min, x_max) that each segment of our piecewise function covers.

• The logic to generate the y points is as follows: we compute the cumulative fractions cumfracs = cumsum(fractions), such that xmin + (xmax - xmin) * cumfracs[j] is the x-value of the right endpoint of the jth segment. Then we sample at each address (:y, i) a normal whose mean is the y-value of the segment that contains xs[i].

@gen function piecewise_constant(xs::Vector{Float64})
# Generate a number of segments (at least 1)
segment_count ~ poisson_plus_one(1)

# To determine changepoints, draw a vector on the simplex from a Dirichlet
# distribution. This gives us the proportions of the entire interval that each
# segment takes up. (The entire interval is determined by the minimum and maximum
# x values.)
fractions ~ dirichlet([1.0 for i=1:segment_count])

# Generate values for each segment
segments = [{(:segments, i)} ~ normal(0, 1) for i=1:segment_count]

# Determine a global noise level
noise ~ gamma(1, 1)

# Generate the y points for the input x points
xmin, xmax = extrema(xs)
cumfracs = cumsum(fractions)
# Correct numeric issue: cumfracs[end] might be 0.999999
@assert cumfracs[end] ≈ 1.0
cumfracs[end] = 1.0

inds = [findfirst(frac -> frac >= (x - xmin) / (xmax - xmin),
cumfracs) for x in xs]
segment_values = segments[inds]
for (i, val) in enumerate(segment_values)
{(:y, i)} ~ normal(val, noise)
end
end;


Let’s understand its behavior by visualizing several runs of the model. We begin by creating a simple dataset of xs, evenly spaced between -5 and 5.

xs_dense = collect(range(-5, stop=5, length=50));


Don’t worry about understanding the following code, which we use for visualization.

function trace_to_dict(tr)
Dict(:values => [tr[(:segments, i)] for i=1:(tr[:segment_count])],
:fracs  => tr[:fractions], :n => tr[:segment_count], :noise => tr[:noise],
:ys => [tr[(:y, i)] for i=1:length(xs_dense)])
end;

function visualize_trace(tr; title="")
xs, = get_args(tr)
tr = trace_to_dict(tr)

scatter(xs, tr[:ys], label=nothing, xlabel="X", ylabel="Y")

cumfracs = [0.0, cumsum(tr[:fracs])...]
xmin = minimum(xs)
xmax = maximum(xs)
for i in 1:tr[:n]
segment_xs = [xmin + cumfracs[i] * (xmax - xmin), xmin + cumfracs[i+1] * (xmax - xmin)]
segment_ys = fill(tr[:values][i], 2)
plot!(segment_xs, segment_ys, label=nothing, linewidth=4)
end
Plots.title!(title)
end

traces = [simulate(piecewise_constant, (xs_dense,)) for _ in 1:9]
plot([visualize_trace(t) for t in traces]...)


Many of the samples involve only one segment, but many of them involve more. The level of noise also varies from sample to sample.

## 2. Basic Metropolis-Hastings inference

Let’s create three synthetic datasets, each more challenging than the last, to test out our inference capabilities.

ys_simple  = ones(length(xs_dense)) .+ randn(length(xs_dense)) * 0.1
ys_medium  = Base.ifelse.(Int.(floor.(abs.(xs_dense ./ 3))) .% 2 .== 0,
2, 0) .+ randn(length(xs_dense)) * 0.1;
ys_complex = Int.(floor.(abs.(xs_dense ./ 2))) .% 5 .+ randn(length(xs_dense)) * 0.1;


We’ll need a helper function for creating a choicemap of constraints from a vector of ys:

function make_constraints(ys)
choicemap([(:y, i) => ys[i] for i=1:length(ys)]...)
end;


As we saw in the last problem set, importance sampling does a decent job on the simple dataset:

NUM_CHAINS = 9

traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_simple), 5000)) for _ in 1:NUM_CHAINS]
Plots.plot([visualize_trace(t) for t in traces]...)


But on the complex dataset, it takes many more particles (here, we use 50,000) to do even an OK job:

traces = [first(importance_resampling(piecewise_constant, (xs_dense,), make_constraints(ys_complex), 50000)) for _ in 1:9]
scores = [get_score(t) for t in traces]
println("Log mean score: $(logsumexp(scores)-log(NUM_CHAINS))") Plots.plot([visualize_trace(t) for t in traces]...)  Log mean score: -37.681098516181876  Let’s try instead to write a simple Metropolis-Hastings algorithm to tackle the problem. First, some visualization code (feel free to ignore it!). function visualize_mh_alg(xs, ys, update, frames=200, iters_per_frame=1, N=NUM_CHAINS; verbose=true) traces = [first(generate(piecewise_constant, (xs,), make_constraints(ys))) for _ in 1:N] viz = Plots.@animate for i in 1:frames if i*iters_per_frame % 100 == 0 && verbose println("Iteration$(i*iters_per_frame)")
end

for j in 1:N
for k in 1:iters_per_frame
traces[j] = update(traces[j], xs, ys)
end
end

Plots.plot([visualize_trace(t; title=(j == 2 ? "Iteration $(i*iters_per_frame)/$(frames*iters_per_frame)" : "")) for (j,t) in enumerate(traces)]...)#, layout=l)
end
scores = [Gen.get_score(t) for t in traces]
println("Log mean score: \$(logsumexp(scores) - log(N))")
gif(viz)
end;


And now the MH algorithm itself.

We’ll use a basic Block Resimulation sampler, which cycles through the following blocks of variables:

• Block 1: :segment_count and :fractions. Resampling these tries proposing a completely new division of the interval into pieces. However, it reuses the (:segments, i) values wherever possible; that is, if we currently have three segments and :segment_count is proposed to change to 5, only two new segment values will be sampled.

• Block 2: :fractions. This proposal tries leaving the number of segments the same, but resamples their relative lengths.

• Block 3: :noise. This proposal adjusts the global noise parameter.

• Blocks 4 and up: (:segments, i). Tries separately proposing new values for each segment (and accepts or rejects each proposal independently).

function simple_update(tr, xs, ys)
tr, = mh(tr, select(:segment_count, :fractions))
tr, = mh(tr, select(:fractions))
tr, = mh(tr, select(:noise))
for i=1:tr[:segment_count]
tr, = mh(tr, select((:segments, i)))
end
tr
end

simple_update (generic function with 1 method)


Our algorithm makes quick work of the simple dataset:

visualize_mh_alg(xs_dense, ys_simple, simple_update, 100, 1)

Iteration 100
Log mean score: 46.82126477547658


On the medium dataset, it does an OK job with less computation than we used in importance sampling:

visualize_mh_alg(xs_dense, ys_medium, simple_update, 50, 10)

Iteration 100
Iteration 200
Iteration 300
Iteration 400
Iteration 500
Log mean score: 17.75535817726937