Extending Gen

Extending Gen

Gen is designed for extensibility. To implement behaviors that are not directly supported by the existing modeling languages, users can implement `black-box' generative functions directly, without using built-in modeling language. These generative functions can then be invoked by generative functions defined using the built-in modeling language. This includes several special cases:

Custom gradients

To add a custom gradient for a differentiable deterministic computation, define a concrete subtype of CustomGradientGF with the following methods:

For example:

struct MyPlus <: CustomGradientGF{Float64} end

Gen.apply(::MyPlus, args) = args[1] + args[2]
Gen.gradient(::MyPlus, args, retval, retgrad) = (retgrad, retgrad)
Gen.has_argument_grads(::MyPlus) = (true, true)

Abstract type for a generative function with a custom gradient computation, and default behaviors for all other generative function interface methods.

T is the type of the return value.

retval = apply(gen_fn::CustomGradientGF, args)

Apply the function to the arguments.

arg_grads = gradient(gen_fn::CustomDetermGF, args, retval, retgrad)

Return the gradient tuple with respect to the arguments, where nothing is for argument(s) whose gradient is not available.


Custom incremental computation

Iterative inference techniques like Markov chain Monte Carlo involve repeatedly updating the execution traces of generative models. In some cases, the output of a deterministic computation within the model can be incrementally computed during each of these updates, instead of being computed from scratch.

To add a custom incremental computation for a deterministic computation, define a concrete subtype of CustomUpdateGF with the following methods:

The second type parameter of CustomUpdateGF is the type of the state that may be used internally to facilitate incremental computation within update_with_state.

For example, we can implement a function for computing the sum of a vector that efficiently computes the new sum when a small fraction of the vector elements change:

struct MyState

struct MySum <: CustomUpdateGF{Float64,MyState} end

function Gen.apply_with_state(::MySum, args)
    arr = args[1]
    s = sum(arr)
    state = MyState(arr, s)
    (s, state)

function Gen.update_with_state(::MySum, state, args, argdiffs::Tuple{VectorDiff})
    arr = args[1]
    prev_sum = state.sum
    retval = prev_sum
    for i in keys(argdiffs[1].updated)
        retval += (arr[i] - state.prev_arr[i])
    prev_length = length(state.prev_arr)
    new_length = length(arr)
    for i=prev_length+1:new_length
        retval += arr[i]
    for i=new_length+1:prev_length
        retval -= arr[i]
    state = MyState(arr, retval)
    (state, retval, UnknownChange())

Gen.num_args(::MySum) = 1

Abstract type for a generative function with a custom update computation, and default behaviors for all other generative function interface methods.

T is the type of the return value and S is the type of state used internally for incremental computation.

retval, state = apply_with_state(gen_fn::CustomDetermGF, args)

Execute the generative function and return the return value and the state.

state, retval, retdiff = update_with_state(gen_fn::CustomDetermGF, state, args, argdiffs)

Update the arguments to the generative function and return new return value and state.


Custom distributions

Users can extend Gen with new probability distributions, which can then be used to make random choices within generative functions. Simple transformations of existing distributions can be created using the @dist DSL. For arbitrary distributions, including distributions that cannot be expressed in the @dist DSL, users can define a custom distribution by implementing Gen's Distribution interface directly, as defined below.

Probability distributions are singleton types whose supertype is Distribution{T}, where T indicates the data type of the random sample.

abstract type Distribution{T} end

A new Distribution type must implement the following methods:

By convention, distributions have a global constant lower-case name for the singleton value. For example:

struct Bernoulli <: Distribution{Bool} end
const bernoulli = Bernoulli()

Distribution values should also be callable, which is a syntactic sugar with the same behavior of calling random:

bernoulli(0.5) # identical to random(bernoulli, 0.5) and random(Bernoulli(), 0.5)
val::T = random(dist::Distribution{T}, args...)

Sample a random choice from the given distribution with the given arguments.

lpdf = logpdf(dist::Distribution{T}, value::T, args...)

Evaluate the log probability (density) of the value.

has::Bool = has_output_grad(dist::Distribution)

Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice.

grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...)

Compute the gradient of the logpdf with respect to the value, and each of the arguments.

If has_output_grad returns false, then the first element of the returned tuple is nothing. Otherwise, the first element of the tuple is the gradient with respect to the value. If the return value of has_argument_grads has a false value for at position i, then the i+1th element of the returned tuple has value nothing. Otherwise, this element contains the gradient with respect to the ith argument.


Custom generative functions

We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the DynamicDSLFunction type as an example.

Define a trace data type
struct MyTraceType <: Trace
Decide the return type for the generative function

Suppose our return type is Vector{Float64}.

Define a data type for your generative function

This should be a subtype of GenerativeFunction, with the appropriate type parameters.

struct MyGenerativeFunction <: GenerativeFunction{Vector{Float64},MyTraceType}

Note that your generative function may not need to have any fields. You can create a constructor for it, e.g.:

function MyGenerativeFunction(...)
Decide what the arguments to a generative function should be

For example, our generative functions might take two arguments, a (of type Int) and b (of type Float64). Then, the argument tuple passed to e.g. generate will have two elements.

NOTE: Be careful to distinguish between arguments to the generative function itself, and arguments to the constructor of the generative function. For example, if you have a generative function type that is parametrized by, for example, modeling DSL code, this DSL code would be a parameter of the generative function constructor.

Decide what the traced random choices (if any) will be

Remember that each random choice is assigned a unique address in (possibly) hierarchical address space. You are free to design this address space as you wish, although you should document it for users of your generative function type.

Implement methods of the Generative Function Interface

At minimum, you need to implement the following methods:

If you want to use the generative function within models, you should implement:

If you want to use MCMC on models that call your generative function, then implement:

If you want to use gradient-based inference techniques on models that call your generative function, then implement:

If your generative function has trainable parameters, then implement:

Custom modeling languages

Gen can be extended with new modeling languages by implementing new generative function types, and constructors for these types that take models as input. This typically requires implementing the entire generative function interface, and is advanced usage of Gen.