# Generative Function Combinators

Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages. However, generative function combinators are not 'higher order generative functions', because they are not themselves generative functions (they are regular Julia functions).

## Map combinator

`Gen.Map`

— Type.`gen_fn = Map(kernel::GenerativeFunction)`

Return a new generative function that applies the kernel independently for a vector of inputs.

The returned generative function has one argument with type `Vector{X}`

for each argument of the input generative function with type `X`

. The length of each argument, which must be the same for each argument, determines the number of times the input generative function is called (N). Each call to the input function is made under address namespace i for i=1..N. The return value of the returned function has type `FunctionalCollections.PersistentVector{Y}`

where `Y`

is the type of the return value of the input function. The map combinator is similar to the 'map' higher order function in functional programming, except that the map combinator returns a new generative function that must then be separately applied.

If `kernel`

has optional trailing arguments, the corresponding `Vector`

arguments can be omitted from calls to `Map(kernel)`

.

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$.

For example, consider the following generative function, which makes one random choice at address `:z`

:

```
@gen function foo(x1::Float64, x2::Float64)
y = @trace(normal(x1 + x2, 1.0), :z)
return y
end
```

We apply the map combinator to produce a new generative function `bar`

:

`bar = Map(foo)`

We can then obtain a trace of `bar`

:

`(trace, _) = generate(bar, ([0.0, 0.5], [0.5, 1.0]))`

This causes `foo`

to be invoked twice, once with arguments `(0.0, 0.5)`

in address namespace `1`

and once with arguments `(0.5, 1.0)`

in address namespace `2`

. If the resulting trace has random choices:

```
│
├── 1
│ │
│ └── :z : -0.5757913836706721
│
└── 2
│
└── :z : 0.7357177113395333
```

then the return value is:

`FunctionalCollections.PersistentVector{Any}[-0.575791, 0.735718]`

## Unfold combinator

`Gen.Unfold`

— Type.`gen_fn = Unfold(kernel::GenerativeFunction)`

Return a new generative function that applies the kernel in sequence, passing the return value of one application as an input to the next.

The kernel accepts the following arguments:

The first argument is the

`Int`

index indicating the position in the sequence (starting from 1).The second argument is the

*state*.The kernel may have additional arguments after the state.

The return type of the kernel must be the same type as the state.

The returned generative function accepts the following arguments:

The number of times (N) to apply the kernel.

The initial state.

The rest of the arguments (not including the state) that will be passed to each kernel application.

The return type of the returned generative function is `FunctionalCollections.PersistentVector{T}`

where `T`

is the return type of the kernel.

If `kernel`

has optional trailing arguments, the corresponding arguments can be omitted from calls to `Unfold(kernel)`

.

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$. The initial state is denoted $y_0$, the number of applications is $n$, and the remaining arguments to the kernel not including the state, are $z$.

For example, consider the following kernel, with state type `Bool`

, which makes one random choice at address `:z`

:

```
@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
y = @trace(bernoulli(y_prev ? z1 : z2), :y)
return y
end
```

We apply the map combinator to produce a new generative function `bar`

:

`bar = Unfold(foo)`

We can then obtain a trace of `bar`

:

`(trace, _) = generate(bar, (5, false, 0.05, 0.95))`

This causes `foo`

to be invoked five times. The resulting trace may contain the following random choices:

```
│
├── 1
│ │
│ └── :y : true
│
├── 2
│ │
│ └── :y : false
│
├── 3
│ │
│ └── :y : true
│
├── 4
│ │
│ └── :y : false
│
└── 5
│
└── :y : true
```

then the return value is:

`FunctionalCollections.PersistentVector{Any}[true, false, true, false, true]`

## Recurse combinator

TODO: document me

## Switch combinator

`Gen.Switch`

— Type.`gen_fn = Switch(gen_fns::GenerativeFunction...)`

Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}`

where the first index indicates which branch to call.

`gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T`

Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}`

or an argument tuple of type `Tuple{T, ...}`

where the first index either indicates which branch to call, or indicates an index into `d`

which maps to the selected branch. This form is meant for convenience - it allows the programmer to use `d`

like if-else or case statements.

`Switch`

is designed to allow for the expression of patterns of if-else control flow. `gen_fns`

must satisfy a few requirements:

- Each
`gen_fn`

in`gen_fns`

must accept the same argument types. - Each
`gen_fn`

in`gen_fns`

must return the same return type.

Otherwise, each `gen_fn`

can come from different modeling languages, possess different traces, etc.

Consider the following constructions:

```
@gen function bang((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + y, std), :z)
return z
end
@gen function fuzz((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + 2 * y, std), :z)
return z
end
sc = Switch(bang, fuzz)
```

This creates a new generative function `sc`

. We can then obtain the trace of `sc`

:

`(trace, _) = simulate(sc, (2, 5.0, 3.0))`

The resulting trace contains the subtrace from the branch with index `2`

- in this case, a call to `fuzz`

:

```
│
└── :z : 13.552870875213735
```