Generative Combinators

Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages.

Map combinator

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$.

schematic of map combinator

For example, consider the following generative function, which makes one random choice at address r^2:

using Gen
@gen function foo(x, y, z)
    r ~ normal(x^2 + y^2 + z^2, 1.0)
    return r
end
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false)

We apply the map combinator to produce a new generative function bar:

bar = Map(foo)
Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false))

We can then obtain a trace of bar:

trace, _ = generate(bar, ([0.0, 0.5], [0.5, 1.0], [1.0, -1.0]))
trace
Gen.VectorTrace{Gen.MapType, Any, Gen.DynamicDSLTrace}(Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:r => Gen.ChoiceOrCallRecord{Float64}(0.8716112694695082, -0.9905275489009112, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.9905275489009112, 0.0, (0.0, 0.5, 1.0), 0.8716112694695082), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:r => Gen.ChoiceOrCallRecord{Float64}(2.2011632799317287, -0.920131045818186, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.920131045818186, 0.0, (0.5, 1.0, -1.0), 2.2011632799317287)], Any[0.8716112694695082, 2.2011632799317287], ([0.0, 0.5], [0.5, 1.0], [1.0, -1.0]), 2, 2, -1.9106585947190973, 0.0)

This causes foo to be invoked twice, once with arguments (0.0, 0.5, 1.0) in address namespace 1 and once with arguments (0.5, 1.0, -1.0) in address namespace 2.

get_choices(trace)
│
├── 1
│   │
│   └── :r : 0.8716112694695082
│
└── 2
    │
    └── :r : 2.2011632799317287

If the resulting trace has random choices: then the return value is:

get_retval(trace)
Persistent{Any}[0.8716112694695082, 2.2011632799317287]

Unfold combinator

In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$. The initial state is denoted $y_0$, the number of applications is $n$, and the remaining arguments to the kernel not including the state, are $z$.

schematic of unfold combinator

For example, consider the following kernel, with state type Bool, which makes one random choice at address :z:

using Gen
@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
    y = @trace(bernoulli(y_prev ? z1 : z2), :y)
    return y
end
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false)

We apply the map combinator to produce a new generative function bar:

bar = Unfold(foo)
Unfold{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false))

We can then obtain a trace of bar:

trace, _ = generate(bar, (5, false, 0.05, 0.95))
trace
Gen.VectorTrace{Gen.UnfoldType, Any, Gen.DynamicDSLTrace}(Unfold{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (1, false, 0.05, 0.95), true), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(false, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (2, true, 0.05, 0.95), false), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (3, false, 0.05, 0.95), true), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -2.995732273553991, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -2.995732273553991, 0.0, (4, true, 0.05, 0.95), true), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(false, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (5, true, 0.05, 0.95), false)], Any[true, false, true, true, false], (5, false, 0.05, 0.95), 5, 5, -3.2009054511041932, 0.0)

This causes foo to be invoked five times. The resulting trace may contain the following random choices:

get_choices(trace)
│
├── 1
│   │
│   └── :y : true
│
├── 2
│   │
│   └── :y : false
│
├── 3
│   │
│   └── :y : true
│
├── 4
│   │
│   └── :y : true
│
└── 5
    │
    └── :y : false

then the return value is:

get_retval(trace)
Persistent{Any}[true, false, true, true, false]

Switch combinator

schematic of switch combinator

Consider the following constructions:

@gen function line(x)
    z ~ normal(3*x+1,1.0)
    return z
end

@gen function outlier(x)
    z ~ normal(3*x+1, 10.0)
    return z
end

switch_model = Switch(line, outlier)
Switch{Int64, 2, Tuple{DynamicDSLFunction{Any}, DynamicDSLFunction{Any}}, Any}((DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], Main.var"##line#278", Bool[0], false), DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], Main.var"##outlier#279", Bool[0], false)), Dict{Int64, Int64}())

This creates a new generative function switch_model whose arguments take the form (branch, args...). By default, branch is an integer indicating which generative function to execute. For example, branch 2 corresponds to outlier:

trace = simulate(switch_model, (2, 5.0))
get_choices(trace)
│
└── :z : 20.353293276948396