Generative Functions

Gen.GenerativeFunctionType
GenerativeFunction{T,U <: Trace}

Abstract type for a generative function with return value type T and trace type U.

source
Gen.TraceType
Trace

Abstract type for a trace of a generative function.

source

The complete set of methods in the generative function interface (GFI) is:

Gen.simulateFunction
trace = simulate(gen_fn, args)

Execute the generative function and return the trace.

Given arguments (args), sample $(r, t) \sim p(\cdot; x)$ and return a trace with choice map $t$.

If gen_fn has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the args tuple. The generated trace will have default values filled in.

source
Gen.generateFunction
(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple)

Return a trace of a generative function.

(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple,
                                constraints::ChoiceMap)

Return a trace of a generative function that is consistent with the given constraints on the random choices.

Given arguments $x$ (args) and assignment $u$ (constraints) (which is empty for the first form), sample $t \sim q(\cdot; u, x)$ and $r \sim q(\cdot; x, t)$, and return the trace $(x, r, t)$ (trace). Also return the weight (weight):

\[\log \frac{p(r, t; x)}{q(t; u, x) q(r; x, t)}\]

If gen_fn has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the args tuple. The generated trace will have default values filled in.

Example without constraints:

(trace, weight) = generate(foo, (2, 4))

Example with constraint that address :z takes value true.

(trace, weight) = generate(foo, (2, 4), choicemap((:z, true))
source
Gen.updateFunction
(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple,
                                               constraints::ChoiceMap)

Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s).

Given a previous trace $(x, r, t)$ (trace), new arguments $x'$ (args), and a map $u$ (constraints), return a new trace $(x', r', t')$ (new_trace) that is consistent with $u$. The values of choices in $t'$ are either copied from $t$ or from $u$ (with $u$ taking precedence) or are sampled from the internal proposal distribution. All choices in $u$ must appear in $t'$. Also return an assignment $v$ (discard) containing the choices in $t$ that were overwritten by values from $u$, and any choices in $t$ whose address does not appear in $t'$. Sample $t' \sim q(\cdot; x', t + u)$, and $r' \sim q(\cdot; x', t')$, where $t + u$ is the choice map obtained by merging $t$ and $u$ with $u$ taking precedence for overlapping addresses. Also return a weight (weight):

\[\log \frac{p(r', t'; x')}{q(r'; x', t') q(t'; x', t + u)} - \log \frac{p(r, t; x)}{q(r; x, t)}\]

Note that argdiffs is expected to be the same length as args. If the function that generated trace supports default values for trailing arguments, then these arguments can be omitted from args and argdiffs. Note that if the original trace was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace.

source
(new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap)

Shorthand variant of update which assumes the arguments are unchanged.

source
Gen.regenerateFunction
(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple,
                                          selection::Selection)

Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family.

Given a previous trace $(x, r, t)$ (trace), new arguments $x'$ (args), and a set of addresses $A$ (selection), return a new trace $(x', t')$ (new_trace) such that $t'$ agrees with $t$ on all addresses not in $A$ ($t$ and $t'$ may have different sets of addresses). Let $u$ denote the restriction of $t$ to the complement of $A$. Sample $t' \sim Q(\cdot; u, x')$ and sample $r' \sim Q(\cdot; x', t')$. Return the new trace $(x', r', t')$ (new_trace) and the weight (weight):

\[\log \frac{p(r', t'; x')}{q(t'; u, x') q(r'; x', t')} - \log \frac{p(r, t; x)}{q(t; u', x) q(r; x, t)}\]

where $u'$ is the restriction of $t'$ to the complement of $A$.

Note that argdiffs is expected to be the same length as args. If the function that generated trace supports default values for trailing arguments, then these arguments can be omitted from args and argdiffs. Note that if the original trace was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the regenerated trace.

source
(new_trace, weight, retdiff) = regenerate(trace, selection::Selection)

Shorthand variant of regenerate which assumes the arguments are unchanged.

source
Gen.get_argsFunction
get_args(trace)

Return the argument tuple for a given execution.

Example:

args::Tuple = get_args(trace)
source
Gen.get_retvalFunction
get_retval(trace)

Return the return value of the given execution.

Example for generative function with return type T:

retval::T = get_retval(trace)
source
Gen.get_choicesFunction
get_choices(trace)

Return a value implementing the assignment interface

Note that the value of any non-addressed randomness is not externally accessible.

Example:

choices::ChoiceMap = get_choices(trace)
z_val = choices[:z]
source
Gen.get_scoreFunction
get_score(trace)

Return:

\[\log \frac{p(r, t; x)}{q(r; x, t)}\]

When there is no non-addressed randomness, this simplifies to the log probability $\log p(t; x)$.

source
Gen.get_gen_fnFunction
gen_fn::GenerativeFunction = get_gen_fn(trace)

Return the generative function that produced the given trace.

source
Base.getindexFunction
value = getindex(trace::Trace, addr)

Get the value of the random choice, or auxiliary state (e.g. return value of inner function call), at address addr.

source
retval = getindex(trace::Trace)
retval = trace[]

Synonym for get_retval.

source
Gen.projectFunction
weight = project(trace::U, selection::Selection)

Estimate the probability that the selected choices take the values they do in a trace.

Given a trace $(x, r, t)$ (trace) and a set of addresses $A$ (selection), let $u$ denote the restriction of $t$ to $A$. Return the weight (weight):

\[\log \frac{p(r, t; x)}{q(t; u, x) q(r; x, t)}\]

source
Gen.proposeFunction
(choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple)

Sample an assignment and compute the probability of proposing that assignment.

Given arguments (args), sample $t \sim p(\cdot; x)$ and $r \sim p(\cdot; x, t)$, and return $t$ (choices) and the weight (weight):

\[\log \frac{p(r, t; x)}{q(r; x, t)}\]

source
Gen.assessFunction
(weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)

Return the probability of proposing an assignment

Given arguments $x$ (args) and an assignment $t$ (choices) such that $p(t; x) > 0$, sample $r \sim q(\cdot; x, t)$ and return the weight (weight):

\[\log \frac{p(r, t; x)}{q(r; x, t)}\]

It is an error if $p(t; x) = 0$.

source
Gen.has_argument_gradsFunction
bools::Tuple = has_argument_grads(gen_fn::Union{GenerativeFunction,Distribution})

Return a tuple of booleans indicating whether a gradient is available for each of its arguments.

source
Gen.has_submapFunction
has_submap(choices::ChoiceMap, addr)

Return true if there is a non-empty sub-assignment at the given address.

source
Gen.accepts_output_gradFunction
req::Bool = accepts_output_grad(gen_fn::GenerativeFunction)

Return a boolean indicating whether the return value is dependent on any of the gradient source elements for any trace.

The gradient source elements are:

  • Any argument whose position is true in has_argument_grads

  • Any trainable parameter

  • Random choices made at a set of addresses that are selectable by choice_gradients.

source
Gen.accumulate_param_gradients!Function
arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.)

Increment gradient accumulators for parameters by the gradient of the log-probability of the trace, optionally scaled, and return the gradient with respect to the arguments (not scaled).

Given a previous trace $(x, t)$ (trace) and a gradient with respect to the return value $∇_y J$ (retgrad), return the following gradient (arg_grads) with respect to the arguments $x$:

\[∇_x \left( \log P(t; x) + J \right)\]

The length of arg_grads will be equal to the number of arguments to the function that generated trace (including any optional trailing arguments). If an argument is not annotated with (grad), the corresponding value in arg_grads will be nothing.

Also increment the gradient accumulators for the trainable parameters $Θ$ of the function by:

\[∇_Θ \left( \log P(t; x) + J \right)\]

source
Gen.choice_gradientsFunction
(arg_grads, choice_values, choice_grads) = choice_gradients(
    trace, selection=EmptySelection(), retgrad=nothing)

Given a previous trace $(x, t)$ (trace) and a gradient with respect to the return value $∇_y J$ (retgrad), return the following gradient (arg_grads) with respect to the arguments $x$:

\[∇_x \left( \log P(t; x) + J \right)\]

The length of arg_grads will be equal to the number of arguments to the function that generated trace (including any optional trailing arguments). If an argument is not annotated with (grad), the corresponding value in arg_grads will be nothing.

Also given a set of addresses $A$ (selection) that are continuous-valued random choices, return the folowing gradient (choice_grads) with respect to the values of these choices:

\[∇_A \left( \log P(t; x) + J \right)\]

The gradient is represented as a choicemap whose value at (hierarchical) address addr is $∂J/∂t[\texttt{addr}]$.

Also return the choicemap (choice_values) that is the restriction of $t$ to $A$.

source
Gen.get_paramsFunction
get_params(gen_fn::GenerativeFunction)

Return an iterable over the trainable parameters of the generative function.

source
Gen.DiffType
Diff

Abstract supertype for information about a change to a value.

source
Gen.UnknownChangeType
UnknownChange

Singleton to indicate the change to the value is unknown or unprovided.

source
Gen.DiffedType

Diffed{V,DV <: Diff}

Container for a value and information about a change to its value.

source
Gen.CustomUpdateGFType
CustomUpdateGF{T,S}

Abstract type for a generative function with a custom update computation, and default behaviors for all other generative function interface methods.

T is the type of the return value and S is the type of state used internally for incremental computation.

source
Gen.apply_with_stateFunction
retval, state = apply_with_state(gen_fn::CustomDetermGF, args)

Execute the generative function and return the return value and the state.

source
Gen.update_with_stateFunction
state, retval, retdiff = update_with_state(gen_fn::CustomDetermGF, state, args, argdiffs)

Update the arguments to the generative function and return new return value and state.

source
Gen.CustomGradientGFType
CustomGradientGF{T}

Abstract type for a generative function with a custom gradient computation, and default behaviors for all other generative function interface methods.

T is the type of the return value.

source
Gen.applyFunction
retval = apply(gen_fn::CustomGradientGF, args)

Apply the function to the arguments.

source
Gen.gradientFunction
arg_grads = gradient(gen_fn::CustomDetermGF, args, retval, retgrad)

Return the gradient tuple with respect to the arguments, where nothing is for argument(s) whose gradient is not available.

source
Gen.init_update_stateFunction
state = init_update_state(conf, gen_fn::GenerativeFunction, param_list::Vector)

Get the initial state for a parameter update to the given parameters of the given generative function.

param_list is a vector of references to parameters of gen_fn. conf configures the update.

source
Gen.apply_update!Function
apply_update!(state)

Apply one parameter update, mutating the values of the trainable parameters, and possibly also the given state.

source