Generative Function Combinators

Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages. However, generative function combinators are not 'higher order generative functions', because they are not themselves generative functions (they are regular Julia functions).

Map combinator

Gen.MapType
gen_fn = Map(kernel::GenerativeFunction)

Return a new generative function that applies the kernel independently for a vector of inputs.

The returned generative function has one argument with type Vector{X} for each argument of the input generative function with type X. The length of each argument, which must be the same for each argument, determines the number of times the input generative function is called (N). Each call to the input function is made under address namespace i for i=1..N. The return value of the returned function has type FunctionalCollections.PersistentVector{Y} where Y is the type of the return value of the input function. The map combinator is similar to the 'map' higher order function in functional programming, except that the map combinator returns a new generative function that must then be separately applied.

If kernel has optional trailing arguments, the corresponding Vector arguments can be omitted from calls to Map(kernel).

source

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$.

schematic of map combinator

For example, consider the following generative function, which makes one random choice at address :z:

@gen function foo(x1::Float64, x2::Float64)
    y = @trace(normal(x1 + x2, 1.0), :z)
    return y
end

We apply the map combinator to produce a new generative function bar:

bar = Map(foo)

We can then obtain a trace of bar:

(trace, _) = generate(bar, ([0.0, 0.5], [0.5, 1.0]))

This causes foo to be invoked twice, once with arguments (0.0, 0.5) in address namespace 1 and once with arguments (0.5, 1.0) in address namespace 2. If the resulting trace has random choices:

│
├── 1
│   │
│   └── :z : -0.5757913836706721
│
└── 2
    │
    └── :z : 0.7357177113395333

then the return value is:

FunctionalCollections.PersistentVector{Any}[-0.575791, 0.735718]

Unfold combinator

Gen.UnfoldType
gen_fn = Unfold(kernel::GenerativeFunction)

Return a new generative function that applies the kernel in sequence, passing the return value of one application as an input to the next.

The kernel accepts the following arguments:

  • The first argument is the Int index indicating the position in the sequence (starting from 1).

  • The second argument is the state.

  • The kernel may have additional arguments after the state.

The return type of the kernel must be the same type as the state.

The returned generative function accepts the following arguments:

  • The number of times (N) to apply the kernel.

  • The initial state.

  • The rest of the arguments (not including the state) that will be passed to each kernel application.

The return type of the returned generative function is FunctionalCollections.PersistentVector{T} where T is the return type of the kernel.

If kernel has optional trailing arguments, the corresponding arguments can be omitted from calls to Unfold(kernel).

source

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$. The initial state is denoted $y_0$, the number of applications is $n$, and the remaining arguments to the kernel not including the state, are $z$.

schematic of unfold combinator

For example, consider the following kernel, with state type Bool, which makes one random choice at address :z:

@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
    y = @trace(bernoulli(y_prev ? z1 : z2), :y)
    return y
end

We apply the map combinator to produce a new generative function bar:

bar = Unfold(foo)

We can then obtain a trace of bar:

(trace, _) = generate(bar, (5, false, 0.05, 0.95))

This causes foo to be invoked five times. The resulting trace may contain the following random choices:

│
├── 1
│   │
│   └── :y : true
│
├── 2
│   │
│   └── :y : false
│
├── 3
│   │
│   └── :y : true
│
├── 4
│   │
│   └── :y : false
│
└── 5
    │
    └── :y : true

then the return value is:

FunctionalCollections.PersistentVector{Any}[true, false, true, false, true]

Recurse combinator

TODO: document me

schematic of recurse combinatokr

Switch combinator

Gen.SwitchType
gen_fn = Switch(gen_fns::GenerativeFunction...)

Returns a new generative function that accepts an argument tuple of type Tuple{Int, ...} where the first index indicates which branch to call.

gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T

Returns a new generative function that accepts an argument tuple of type Tuple{Int, ...} or an argument tuple of type Tuple{T, ...} where the first index either indicates which branch to call, or indicates an index into d which maps to the selected branch. This form is meant for convenience - it allows the programmer to use d like if-else or case statements.

Switch is designed to allow for the expression of patterns of if-else control flow. gen_fns must satisfy a few requirements:

  1. Each gen_fn in gen_fns must accept the same argument types.
  2. Each gen_fn in gen_fns must return the same return type.

Otherwise, each gen_fn can come from different modeling languages, possess different traces, etc.

source
schematic of switch combinator

Consider the following constructions:

@gen function bang((grad)(x::Float64), (grad)(y::Float64))
    std::Float64 = 3.0
    z = @trace(normal(x + y, std), :z)
    return z
end

@gen function fuzz((grad)(x::Float64), (grad)(y::Float64))
    std::Float64 = 3.0
    z = @trace(normal(x + 2 * y, std), :z)
    return z
end

sc = Switch(bang, fuzz)

This creates a new generative function sc. We can then obtain the trace of sc:

(trace, _) = simulate(sc, (2, 5.0, 3.0))

The resulting trace contains the subtrace from the branch with index 2 - in this case, a call to fuzz:

│
└── :z : 13.552870875213735