Tutorial: Basics of Iterative Inference Programming in Gen

This tutorial introduces the basics of inference programming in Gen. In particular, in this notebook we’ll focus on iterative inference programs, which include Markov chain Monte Carlo algorithms.

The task: curve-fitting with outliers

Suppose we have a dataset of points in the $x,y$ plane that is mostly explained by a linear relationship, but which also has several outliers. Our goal will be to automatically identify the outliers, and to find a linear relationship (a slope and intercept, as well as an inherent noise level) that explains rest of the points:

See https://dspace.mit.edu/bitstream/handle/1721.1/119255/MIT-CSAIL-TR-2018-020.pdf, Figure 2(a))

This is a simple inference problem. But it has two features that make it ideal for introducing concepts in modeling and inference.

  1. First, we want not only to estimate the slope and intercept of the line that best fits the data, but also to classify each point as an inlier or outlier; that is, there are a large number of latent variables of interest, enough to make importance sampling an unreliable method (absent a more involved custom proposal that does the heavy lifting).
  2. Second, several of the parameters we’re estimating (the slope and intercept) are continuous and amenable to gradient-based search techniques, which will allow us to explore Gen’s optimization capabilities.

Let’s get started!

Outline

Section 1. Writing the model: a first attempt

Section 2. Visualizing the model’s behavior

Section 3. The problem with generic importance sampling

Section 4. MCMC Inference Part 1: Block Resimulation

Section 5. MCMC Inference Part 2: Gaussian Drift

Section 6. MCMC Inference Part 3: Proposals based on heuristics

Section 7. MAP Optimization

import Random, Logging
using Gen, Plots

# Disable logging, because @animate is verbose otherwise
Logging.disable_logging(Logging.Info);

1. Writing the model

We begin, as usual, by writing a model: a generative function responsible (conceptually) for simulating a synthetic dataset.

Our model will take as input a vector of x coordinates, and produce as output corresponding y coordinates.

We will also use this opportunity to introduce some syntactic sugar. As described in the previous notebook, random choices in Gen are given addresses using the syntax {addr} ~ distribution(...). But this can be a bit verbose, and often leads to code that looks like the following:

x = {:x} ~ normal(0, 1)
slope = {:slope} ~ normal(0, 1)

In these examples, the variable name is duplicated as the address of the random choice. Because this is a common pattern, Gen provides syntactic sugar that makes it nicer to use:

# Desugars to "x = {:x} ~ normal(0, 1)"
x ~ normal(0, 1)
# Desugars to "slope = {:slope} ~ normal(0, 1)"
slope ~ normal(0, 1)

Note that sometimes, it is still necessary to use the {...} form, for example in loops:

# INVALID:
for i=1:10
    y ~ normal(0, 1) # The name :y will be used more than once!!
    println(y)
end

# VALID:
for i=1:10
    y = {(:y, i)} ~ normal(0, 1) # OK: the address is different each time.
    println(y)
end

We’ll use this new syntax for writing our model of linear regression with outliers. As we’ve seen before, the model generates parameters from a prior, and then simulates data based on those parameters:

@gen function regression_with_outliers(xs::Vector{<:Real})
    # First, generate some parameters of the model. We make these
    # random choices, because later, we will want to infer them
    # from data. The distributions we use here express our assumptions
    # about the parameters: we think the slope and intercept won't be
    # too far from 0; that the noise is relatively small; and that
    # the proportion of the dataset that don't fit a linear relationship
    # (outliers) could be anything between 0 and 1.
    slope ~ normal(0, 2)
    intercept ~ normal(0, 2)
    noise ~ gamma(1, 1)
    prob_outlier ~ uniform(0, 1)
    
    # Next, we generate the actual y coordinates.
    n = length(xs)
    ys = Float64[]
    
    for i = 1:n
        # Decide whether this point is an outlier, and set
        # mean and standard deviation accordingly
        if ({:data => i => :is_outlier} ~ bernoulli(prob_outlier))
            (mu, std) = (0., 10.)
        else
            (mu, std) = (xs[i] * slope + intercept, noise)
        end
        # Sample a y value for this point
        push!(ys, {:data => i => :y} ~ normal(mu, std))
    end
    ys
end;

2. What our model is doing: visualizing the prior

Let’s visualize what our model is doing by drawing some samples from the prior.

# Generate nine traces and visualize them
include("visualization/regression_viz.jl")
xs     = collect(range(-5, stop=5, length=20))
traces = [Gen.simulate(regression_with_outliers, (xs,)) for i in 1:9];
Plots.plot([visualize_trace(t) for t in traces]...)

svg

Legend:

Note that an outlier can occur anywhere — including close to the line — and that our model is capable of generating datasets in which the vast majority of points are outliers.

3. The problem with generic importance sampling

To motivate the need for more complex inference algorithms, let’s begin by using the simple importance sampling method from the previous tutorial, and thinking about where it fails.

First, let us create a synthetic dataset to do inference about.

function make_synthetic_dataset(n)
    Random.seed!(1)
    prob_outlier = 0.2
    true_inlier_noise = 0.5
    true_outlier_noise = 5.0
    true_slope = -1
    true_intercept = 2
    xs = collect(range(-5, stop=5, length=n))
    ys = Float64[]
    for (i, x) in enumerate(xs)
        if rand() < prob_outlier
            y = randn() * true_outlier_noise
        else
            y = true_slope * x + true_intercept + randn() * true_inlier_noise
        end
        push!(ys, y)
    end
    (xs, ys)
end
    
(xs, ys) = make_synthetic_dataset(20);
Plots.scatter(xs, ys, color="black", xlabel="X", ylabel="Y", 
              label=nothing, title="Observations - regular data and outliers")

svg

We will to express our observations as a ChoiceMap that constrains the values of certain random choices to equal their observed values. Here, we want to constrain the values of the choices with address :data => i => :y (that is, the sampled $y$ coordinates) to equal the observed $y$ values. Let’s write a helper function that takes in a vector of $y$ values and creates a ChoiceMap that we can use to constrain our model:

function make_constraints(ys::Vector{Float64})
    constraints = Gen.choicemap()
    for i=1:length(ys)
        constraints[:data => i => :y] = ys[i]
    end
    constraints
end;

We can apply it to our dataset’s vector of ys to make a set of constraints for doing inference:

observations = make_constraints(ys);

Now, we use the library function importance_resampling to draw approximate posterior samples given those observations:

function logmeanexp(scores)
    logsumexp(scores) - log(length(scores))
end;
traces    = [first(Gen.importance_resampling(regression_with_outliers, (xs,), observations, 2000)) for i in 1:9]
log_probs = [get_score(t) for t in traces]
println("Average log probability: $(logmeanexp(log_probs))")
Plots.plot([visualize_trace(t) for t in traces]...)
Average log probability: -51.8570956904098

svg

We see here that importance resampling hasn’t completely failed: it generally finds a reasonable position for the line. But the details are off: there is little logic to the outlier classification, and the inferred noise around the line is too wide. The problem is that there are just too many variables to get right, and so sampling everything in one go is highly unlikely to produce a perfect hit.

In the remainder of this notebook, we’ll explore techniques for finding the right solution iteratively, beginning with an initial guess and making many small changes, until we achieve a reasonable posterior sample.

4. MCMC Inference Part 1: Block Resimulation

What is MCMC?

Markov Chain Monte Carlo (“MCMC”) methods are a powerful family of algorithms for iteratively producing approximate samples from a distribution (when applied to Bayesian inference problems, the posterior distribution of unknown (hidden) model variables given data).

There is a rich theory behind MCMC methods, but we focus on applying MCMC in Gen and introducing theoretical ideas only when necessary for understanding. As we will see, Gen provides abstractions that hide and automate much of the math necessary for implementing MCMC algorithms correctly.

The general shape of an MCMC algorithm is as follows. We begin by sampling an intial setting of all unobserved variables; in Gen, we produce an initial trace consistent with (but not necessarily probable given) our observations. Then, in a long-running loop, we make small, stochastic changes to the trace; in order for the algorithm to be asymptotically correct, these stochastic updates must satisfy certain probabilistic properties.

One common way of ensuring that the updates do satisfy those properties is to compute a Metropolis-Hastings acceptance ratio. Essentially, after proposing a change to a trace, we add an “accept or reject” step that stochastically decides whether to commit the update or to revert it. This is an over-simplification, but generally speaking, this step ensures we are more likely to accept changes that make our trace fit the observed data better, and to reject ones that make our current trace worse. The algorithm also tries not to go down dead ends: it is more likely to take an exploratory step into a low-probability region if it knows it can easily get back to where it came from.

Gen’s metropolis_hastings function automatically adds this “accept/reject” check (including the correct computation of the probability of acceptance or rejection), so that inference programmers need only think about what sorts of updates might be useful to propose. Starting in this section, we’ll look at several design patterns for MCMC updates, and how to apply them in Gen.

Block Resimulation

One of the simplest strategies we can use is called Resimulation MH, and it works as follows.

We begin, as in most iterative inference algorithms, by sampling an initial trace from our model, fixing the observed choices to their observed values.

# Gen's `generate` function accepts a model, a tuple of arguments to the model,
# and a `ChoiceMap` representing observations (or constraints to satisfy). It returns
# a complete trace consistent with the observations, and an importance weight.  
# In this call, we ignore the weight returned.
(tr, _) = generate(regression_with_outliers, (xs,), observations)

Then, in each iteration of our program, we propose changes to all our model’s variables in “blocks,” by erasing a set of variables from our current trace and resimulating them from the model. After resimulating each block of choices, we perform an accept/reject step, deciding whether the proposed changes are worth making.

# Pseudocode
for iter=1:500
    tr = maybe_update_block_1(tr)
    tr = maybe_update_block_2(tr)
    ...
    tr = maybe_update_block_n(tr)
end

The main design choice in designing a Block Resimulation MH algorithm is how to block the choices together for resimulation. At one extreme, we could put each random choice the model makes in its own block. At the other, we could put all variables into a single block (a strategy sometimes called “independent” MH, and which bears a strong similarity to importance resampling, as it involves repeatedly generating completely new traces and deciding whether to keep them or not). Usually, the right thing to do is somewhere in between.

For the regression problem, here is one possible blocking of choices:

Block 1: slope, intercept, and noise. These parameters determine the linear relationship; resimulating them is like picking a new line. We know from our importance sampling experiment above that before too long, we’re bound to sample something close to the right line.

Blocks 2 through N+1: Each is_outlier, in its own block. One problem we saw with importance sampling in this problem was that it tried to sample every outlier classification at once, when in reality the chances of a single sample that correctly classifies all the points are very low. Here, we can choose to resimulate each is_outlier choice separately, and for each one, decide whether to use the resimulated value or not.

Block N+2: prob_outlier. Finally, we can propose a new prob_outlier value; in general, we can expect to accept the proposal when it is in line with the current hypothesized proportion of is_outlier choices that are set to true.

Resimulating a block of variables is the simplest form of update that Gen’s metropolis_hastings operator (or mh for short) supports. When supplied with a current trace and a selection of trace addresses to resimulate, mh performs the resimulation and the appropriate accept/reject check, then returns a possibly updated trace, along with a boolean indicating whether the update was accepted or not. A selection is created using the select method. So a single update of the scheme we proposed above would look like this:

# Perform a single block resimulation update of a trace.
function block_resimulation_update(tr)
    # Block 1: Update the line's parameters
    line_params = select(:noise, :slope, :intercept)
    (tr, _) = mh(tr, line_params)
    
    # Blocks 2-N+1: Update the outlier classifications
    (xs,) = get_args(tr)
    n = length(xs)
    for i=1:n
        (tr, _) = mh(tr, select(:data => i => :is_outlier))
    end
    
    # Block N+2: Update the prob_outlier parameter
    (tr, _) = mh(tr, select(:prob_outlier))
    
    # Return the updated trace
    tr
end;

All that’s left is to (a) obtain an initial trace, and then (b) run that update in a loop for as long as we’d like:

function block_resimulation_inference(xs, ys, observations)
    observations = make_constraints(ys)
    (tr, _) = generate(regression_with_outliers, (xs,), observations)
    for iter=1:500
        tr = block_resimulation_update(tr)
    end
    tr
end;

Let’s test it out:

scores = Vector{Float64}(undef, 10)
for i=1:10
    @time tr = block_resimulation_inference(xs, ys, observations)
    scores[i] = get_score(tr)
end
println("Log probability: ", logmeanexp(scores))
  0.773126 seconds (11.47 M allocations: 677.525 MiB, 13.19% gc time, 26.67% compilation time)
  0.549964 seconds (10.81 M allocations: 641.426 MiB, 14.86% gc time)
  0.558047 seconds (10.81 M allocations: 641.426 MiB, 15.35% gc time)
  0.549738 seconds (10.81 M allocations: 641.426 MiB, 15.25% gc time)
  0.541699 seconds (10.81 M allocations: 641.426 MiB, 14.85% gc time)
  0.538167 seconds (10.81 M allocations: 641.426 MiB, 14.75% gc time)
  0.556597 seconds (10.81 M allocations: 641.426 MiB, 15.09% gc time)
  0.536488 seconds (10.81 M allocations: 641.426 MiB, 14.92% gc time)
  0.541546 seconds (10.81 M allocations: 641.426 MiB, 16.01% gc time)
  0.539004 seconds (10.81 M allocations: 641.426 MiB, 14.95% gc time)
Log probability: -50.78536994535881

We note that this is significantly better than importance sampling, even if we run importance sampling for about the same amount of (wall-clock) time per sample:

scores = Vector{Float64}(undef, 10)
for i=1:10
    @time (tr, _) = importance_resampling(regression_with_outliers, (xs,), observations, 17000)
    scores[i] = get_score(tr)
end
println("Log probability: ", logmeanexp(scores))
  0.596626 seconds (12.53 M allocations: 882.477 MiB, 18.17% gc time)
  0.603205 seconds (12.53 M allocations: 882.477 MiB, 19.11% gc time)
  0.609832 seconds (12.53 M allocations: 882.477 MiB, 19.07% gc time)
  0.595268 seconds (12.53 M allocations: 882.477 MiB, 18.82% gc time)
  0.605643 seconds (12.53 M allocations: 882.477 MiB, 18.94% gc time)
  0.587700 seconds (12.53 M allocations: 882.477 MiB, 17.79% gc time)
  0.582221 seconds (12.53 M allocations: 882.477 MiB, 18.52% gc time)
  0.582921 seconds (12.53 M allocations: 882.477 MiB, 18.71% gc time)
  0.585478 seconds (12.53 M allocations: 882.477 MiB, 18.99% gc time)
  0.567708 seconds (12.53 M allocations: 882.477 MiB, 17.21% gc time)
Log probability: -53.7625847077635

It’s one thing to see a log probability increase; it’s better to understand what the inference algorithm is actually doing, and to see why it’s doing better.

A great tool for debugging and improving MCMC algorithms is visualization. We can use Plots.@animate to produce an animated visualization:

t, = generate(regression_with_outliers, (xs,), observations)

viz = Plots.@animate for i in 1:500
    global t
    t = block_resimulation_update(t)
    visualize_trace(t; title="Iteration $i/500")
end;
gif(viz)

We can see that although the algorithm keeps changing the inferences of which points are inliers and outliers, it has a harder time refining the continuous parameters. We address this challenge next.

5. MCMC Inference Part 2: Gaussian Drift MH

So far, we’ve seen one form of incremental trace update:

(tr, did_accept) = mh(tr, select(:address1, :address2, ...))

This update is incremental in that it only proposes changes to part of a trace (the selected addresses). But when computing what changes to propose, it ignores the current state completely and resimulates all-new values from the model.

That wholesale resimulation of values is often not the best way to search for improvements. To that end, Gen also offers a more general flavor of MH:

(tr, did_accept) = mh(tr, custom_proposal, custom_proposal_args)

A “custom proposal” is just what it sounds like: whereas before, we were using the default resimulation proposal to come up with new values for the selected addresses, we can now pass in a generative function that samples proposed values however it wants.

For example, here is a custom proposal that takes in a current trace, and proposes a new slope and intercept by randomly perturbing the existing values:

@gen function line_proposal(current_trace)
    slope ~ normal(current_trace[:slope], 0.5)
    intercept ~ normal(current_trace[:intercept], 0.5)
end;

This is often called a “Gaussian drift” proposal, because it essentially amounts to proposing steps of a random walk. (What makes it different from a random walk is that we will still use an MH accept/reject step to make sure we don’t wander into areas of very low probability.)

To use the proposal, we write:

(tr, did_accept) = mh(tr, line_proposal, ())

Two things to note:

  1. We no longer need to pass a selection of addresses. Instead, Gen assumes that whichever addresses are sampled by the proposal (in this case, :slope and :intercept) are being proposed to.

  2. The argument list to the proposal is an empty tuple, (). The line_proposal generative function does expect an argument, the previous trace, but this is supplied automatically to all MH custom proposals (a proposal generative function for use with mh must take as its first argument the current trace of the model).

Let’s swap it into our update:

function gaussian_drift_update(tr)
    # Gaussian drift on line params
    (tr, _) = mh(tr, line_proposal, ())
    
    # Block resimulation: Update the outlier classifications
    (xs,) = get_args(tr)
    n = length(xs)
    for i=1:n
        (tr, _) = mh(tr, select(:data => i => :is_outlier))
    end
    
    # Block resimulation: Update the prob_outlier parameter
    (tr, w) = mh(tr, select(:prob_outlier))
    (tr, w) = mh(tr, select(:noise))
    tr
end;

If we compare the Gaussian Drift proposal visually with our old algorithm, we can see the new behavior:

tr1, = generate(regression_with_outliers, (xs,), observations)
tr2 = tr1

viz = Plots.@animate for i in 1:300
    global tr1, tr2
    tr1 = gaussian_drift_update(tr1)
    tr2 = block_resimulation_update(tr2)
    Plots.plot(visualize_trace(tr1; title="Drift Kernel (Iter $i)"), 
               visualize_trace(tr2; title="Resim Kernel (Iter $i)"))
end;
gif(viz)