Generative Functions

# Generative Functions

One of the core abstractions in Gen is the generative function. Generative functions are used to represent a variety of different types of probabilistic computations including generative models, inference models, custom proposal distributions, and variational approximations.

Generative functions are represented by the following abstact type:

GenerativeFunction{T,U}

Abstract type for a generative function with return value type T and trace type U.

source

There are various kinds of generative functions, which are represented by concrete subtypes of GenerativeFunction. For example, the Built-in Modeling Language allows generative functions to be constructed using Julia function definition syntax:

@gen function foo(a, b)
if @trace(bernoulli(0.5), :z)
return a + b + 1
else
return a + b
end
end

Generative functions behave like Julia functions in some respects. For example, we can call a generative function foo on arguments and get an output value using regular Julia call syntax:

>julia foo(2, 4)
7

However, generative functions are distinct from Julia functions because they support additional behaviors, described in the remainder of this section.

## Mathematical definition

Generative functions represent computations that accept some arguments, may use randomness internally, return an output, and cannot mutate externally observable state. We represent the randomness used during an execution of a generative function as a map from unique addresses to values, denoted $t : A \to V$ where $A$ is an address set and $V$ is a set of possible values that random choices can take. In this section, we assume that random choices are discrete to simplify notation. We say that two random choice maps $t$ and $s$ agree if they assign the same value for any address that is in both of their domains.

Generative functions may also use untraced randomness, which is not included in the map $t$. However, the state of untraced random choices is maintained by the trace internally. We denote untraced randomness by $r$. Untraced randomness is useful for example, when calling black box Julia code that implements a randomized algorithm.

The observable behavior of every generative function is defined by the following mathematical objects:

### 1. Input type

The set of valid argument tuples to the function, denoted $X$.

### 2. Probability distribution family

A family of probability distributions $p(t, r; x)$ on maps $t$ from random choice addresses to their values, and untraced randomness $r$, indexed by arguments $x$, for all $x \in X$. Note that the distribution must be normalized:

$\sum_{t, r} p(t, r; x) = 1 \;\; \mbox{for all} \;\; x \in X$

This corresponds to a requirement that the function terminate with probabability 1 for all valid arguments. We use $p(t; x)$ to denote the marginal distribution on the map $t$:

$p(t; x) := \sum_{r} p(t, r; x)$

And we denote the conditional distribution on untraced randomness $r$, given the map $t$, as:

$p(r; x, t) := p(t, r; x) / p(t; x)$

### 3. Return value function

A (deterministic) function $f$ that maps the tuple $(x, t)$ of the arguments and the random choice map to the return value of the function (which we denote by $y$). Note that the return value cannot depend on the untraced randomness.

### 4. Internal proposal distribution family

A family of probability distributions $q(t; x, u)$ on maps $t$ from random choice addresses to their values, indexed by tuples $(x, u)$ where $u$ is a map from random choice addresses to values, and where $x$ are the arguments to the function. It must satisfy the following conditions:

$\sum_{t} q(t; x, u) = 1 \;\; \mbox{for all} \;\; x \in X, u$
$p(t; x) > 0 \mbox{ if and only if } q(t; x, u) > 0 \mbox{ for all } u \mbox{ where } u \mbox{ and } t \mbox{ agree }$
$q(t; x, u) > 0 \mbox{ implies that } u \mbox{ and } t \mbox{ agree }.$

There is also a family of probability distributions $q(r; x, t)$ on untraced randomness, that satisfies:

$q(r; x, t) > 0 \mbox{ if and only if } p(r; x, t) > 0$

## Traces

An execution trace (or just trace) is a record of an execution of a generative function. There is no abstract type representing all traces. Different concrete types of generative functions use different data structures and different Jula types for their traces. The trace type that a generative function uses is the second type parameter of the GenerativeFunction abstract type.

A trace of a generative function can be produced using:

(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple)

Return a trace of a generative function.

(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple,
constraints::ChoiceMap)

Return a trace of a generative function that is consistent with the given constraints on the random choices.

Given arguments $x$ (args) and assignment $u$ (constraints) (which is empty for the first form), sample $t \sim q(\cdot; u, x)$ and $r \sim q(\cdot; x, t)$, and return the trace $(x, t, r)$ (trace). Also return the weight (weight):

$\log \frac{p(t, r; x)}{q(t; u, x) q(r; x, t)}$

Example without constraints:

(trace, weight) = generate(foo, (2, 4))

Example with constraint that address :z takes value true.

(trace, weight) = generate(foo, (2, 4), choicemap((:z, true))
source

The trace contains various information about the execution, including:

The arguments to the generative function:

get_args(trace)

Return the argument tuple for a given execution.

Example:

args::Tuple = get_args(trace)
source

The return value of the generative function:

get_retval(trace)

Return the return value of the given execution.

Example for generative function with return type T:

retval::T = get_retval(trace)
source

The map $t$ from addresses of random choices to their values:

get_choices(trace)

Return a value implementing the assignment interface

Note that the value of any non-addressed randomness is not externally accessible.

Example:

choices::ChoiceMap = get_choices(trace)
z_val = choices[:z]
source

The log probability that the random choices took the values they did:

get_score(trace)

Return $P(r, t; x) / Q(r; tx, t)$. When there is no non-addressed randomness, this simplifies to the log probability $P(t; x)$.

source

A reference to the generative function that was executed:

gen_fn::GenerativeFunction = get_gen_fn(trace)

Return the generative function that produced the given trace.

source

## Trace update methods

It is often important to update or adjust the trace of a generative function. In Gen, traces are persistent data structures, meaning they can be treated as immutable values. There are several methods that take a trace of a generative function as input and return a new trace of the generative function based on adjustments to the execution history of the function. We will illustrate these methods using the following generative function:

@gen function foo()
val = @trace(bernoulli(0.3), :a)
if @trace(bernoulli(0.4), :b)
val = @trace(bernoulli(0.6), :c) && val
else
val = @trace(bernoulli(0.1), :d) && val
end
val = @trace(bernoulli(0.7), :e) && val
return val
end

Suppose we have a trace (trace) with initial choices:

│
├── :a : false
│
├── :b : true
│
├── :c : false
│
└── :e : true

Note that address :d is not present because the branch in which :d is sampled was not taken because random choice :b had value true.

### Update

(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiff,
constraints::ChoiceMap)

Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for any newly introduced random choice(s).

Given a previous trace $(x, t, r)$ (trace), new arguments $x'$ (args), and a map $u$ (constraints), return a new trace $(x', t', r')$ (new_trace) that is consistent with $u$. The values of choices in $t'$ are deterministically copied either from $t$ or from $u$ (with $u$ taking precedence). All choices in $u$ must appear in $t'$. Also return an assignment $v$ (discard) containing the choices in $t$ that were overwritten by values from $u$, and any choices in $t$ whose address does not appear in $t'$. The new non-addressed randomness is sampled from $r' \sim q(\cdot; x', t')$. Also return a weight (weight):

$\log \frac{p(r', t'; x') q(r; x, t)}{p(r, t; x) q(r'; x', t')}$
source

Suppose we run update on the example trace, with the following constraints:

│
├── :b : false
│
└── :d : true
constraints = choicemap((:b, false), (:d, true))
(new_trace, w, _, discard) = update(trace, (), noargdiff, constraints)

Then get_choices(new_trace) will be:

│
├── :a : false
│
├── :b : false
│
├── :d : true
│
└── :e : true

and discard will be:

│
├── :b : true
│
└── :c : false

Note that the discard contains both the previous values of addresses that were overwritten, and the values for addresses that were in the previous trace but are no longer in the new trace. The weight (w) is computed as:

$p(t'; x) = 0.7 × 0.4 × 0.4 × 0.7 = 0.0784\\ p(t; x') = 0.7 × 0.6 × 0.1 × 0.7 = 0.0294\\ w = \log p(t'; x')/p(t; x) = \log 0.0294/0.0784 = \log 0.375$

### Regenerate

(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiff,
selection::AddressSet)

Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family.

Given a previous trace $(x, t, r)$ (trace), new arguments $x'$ (args), and a set of addresses $A$ (selection), return a new trace $(x', t')$ (new_trace) such that $t'$ agrees with $t$ on all addresses not in $A$ ($t$ and $t'$ may have different sets of addresses). Let $u$ denote the restriction of $t$ to the complement of $A$. Sample $t' \sim Q(\cdot; u, x')$ and sample $r' \sim Q(\cdot; x', t')$. Return the new trace $(x', t', r')$ (new_trace) and the weight (weight):

$\log \frac{p(r', t'; x') q(t; u', x) q(r; x, t)}{p(r, t; x) q(t'; u, x') q(r'; x', t')}$

where $u'$ is the restriction of $t'$ to the complement of $A$.

source

Suppose we run regenerate on the example trace, with selection :a and :b:

(new_trace, w, _) = regenerate(trace, (), noargdiff, select(:a, :b))

Then, a new value for :a will be sampled from bernoulli(0.3), and a new value for :b will be sampled from bernoulli(0.4). If the new value for :b is true, then the previous value for :c (false) will be retained. If the new value for :b is false, then a new value for :d will be sampled from bernoulli(0.7). The previous value for :c will always be retained. Suppose the new value for :a is true, and the new value for :b is true. Then get_choices(new_trace) will be:

│
├── :a : true
│
├── :b : true
│
├── :c : false
│
└── :e : true

The weight (w) is $\log 1 = 0$.

### Extend

(new_trace, weight, retdiff) = extend(trace, args::Tuple, argdiff,
constraints::ChoiceMap)

Extend a trace with new random choices by changing the arguments.

Given a previous trace $(x, t, r)$ (trace), new arguments $x'$ (args), and an assignment $u$ (choices) that shares no addresses with $t$, return a new trace $(x', t', r')$ (new_trace) such that $t'$ agrees with $t$ on all addresses in $t$ and $t'$ agrees with $u$ on all addresses in $u$. Sample $t' \sim Q(\cdot; t + u, x')$ and $r' \sim Q(\cdot; t', x)$. Also return the weight (weight):

$\log \frac{p(r', t'; x') q(r; x, t)}{p(r, t; x) q(t'; t + u, x') q(r'; x', t')}$
source

### Argdiffs

In addition to the input trace, and other arguments that indicate how to adjust the trace, each of these methods also accepts an args argument and an argdiff argument. The args argument contains the new arguments to the generative function, which may differ from the previous arguments to the generative function (which can be retrieved by applying get_args to the previous trace). In many cases, the adjustment to the execution specified by the other arguments to these methods is 'small' and only effects certain parts of the computation. Therefore, it is often possible to generate the new trace and the appropriate log probability ratios required for these methods without revisiting every state of the computation of the generative function. To enable this, the argdiff argument provides additional information about the difference between the previous arguments to the generative function, and the new arguments. This argdiff information permits the implementation of the update method to avoid inspecting the entire argument data structure to identify which parts were updated. Note that the correctness of the argdiff is in general not verified by Gen–-passing incorrect argdiff information may result in incorrect behavior.

The trace update methods for all generative functions above should accept at least the following types of argdiffs:

const noargdiff = NoArgDiff()

Indication that there was no change to the arguments of the generative function.

source
const unknownargdiff = UnknownArgDiff()

Indication that no information is provided about the change to the arguments of the generative function.

source

Generative functions may also accept custom types for their argdiffs that allow more precise information about the different to be supplied. It is the responsibility of the author of a generative function to specify the valid argdiff types in the documentation of their function, and it is the responsibility of the user of a generative function to construct and pass in the appropriate argdiff value.

### Retdiffs

To enable generative functions that invoke other functions to efficiently make use of incremental computation, the trace update methods of generative functions also return a retdiff value, which provides information about the difference in the return value of the previous trace an the return value of the new trace.

Generative functions may return arbitrary retdiff values, provided that the type has the following method:

isnodiff(retdiff)::Bool

Return true if the given retdiff value indicates no change to the return value.

source

It is the responsibility of the author of the generative function to document the possible retdiff values that may be returned, and how the should be interpreted. There are two generic constant retdiff provided for authors of generative functions to use in simple cases:

const defaultretdiff = DefaultRetDiff()

A default retdiff value that provides no information about the return value difference.

source
const noretdiff = NoRetDiff()

A retdiff value that indicates that there was no difference to the return value.

source

## Differentiable programming

Generative functions may support computation of gradients with respect to (i) all or a subset of its arguments, (ii) its trainable parameters, and (iii) the value of certain random choices. The set of elements (either arguments, trainable parameters, or random choices) for which gradients are available is called the gradient source set. A generative function statically reports whether or not it is able to compute gradients with respect to each of its arguments, through the function has_argument_grads. Let $x_G$ denote the set of arguments for which the generative function does support gradient computation. Similarly, a generative function supports gradients with respect the value of random choices made at all or a subset of addresses. If the return value of the function is conditionally independent of each element in the gradient source set given the other elements in the gradient source set and values of all other random choices, for all possible traces of the function, then the generative function requires a return value gradient to compute gradients with respect to elements of the gradient source set. This static property of the generative function is reported by accepts_output_grad.

bools::Tuple = has_argument_grads(gen_fn::Union{GenerativeFunction,Distribution})

Return a tuple of booleans indicating whether a gradient is available for each of its arguments.

source
req::Bool = accepts_output_grad(gen_fn::GenerativeFunction)

Return a boolean indicating whether the return value is dependent on any of the gradient source elements for any trace.

• Any argument whose position is true in has_argument_grads

• Any static parameter

• Random choices made at a set of addresses that are selectable by choice_gradients.

source
arg_grads = accumulate_param_gradients!(trace, retgrad, scaler=1.)

Increment gradient accumulators for parameters by the gradient of the log-probability of the trace, optionally scaled, and return the gradient with respect to the arguments (not scaled).

Given a previous trace $(x, t)$ (trace) and a gradient with respect to the return value $∇_y J$ (retgrad), return the following gradient (arg_grads) with respect to the arguments $x$:

$∇_x \left( \log P(t; x) + J \right)$

Also increment the gradient accumulators for the static parameters $Θ$ of the function by:

$∇_Θ \left( \log P(t; x) + J \right)$
source
(arg_grads, choice_values, choice_grads) = choice_gradients(trace, selection::AddressSet,
retgrad)

Given a previous trace $(x, t)$ (trace) and a gradient with respect to the return value $∇_y J$ (retgrad), return the following gradient (arg_grads) with respect to the arguments $x$:

$∇_x \left( \log P(t; x) + J \right)$

Also given a set of addresses $A$ (selection) that are continuous-valued random choices, return the folowing gradient (choice_grads) with respect to the values of these choices:

$∇_A \left( \log P(t; x) + J \right)$

Also return the assignment (choice_values) that is the restriction of $t$ to $A$.

source
get_params(gen_fn::GenerativeFunction)

Return an iterable over the trainable parameters of the generative function.

source

weight = project(trace::U, selection::AddressSet)

Estimate the probability that the selected choices take the values they do in a trace.

Given a trace $(x, t, r)$ (trace) and a set of addresses $A$ (selection), let $u$ denote the restriction of $t$ to $A$. Return the weight (weight):

$\log \frac{p(r, t; x)}{q(t; u, x) q(r; x, t)}$
source
(choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple)

Sample an assignment and compute the probability of proposing that assignment.

Given arguments (args), sample $t \sim p(\cdot; x)$ and $r \sim p(\cdot; x, t)$, and return $t$ (choices) and the weight (weight):

$\log \frac{p(r, t; x)}{q(r; x, t)}$
source
(weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)

Return the probability of proposing an assignment

Given arguments $x$ (args) and an assignment $t$ (choices) such that $p(t; x) > 0$, sample $r \sim q(\cdot; x, t)$ and return the weight (weight):

$\log \frac{p(r, t; x)}{q(r; x, t)}$

It is an error if $p(t; x) = 0$.

source

## Custom generative function types

Most users can just use generative functions written in the Built-in Modeling Language, and can skip this section. However, to develop new modeling DSLs, or optimized implementations of certain probabilistic modeling components, users can also implement custom types of generative functions. We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the DynamicDSLFunction type as an example.

### Define a trace data type

struct MyTraceType
..
end

### Decide the return type for the generative function

Suppose our return type is Vector{Float64}.

### Define a data type for your generative function

This should be a subtype of GenerativeFunction, with the appropriate type parameters.

struct MyGenerativeFunction <: GenerativeFunction{Vector{Float64},MyTraceType}
..
end

Note that your generative function may not need to have any fields. You can create a constructor for it, e.g.:

function MyGenerativeFunction(...)
..
end

### Decide what the arguments to a generative function should be

For example, our generative functions might take two arguments, a (of type Int) and b (of type Float64). Then, the argument tuple passed to e.g. generate` will have two elements.

NOTE: Be careful to distinguish between arguments to the generative function itself, and arguments to the constructor of the generative function. For example, if you have a generative function type that is parametrized by, for example, modeling DSL code, this DSL code would be a parameter of the generative function constructor.

### Decide what the traced random choices (if any) will be

Remember that each random choice is assigned a unique address in (possibly) hierarchical address space. You are free to design this address space as you wish, although you should document it for users of your generative function type.