# Extending Gen

Gen is designed for extensibility. To implement behaviors that are not directly supported by the existing modeling languages, users can implement `black-box' generative functions directly, without using built-in modeling language. These generative functions can then be invoked by generative functions defined using the built-in modeling language. This includes several special cases:

Extending Gen with custom gradient computations

Extending Gen with custom incremental computation of return values

Extending Gen with new modeling languages.

## Custom gradients

To add a custom gradient for a differentiable deterministic computation, define a concrete subtype of `CustomGradientGF`

with the following methods:

For example:

```
struct MyPlus <: CustomGradientGF{Float64} end
Gen.apply(::MyPlus, args) = args[1] + args[2]
Gen.gradient(::MyPlus, args, retval, retgrad) = (retgrad, retgrad)
Gen.has_argument_grads(::MyPlus) = (true, true)
```

`Gen.CustomGradientGF`

— Type.`CustomGradientGF{T}`

Abstract type for a generative function with a custom gradient computation, and default behaviors for all other generative function interface methods.

`T`

is the type of the return value.

`Gen.apply`

— Function.`retval = apply(gen_fn::CustomGradientGF, args)`

Apply the function to the arguments.

`Gen.gradient`

— Function.`arg_grads = gradient(gen_fn::CustomDetermGF, args, retval, retgrad)`

Return the gradient tuple with respect to the arguments, where `nothing`

is for argument(s) whose gradient is not available.

## Custom incremental computation

Iterative inference techniques like Markov chain Monte Carlo involve repeatedly updating the execution traces of generative models. In some cases, the output of a deterministic computation within the model can be incrementally computed during each of these updates, instead of being computed from scratch.

To add a custom incremental computation for a deterministic computation, define a concrete subtype of `CustomUpdateGF`

with the following methods:

The second type parameter of `CustomUpdateGF`

is the type of the state that may be used internally to facilitate incremental computation within `update_with_state`

.

For example, we can implement a function for computing the sum of a vector that efficiently computes the new sum when a small fraction of the vector elements change:

```
struct MyState
prev_arr::Vector{Float64}
sum::Float64
end
struct MySum <: CustomUpdateGF{Float64,MyState} end
function Gen.apply_with_state(::MySum, args)
arr = args[1]
s = sum(arr)
state = MyState(arr, s)
(s, state)
end
function Gen.update_with_state(::MySum, state, args, argdiffs::Tuple{VectorDiff})
arr = args[1]
prev_sum = state.sum
retval = prev_sum
for i in keys(argdiffs[1].updated)
retval += (arr[i] - state.prev_arr[i])
end
prev_length = length(state.prev_arr)
new_length = length(arr)
for i=prev_length+1:new_length
retval += arr[i]
end
for i=new_length+1:prev_length
retval -= arr[i]
end
state = MyState(arr, retval)
(state, retval, UnknownChange())
end
Gen.num_args(::MySum) = 1
```

`Gen.CustomUpdateGF`

— Type.`CustomUpdateGF{T,S}`

Abstract type for a generative function with a custom update computation, and default behaviors for all other generative function interface methods.

`T`

is the type of the return value and `S`

is the type of state used internally for incremental computation.

`Gen.apply_with_state`

— Function.`retval, state = apply_with_state(gen_fn::CustomDetermGF, args)`

Execute the generative function and return the return value and the state.

`Gen.update_with_state`

— Function.`state, retval, retdiff = update_with_state(gen_fn::CustomDetermGF, state, args, argdiffs)`

Update the arguments to the generative function and return new return value and state.

## Custom distributions

Users can extend Gen with new probability distributions, which can then be used to make random choices within generative functions. Simple transformations of existing distributions can be created using the `@dist`

DSL. For arbitrary distributions, including distributions that cannot be expressed in the `@dist`

DSL, users can define a custom distribution by implementing Gen's Distribution interface directly, as defined below.

Probability distributions are singleton types whose supertype is `Distribution{T}`

, where `T`

indicates the data type of the random sample.

`abstract type Distribution{T} end`

A new Distribution type must implement the following methods:

By convention, distributions have a global constant lower-case name for the singleton value. For example:

```
struct Bernoulli <: Distribution{Bool} end
const bernoulli = Bernoulli()
```

Distribution values should also be callable, which is a syntactic sugar with the same behavior of calling `random`

:

`bernoulli(0.5) # identical to random(bernoulli, 0.5) and random(Bernoulli(), 0.5)`

`Gen.random`

— Function.`val::T = random(dist::Distribution{T}, args...)`

Sample a random choice from the given distribution with the given arguments.

`Gen.logpdf`

— Function.`lpdf = logpdf(dist::Distribution{T}, value::T, args...)`

Evaluate the log probability (density) of the value.

`Gen.has_output_grad`

— Function.`has::Bool = has_output_grad(dist::Distribution)`

Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice.

`Gen.logpdf_grad`

— Function.`grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...)`

Compute the gradient of the logpdf with respect to the value, and each of the arguments.

If `has_output_grad`

returns false, then the first element of the returned tuple is `nothing`

. Otherwise, the first element of the tuple is the gradient with respect to the value. If the return value of `has_argument_grads`

has a false value for at position `i`

, then the `i+1`

th element of the returned tuple has value `nothing`

. Otherwise, this element contains the gradient with respect to the `i`

th argument.

## Custom generative functions

We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the `DynamicDSLFunction`

type as an example.

##### Define a trace data type

```
struct MyTraceType <: Trace
..
end
```

##### Decide the return type for the generative function

Suppose our return type is `Vector{Float64}`

.

##### Define a data type for your generative function

This should be a subtype of `GenerativeFunction`

, with the appropriate type parameters.

```
struct MyGenerativeFunction <: GenerativeFunction{Vector{Float64},MyTraceType}
..
end
```

Note that your generative function may not need to have any fields. You can create a constructor for it, e.g.:

```
function MyGenerativeFunction(...)
..
end
```

##### Decide what the arguments to a generative function should be

For example, our generative functions might take two arguments, `a`

(of type `Int`

) and `b`

(of type `Float64`

). Then, the argument tuple passed to e.g. `generate`

will have two elements.

NOTE: Be careful to distinguish between arguments to the generative function itself, and arguments to the constructor of the generative function. For example, if you have a generative function type that is parametrized by, for example, modeling DSL code, this DSL code would be a parameter of the generative function constructor.

##### Decide what the traced random choices (if any) will be

Remember that each random choice is assigned a unique address in (possibly) hierarchical address space. You are free to design this address space as you wish, although you should document it for users of your generative function type.

##### Implement methods of the Generative Function Interface

At minimum, you need to implement the following methods:

If you want to use the generative function within models, you should implement:

If you want to use MCMC on models that call your generative function, then implement:

If you want to use gradient-based inference techniques on models that call your generative function, then implement:

If your generative function has trainable parameters, then implement:

## Custom modeling languages

Gen can be extended with new modeling languages by implementing new generative function types, and constructors for these types that take models as input. This typically requires implementing the entire generative function interface, and is advanced usage of Gen.