Generative Function Combinators

Generative Function Combinators

Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages. However, generative function combinators are not 'higher order generative functions', because they are not themselves generative functions (they are regular Julia functions).

Map combinator

Gen.MapType.
gen_fn = Map(kernel::GenerativeFunction)

Return a new generative function that applies the kernel independently for a vector of inputs.

The returned generative function has one argument with type Vector{X} for each argument of the input generative function with type X. The length of each argument, which must be the same for each argument, determines the number of times the input generative function is called (N). Each call to the input function is made under address namespace i for i=1..N. The return value of the returned function has type FunctionalCollections.PersistentVector{Y} where Y is the type of the return value of the input function. The map combinator is similar to the 'map' higher order function in functional programming, except that the map combinator returns a new generative function that must then be separately applied.

source

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$.

schematic of map combinator

For example, consider the following generative function, which makes one random choice at address :z:

@gen function foo(x1::Float64, x2::Float64)
    y = @addr(normal(x1 + x2, 1.0), :z)
    return y
end

We apply the map combinator to produce a new generative function bar:

bar = Map(foo)

We can then obtain a trace of bar:

(trace, _) = initialize(bar, ([0.0, 0.5], [0.5, 1.0]))

This causes foo to be invoked twice, once with arguments (0.0, 0.5) in address namespace 1 and once with arguments (0.5, 1.0) in address namespace 2. If the resulting trace has random choices:

│
├── 1
│   │
│   └── :z : -0.5757913836706721
│
└── 2
    │
    └── :z : 0.7357177113395333

then the return value is:

FunctionalCollections.PersistentVector{Any}[-0.575791, 0.735718]

Argdiffs

Generative functions produced by this combinator accept the following argdiff types:

argdiff = MapCustomArgDiff{T}(retained_argdiffs::Dict{Int,T})

Construct an argdiff value that contains argdiff information for some subset of applications of the kernel.

If the number of applications of the kernel, which is determined from the the length of hte input vector(s), has changed, then retained_argdiffs may only contain argdiffs for kernel applications that exist both in the previous trace and and the new trace. For each i in keys(retained_argdiffs), retained_argdiffs[i] contains the argdiff information for the ith application. If an entry is not provided for some i that exists in both the previous and new traces, then its argdiff will be assumed to be NoArgDiff.

source

Retdiffs

Generative functions produced by this combinator may return retdiffs that are one of the following types:

retdiff = VectorCustomRetDiff(retained_retdiffs:Dict{Int,Any})

Construct a retdiff that provides retdiff information about some elements of the returned vector.

retdiff[i]

Return the retdiff value for the ith element of the vector.

haskey(retdiff, i::Int)

Return true if there is a retdiff value for the ith element of the vector, or false if there was no difference in this element.

keys(retdiff)

Return an iterator over the elements with retdiff values.

source

Unfold combinator

Gen.UnfoldType.
gen_fn = Unfold(kernel::GenerativeFunction)

Return a new generative function that applies the kernel in sequence, passing the return value of one application as an input to the next.

The kernel accepts the following arguments:

  • The first argument is the Int index indicating the position in the sequence (starting from 1).

  • The second argument is the state.

  • The kernel may have additional arguments after the state.

The return type of the kernel must be the same type as the state.

The returned generative function accepts the following arguments:

  • The number of times (N) to apply the kernel.

  • The initial state.

  • The rest of the arguments (not including the state) that will be passed to each kernel application.

The return type of the returned generative function is FunctionalCollections.PersistentVector{T} where T is the return type of the kernel.

source

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$. The initial state is denoted $y_0$, the number of applications is $n$, and the remaining arguments to the kernel not including the state, are $z$.

schematic of unfold combinator

For example, consider the following kernel, with state type Bool, which makes one random choice at address :z:

@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
    y = @addr(bernoulli(y_prev ? z1 : z2), :y)
    return y
end

We apply the map combinator to produce a new generative function bar:

bar = Map(foo)

We can then obtain a trace of bar:

(trace, _) = initialize(bar, (5, false, 0.05, 0.95))

This causes foo to be invoked five times. The resulting trace may contain the following random choices:

│
├── 1
│   │
│   └── :y : true
│
├── 2
│   │
│   └── :y : false
│
├── 3
│   │
│   └── :y : true
│
├── 4
│   │
│   └── :y : false
│
└── 5
    │
    └── :y : true

then the return value is:

FunctionalCollections.PersistentVector{Any}[true, false, true, false, true]

Argdiffs

Generative functions produced by this combinator accept the following argdiff types:

argdiff = UnfoldCustomArgDiff(init_changed::Bool, params_changed::Bool)

Construct an argdiff that indicates whether the initial state may have changed (init_changed) , and whether or not the remaining arguments to the kernel may have changed (params_changed).

source

Retdiffs

Generative functions produced by this combinator may return retdiffs that are one of the following types:

Recurse combinator

TODO: document me

schematic of recurse combinatokr