Tutorial: Basics of Iterative Inference Programming in Gen
This tutorial introduces the basics of inference programming in Gen. In particular, in this notebook we’ll focus on iterative inference programs, which include Markov chain Monte Carlo algorithms.
The task: curve-fitting with outliers
Suppose we have a dataset of points in the $x,y$ plane that is mostly explained by a linear relationship, but which also has several outliers. Our goal will be to automatically identify the outliers, and to find a linear relationship (a slope and intercept, as well as an inherent noise level) that explains rest of the points:
This is a simple inference problem. But it has two features that make it ideal for introducing concepts in modeling and inference.
- First, we want not only to estimate the slope and intercept of the line that best fits the data, but also to classify each point as an inlier or outlier; that is, there are a large number of latent variables of interest, enough to make importance sampling an unreliable method (absent a more involved custom proposal that does the heavy lifting).
- Second, several of the parameters we’re estimating (the slope and intercept) are continuous and amenable to gradient-based search techniques, which will allow us to explore Gen’s optimization capabilities.
Let’s get started!
Outline
Section 1. Writing the model: a first attempt
Section 2. Visualizing the model’s behavior
Section 3. The problem with generic importance sampling
Section 4. MCMC Inference Part 1: Block Resimulation
Section 5. MCMC Inference Part 2: Gaussian Drift
Section 6. MCMC Inference Part 3: Proposals based on heuristics
Section 7. MAP Optimization
import Random, Logging
using Gen, Plots
# Disable logging, because @animate is verbose otherwise
Logging.disable_logging(Logging.Info);
1. Writing the model
We begin, as usual, by writing a model: a generative function responsible (conceptually) for simulating a synthetic dataset.
Our model will take as input a vector of x
coordinates, and produce as
output corresponding y
coordinates.
We will also use this opportunity to introduce some syntactic sugar.
As described in the previous notebook, random choices in Gen are given
addresses using the syntax {addr} ~ distribution(...)
. But this can
be a bit verbose, and often leads to code that looks like the following:
x = {:x} ~ normal(0, 1)
slope = {:slope} ~ normal(0, 1)
In these examples, the variable name is duplicated as the address of the random choice. Because this is a common pattern, Gen provides syntactic sugar that makes it nicer to use:
# Desugars to "x = {:x} ~ normal(0, 1)"
x ~ normal(0, 1)
# Desugars to "slope = {:slope} ~ normal(0, 1)"
slope ~ normal(0, 1)
Note that sometimes, it is still necessary to use the {...}
form, for example
in loops:
# INVALID:
for i=1:10
y ~ normal(0, 1) # The name :y will be used more than once!!
println(y)
end
# VALID:
for i=1:10
y = {(:y, i)} ~ normal(0, 1) # OK: the address is different each time.
println(y)
end
We’ll use this new syntax for writing our model of linear regression with outliers. As we’ve seen before, the model generates parameters from a prior, and then simulates data based on those parameters:
@gen function regression_with_outliers(xs::Vector{<:Real})
# First, generate some parameters of the model. We make these
# random choices, because later, we will want to infer them
# from data. The distributions we use here express our assumptions
# about the parameters: we think the slope and intercept won't be
# too far from 0; that the noise is relatively small; and that
# the proportion of the dataset that don't fit a linear relationship
# (outliers) could be anything between 0 and 1.
slope ~ normal(0, 2)
intercept ~ normal(0, 2)
noise ~ gamma(1, 1)
prob_outlier ~ uniform(0, 1)
# Next, we generate the actual y coordinates.
n = length(xs)
ys = Float64[]
for i = 1:n
# Decide whether this point is an outlier, and set
# mean and standard deviation accordingly
if ({:data => i => :is_outlier} ~ bernoulli(prob_outlier))
(mu, std) = (0., 10.)
else
(mu, std) = (xs[i] * slope + intercept, noise)
end
# Sample a y value for this point
push!(ys, {:data => i => :y} ~ normal(mu, std))
end
ys
end;
2. What our model is doing: visualizing the prior
Let’s visualize what our model is doing by drawing some samples from the prior.
# Generate nine traces and visualize them
include("visualization/regression_viz.jl")
xs = collect(range(-5, stop=5, length=20))
traces = [Gen.simulate(regression_with_outliers, (xs,)) for i in 1:9];
Plots.plot([visualize_trace(t) for t in traces]...)
Legend:
- red points: outliers;
- blue points: inliers (i.e. regular data);
- dark grey shading: noise associated with inliers; and
- light grey shading: noise associated with outliers.
Note that an outlier can occur anywhere — including close to the line — and that our model is capable of generating datasets in which the vast majority of points are outliers.
3. The problem with generic importance sampling
To motivate the need for more complex inference algorithms, let’s begin by using the simple importance sampling method from the previous tutorial, and thinking about where it fails.
First, let us create a synthetic dataset to do inference about.
function make_synthetic_dataset(n)
Random.seed!(1)
prob_outlier = 0.2
true_inlier_noise = 0.5
true_outlier_noise = 5.0
true_slope = -1
true_intercept = 2
xs = collect(range(-5, stop=5, length=n))
ys = Float64[]
for (i, x) in enumerate(xs)
if rand() < prob_outlier
y = randn() * true_outlier_noise
else
y = true_slope * x + true_intercept + randn() * true_inlier_noise
end
push!(ys, y)
end
(xs, ys)
end
(xs, ys) = make_synthetic_dataset(20);
Plots.scatter(xs, ys, color="black", xlabel="X", ylabel="Y",
label=nothing, title="Observations - regular data and outliers")
We will to express our observations as a ChoiceMap
that constrains the
values of certain random choices to equal their observed values. Here, we
want to constrain the values of the choices with address :data => i => :y
(that is, the sampled $y$ coordinates) to equal the observed $y$ values.
Let’s write a helper function that takes in a vector of $y$ values and
creates a ChoiceMap
that we can use to constrain our model:
function make_constraints(ys::Vector{Float64})
constraints = Gen.choicemap()
for i=1:length(ys)
constraints[:data => i => :y] = ys[i]
end
constraints
end;
We can apply it to our dataset’s vector of ys
to make a set of constraints
for doing inference:
observations = make_constraints(ys);
Now, we use the library function importance_resampling
to draw approximate
posterior samples given those observations:
function logmeanexp(scores)
logsumexp(scores) - log(length(scores))
end;
traces = [first(Gen.importance_resampling(regression_with_outliers, (xs,), observations, 2000)) for i in 1:9]
log_probs = [get_score(t) for t in traces]
println("Average log probability: $(logmeanexp(log_probs))")
Plots.plot([visualize_trace(t) for t in traces]...)
Average log probability: -51.8570956904098
We see here that importance resampling hasn’t completely failed: it generally finds a reasonable position for the line. But the details are off: there is little logic to the outlier classification, and the inferred noise around the line is too wide. The problem is that there are just too many variables to get right, and so sampling everything in one go is highly unlikely to produce a perfect hit.
In the remainder of this notebook, we’ll explore techniques for finding the right solution iteratively, beginning with an initial guess and making many small changes, until we achieve a reasonable posterior sample.
4. MCMC Inference Part 1: Block Resimulation
What is MCMC?
Markov Chain Monte Carlo (“MCMC”) methods are a powerful family of algorithms for iteratively producing approximate samples from a distribution (when applied to Bayesian inference problems, the posterior distribution of unknown (hidden) model variables given data).
There is a rich theory behind MCMC methods, but we focus on applying MCMC in Gen and introducing theoretical ideas only when necessary for understanding. As we will see, Gen provides abstractions that hide and automate much of the math necessary for implementing MCMC algorithms correctly.
The general shape of an MCMC algorithm is as follows. We begin by sampling an intial setting of all unobserved variables; in Gen, we produce an initial trace consistent with (but not necessarily probable given) our observations. Then, in a long-running loop, we make small, stochastic changes to the trace; in order for the algorithm to be asymptotically correct, these stochastic updates must satisfy certain probabilistic properties.
One common way of ensuring that the updates do satisfy those properties is to compute a Metropolis-Hastings acceptance ratio. Essentially, after proposing a change to a trace, we add an “accept or reject” step that stochastically decides whether to commit the update or to revert it. This is an over-simplification, but generally speaking, this step ensures we are more likely to accept changes that make our trace fit the observed data better, and to reject ones that make our current trace worse. The algorithm also tries not to go down dead ends: it is more likely to take an exploratory step into a low-probability region if it knows it can easily get back to where it came from.
Gen’s metropolis_hastings
function automatically adds this
“accept/reject” check (including the correct computation of the probability
of acceptance or rejection), so that inference programmers need only
think about what sorts of updates might be useful to propose. Starting in
this section, we’ll look at several design patterns for MCMC updates, and how
to apply them in Gen.
Block Resimulation
One of the simplest strategies we can use is called Resimulation MH, and it works as follows.
We begin, as in most iterative inference algorithms, by sampling an initial trace from our model, fixing the observed choices to their observed values.
# Gen's `generate` function accepts a model, a tuple of arguments to the model,
# and a `ChoiceMap` representing observations (or constraints to satisfy). It returns
# a complete trace consistent with the observations, and an importance weight.
# In this call, we ignore the weight returned.
(tr, _) = generate(regression_with_outliers, (xs,), observations)
Then, in each iteration of our program, we propose changes to all our model’s variables in “blocks,” by erasing a set of variables from our current trace and resimulating them from the model. After resimulating each block of choices, we perform an accept/reject step, deciding whether the proposed changes are worth making.
# Pseudocode
for iter=1:500
tr = maybe_update_block_1(tr)
tr = maybe_update_block_2(tr)
...
tr = maybe_update_block_n(tr)
end
The main design choice in designing a Block Resimulation MH algorithm is how to block the choices together for resimulation. At one extreme, we could put each random choice the model makes in its own block. At the other, we could put all variables into a single block (a strategy sometimes called “independent” MH, and which bears a strong similarity to importance resampling, as it involves repeatedly generating completely new traces and deciding whether to keep them or not). Usually, the right thing to do is somewhere in between.
For the regression problem, here is one possible blocking of choices:
Block 1: slope
, intercept
, and noise
. These parameters determine
the linear relationship; resimulating them is like picking a new line. We
know from our importance sampling experiment above that before too long,
we’re bound to sample something close to the right line.
Blocks 2 through N+1: Each is_outlier
, in its own block. One problem we
saw with importance sampling in this problem was that it tried to sample
every outlier classification at once, when in reality the chances of a
single sample that correctly classifies all the points are very low. Here, we
can choose to resimulate each is_outlier
choice separately, and for each
one, decide whether to use the resimulated value or not.
Block N+2: prob_outlier
. Finally, we can propose a new prob_outlier
value; in general, we can expect to accept the proposal when it is in line
with the current hypothesized proportion of is_outlier
choices that are
set to true
.
Resimulating a block of variables is the simplest form of update that Gen’s
metropolis_hastings
operator (or mh
for short) supports. When supplied
with a current trace and a selection of trace addresses to resimulate,
mh
performs the resimulation and the appropriate accept/reject check, then
returns a possibly updated trace, along with a boolean indicating whether the
update was accepted or not. A selection is created using the select
method. So a single update of the scheme we proposed above would look like
this:
# Perform a single block resimulation update of a trace.
function block_resimulation_update(tr)
# Block 1: Update the line's parameters
line_params = select(:noise, :slope, :intercept)
(tr, _) = mh(tr, line_params)
# Blocks 2-N+1: Update the outlier classifications
(xs,) = get_args(tr)
n = length(xs)
for i=1:n
(tr, _) = mh(tr, select(:data => i => :is_outlier))
end
# Block N+2: Update the prob_outlier parameter
(tr, _) = mh(tr, select(:prob_outlier))
# Return the updated trace
tr
end;
All that’s left is to (a) obtain an initial trace, and then (b) run that update in a loop for as long as we’d like:
function block_resimulation_inference(xs, ys, observations)
observations = make_constraints(ys)
(tr, _) = generate(regression_with_outliers, (xs,), observations)
for iter=1:500
tr = block_resimulation_update(tr)
end
tr
end;
Let’s test it out:
scores = Vector{Float64}(undef, 10)
for i=1:10
@time tr = block_resimulation_inference(xs, ys, observations)
scores[i] = get_score(tr)
end
println("Log probability: ", logmeanexp(scores))
0.773126 seconds (11.47 M allocations: 677.525 MiB, 13.19% gc time, 26.67% compilation time)
0.549964 seconds (10.81 M allocations: 641.426 MiB, 14.86% gc time)
0.558047 seconds (10.81 M allocations: 641.426 MiB, 15.35% gc time)
0.549738 seconds (10.81 M allocations: 641.426 MiB, 15.25% gc time)
0.541699 seconds (10.81 M allocations: 641.426 MiB, 14.85% gc time)
0.538167 seconds (10.81 M allocations: 641.426 MiB, 14.75% gc time)
0.556597 seconds (10.81 M allocations: 641.426 MiB, 15.09% gc time)
0.536488 seconds (10.81 M allocations: 641.426 MiB, 14.92% gc time)
0.541546 seconds (10.81 M allocations: 641.426 MiB, 16.01% gc time)
0.539004 seconds (10.81 M allocations: 641.426 MiB, 14.95% gc time)
Log probability: -50.78536994535881
We note that this is significantly better than importance sampling, even if we run importance sampling for about the same amount of (wall-clock) time per sample:
scores = Vector{Float64}(undef, 10)
for i=1:10
@time (tr, _) = importance_resampling(regression_with_outliers, (xs,), observations, 17000)
scores[i] = get_score(tr)
end
println("Log probability: ", logmeanexp(scores))
0.596626 seconds (12.53 M allocations: 882.477 MiB, 18.17% gc time)
0.603205 seconds (12.53 M allocations: 882.477 MiB, 19.11% gc time)
0.609832 seconds (12.53 M allocations: 882.477 MiB, 19.07% gc time)
0.595268 seconds (12.53 M allocations: 882.477 MiB, 18.82% gc time)
0.605643 seconds (12.53 M allocations: 882.477 MiB, 18.94% gc time)
0.587700 seconds (12.53 M allocations: 882.477 MiB, 17.79% gc time)
0.582221 seconds (12.53 M allocations: 882.477 MiB, 18.52% gc time)
0.582921 seconds (12.53 M allocations: 882.477 MiB, 18.71% gc time)
0.585478 seconds (12.53 M allocations: 882.477 MiB, 18.99% gc time)
0.567708 seconds (12.53 M allocations: 882.477 MiB, 17.21% gc time)
Log probability: -53.7625847077635
It’s one thing to see a log probability increase; it’s better to understand what the inference algorithm is actually doing, and to see why it’s doing better.
A great tool for debugging and improving MCMC algorithms is visualization. We
can use Plots.@animate
to produce an
animated visualization:
t, = generate(regression_with_outliers, (xs,), observations)
viz = Plots.@animate for i in 1:500
global t
t = block_resimulation_update(t)
visualize_trace(t; title="Iteration $i/500")
end;
gif(viz)
We can see that although the algorithm keeps changing the inferences of which points are inliers and outliers, it has a harder time refining the continuous parameters. We address this challenge next.
5. MCMC Inference Part 2: Gaussian Drift MH
So far, we’ve seen one form of incremental trace update:
(tr, did_accept) = mh(tr, select(:address1, :address2, ...))
This update is incremental in that it only proposes changes to part of a trace (the selected addresses). But when computing what changes to propose, it ignores the current state completely and resimulates all-new values from the model.
That wholesale resimulation of values is often not the best way to search for improvements. To that end, Gen also offers a more general flavor of MH:
(tr, did_accept) = mh(tr, custom_proposal, custom_proposal_args)
A “custom proposal” is just what it sounds like: whereas before, we were using the default resimulation proposal to come up with new values for the selected addresses, we can now pass in a generative function that samples proposed values however it wants.
For example, here is a custom proposal that takes in a current trace, and proposes a new slope and intercept by randomly perturbing the existing values:
@gen function line_proposal(current_trace)
slope ~ normal(current_trace[:slope], 0.5)
intercept ~ normal(current_trace[:intercept], 0.5)
end;
This is often called a “Gaussian drift” proposal, because it essentially amounts to proposing steps of a random walk. (What makes it different from a random walk is that we will still use an MH accept/reject step to make sure we don’t wander into areas of very low probability.)
To use the proposal, we write:
(tr, did_accept) = mh(tr, line_proposal, ())
Two things to note:
-
We no longer need to pass a selection of addresses. Instead, Gen assumes that whichever addresses are sampled by the proposal (in this case,
:slope
and:intercept
) are being proposed to. -
The argument list to the proposal is an empty tuple,
()
. Theline_proposal
generative function does expect an argument, the previous trace, but this is supplied automatically to all MH custom proposals (a proposal generative function for use withmh
must take as its first argument the current trace of the model).
Let’s swap it into our update:
function gaussian_drift_update(tr)
# Gaussian drift on line params
(tr, _) = mh(tr, line_proposal, ())
# Block resimulation: Update the outlier classifications
(xs,) = get_args(tr)
n = length(xs)
for i=1:n
(tr, _) = mh(tr, select(:data => i => :is_outlier))
end
# Block resimulation: Update the prob_outlier parameter
(tr, w) = mh(tr, select(:prob_outlier))
(tr, w) = mh(tr, select(:noise))
tr
end;
If we compare the Gaussian Drift proposal visually with our old algorithm, we can see the new behavior:
tr1, = generate(regression_with_outliers, (xs,), observations)
tr2 = tr1
viz = Plots.@animate for i in 1:300
global tr1, tr2
tr1 = gaussian_drift_update(tr1)
tr2 = block_resimulation_update(tr2)
Plots.plot(visualize_trace(tr1; title="Drift Kernel (Iter $i)"),
visualize_trace(tr2; title="Resim Kernel (Iter $i)"))
end;
gif(viz)