Generative Function Combinators

Generative Function Combinators

Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages. However, generative function combinators are not 'higher order generative functions', because they are not themselves generative functions (they are regular Julia functions).

Map combinator

Gen.MapType.
gen_fn = Map(kernel::GenerativeFunction)

Return a new generative function that applies the kernel independently for a vector of inputs.

The returned generative function has one argument with type Vector{X} for each argument of the input generative function with type X. The length of each argument, which must be the same for each argument, determines the number of times the input generative function is called (N). Each call to the input function is made under address namespace i for i=1..N. The return value of the returned function has type FunctionalCollections.PersistentVector{Y} where Y is the type of the return value of the input function. The map combinator is similar to the 'map' higher order function in functional programming, except that the map combinator returns a new generative function that must then be separately applied.

source

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$.

schematic of map combinator

For example, consider the following generative function, which makes one random choice at address :z:

@gen function foo(x1::Float64, x2::Float64)
    y = @trace(normal(x1 + x2, 1.0), :z)
    return y
end

We apply the map combinator to produce a new generative function bar:

bar = Map(foo)

We can then obtain a trace of bar:

(trace, _) = generate(bar, ([0.0, 0.5], [0.5, 1.0]))

This causes foo to be invoked twice, once with arguments (0.0, 0.5) in address namespace 1 and once with arguments (0.5, 1.0) in address namespace 2. If the resulting trace has random choices:

│
├── 1
│   │
│   └── :z : -0.5757913836706721
│
└── 2
    │
    └── :z : 0.7357177113395333

then the return value is:

FunctionalCollections.PersistentVector{Any}[-0.575791, 0.735718]

Unfold combinator

Gen.UnfoldType.
gen_fn = Unfold(kernel::GenerativeFunction)

Return a new generative function that applies the kernel in sequence, passing the return value of one application as an input to the next.

The kernel accepts the following arguments:

  • The first argument is the Int index indicating the position in the sequence (starting from 1).

  • The second argument is the state.

  • The kernel may have additional arguments after the state.

The return type of the kernel must be the same type as the state.

The returned generative function accepts the following arguments:

  • The number of times (N) to apply the kernel.

  • The initial state.

  • The rest of the arguments (not including the state) that will be passed to each kernel application.

The return type of the returned generative function is FunctionalCollections.PersistentVector{T} where T is the return type of the kernel.

source

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$. The initial state is denoted $y_0$, the number of applications is $n$, and the remaining arguments to the kernel not including the state, are $z$.

schematic of unfold combinator

For example, consider the following kernel, with state type Bool, which makes one random choice at address :z:

@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
    y = @trace(bernoulli(y_prev ? z1 : z2), :y)
    return y
end

We apply the map combinator to produce a new generative function bar:

bar = Unfold(foo)

We can then obtain a trace of bar:

(trace, _) = generate(bar, (5, false, 0.05, 0.95))

This causes foo to be invoked five times. The resulting trace may contain the following random choices:

│
├── 1
│   │
│   └── :y : true
│
├── 2
│   │
│   └── :y : false
│
├── 3
│   │
│   └── :y : true
│
├── 4
│   │
│   └── :y : false
│
└── 5
    │
    └── :y : true

then the return value is:

FunctionalCollections.PersistentVector{Any}[true, false, true, false, true]

Recurse combinator

TODO: document me

schematic of recurse combinatokr