Variational Inference
Variational inference involves optimizing the parameters of a variational family to maximize a lower bound on the marginal likelihood called the ELBO. In Gen, variational families are represented as generative functions, and variational inference typically involves optimizing the trainable parameters of generative functions.
Black box variational inference
There are two procedures in the inference library for performing black box variational inference. Each of these procedures can also train the model using stochastic gradient descent, as in a variational autoencoder.
Gen.black_box_vi!
— Function(elbo_estimate, traces, elbo_history) = black_box_vi!(
model::GenerativeFunction, model_args::Tuple,
[model_update::ParamUpdate,]
observations::ChoiceMap,
var_model::GenerativeFunction, var_model_args::Tuple,
var_model_update::ParamUpdate;
options...)
Fit the parameters of a variational model (var_model
) to the posterior distribution implied by the given model
and observations
using stochastic gradient methods. Users may optionally specify a model_update
to jointly update the parameters of model
.
Additional arguments:
iters=1000
: Number of iterations of gradient descent.samples_per_iter=100
: Number of samples from the variational and generative model to accumulate gradients over before a single gradient step.verbose=false
: Iftrue
, print information about the progress of fitting.callback
: Callback function that takes(iter, traces, elbo_estimate)
as input, whereiter
is the iteration number andtraces
are samples fromvar_model
for that iteration.
Gen.black_box_vimco!
— Function(iwelbo_estimate, traces, iwelbo_history) = black_box_vimco!(
model::GenerativeFunction, model_args::Tuple,
[model_update::ParamUpdate,]
observations::ChoiceMap,
var_model::GenerativeFunction, var_model_args::Tuple,
var_model_update::ParamUpdate,
grad_est_samples::Int; options...)
Fit the parameters of a variational model (var_model
) to the posterior distribution implied by the given model
and observations
using stochastic gradient methods applied to the Variational Inference with Monte Carlo Objectives (VIMCO) lower bound on the marginal likelihood. Users may optionally specify a model_update
to jointly update the parameters of model
.
Additional arguments:
grad_est_samples::Int
: Number of samples for the VIMCO gradient estimate.iters=1000
: Number of iterations of gradient descent.samples_per_iter=100
: Number of samples from the variational and generative model to accumulate gradients over before a single gradient step.geometric=true
: Whether to use the geometric or arithmetric baselines described in Variational Inference with Monte Carlo Objectivesverbose=false
: Iftrue
, print information about the progress of fitting.callback
: Callback function that takes(iter, traces, elbo_estimate)
as input, whereiter
is the iteration number andtraces
are samples fromvar_model
for that iteration.
Reparametrization trick
To use the reparametrization trick to reduce the variance of gradient estimators, users currently need to write two versions of their variational family, one that is reparametrized and one that is not. Gen does not currently include inference library support for this. We plan add add automated support for reparametrization and other variance reduction techniques in the future.