# Scaling with the Static Modeling Language

## Introduction

For prototyping models and working with dynamic structures, Gen's Dynamic Modeling Language is a great (and the default) way of writing probabilistic programs in nearly pure Julia. However, better performance and scaling characteristics can be obtained using specialized modeling languages or modeling constructs. This notebook introduces a more specialized modeling language known as the Static Modeling Language (SML) which is also built into Gen. The SML provides model speedups by carefully analyzing what work is necessary during inference.

**Prerequisites for this tutorial**

This tutorial will take the robust regression model used to introduce iterative inference in [an earlier tutorial] and optimize the speed of inference using the SML.

## Slow Inference Program Case Study

```
using Gen
using Plots
@gen function model(xs::Vector{Float64})
slope ~ normal(0, 2)
intercept ~ normal(0, 2)
noise ~ gamma(1, 1)
prob_outlier ~ uniform(0, 1)
n = length(xs)
ys = Vector{Float64}(undef, n)
for i = 1:n
if ({:data => i => :is_outlier} ~ bernoulli(prob_outlier))
(mu, std) = (0., 10.)
else
(mu, std) = (xs[i] * slope + intercept, noise)
end
ys[i] = {:data => i => :y} ~ normal(mu, std)
end
ys
end
```

`DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Vector{Float64}], false, Union{Nothing, Some{Any}}[nothing], Main.var"##model#441", Bool[0], false)`

We wrote a Markov chain Monte Carlo inference update for this model that performs updates on each of the 'global' parameters (noise, slope, intercept, and prob*outlier), as well as the 'local' `is*outlier` variable associated with each data point. The update takes a trace as input, and returns the new trace as output. We reproduce this here:

```
function block_resimulation_update(tr)
# Block 1: Update the line's parameters
line_params = select(:noise, :slope, :intercept)
(tr, _) = mh(tr, line_params)
# Blocks 2-N+1: Update the outlier classifications
(xs,) = get_args(tr)
n = length(xs)
for i=1:n
(tr, _) = mh(tr, select(:data => i => :is_outlier))
end
# Block N+2: Update the prob_outlier parameter
(tr, _) = mh(tr, select(:prob_outlier))
# Return the updated trace
tr
end
```

`block_resimulation_update (generic function with 1 method)`

We write a helper function that takes a vector of y-coordinates and populates a constraints choice map:

```
function make_constraints(ys::Vector{Float64})
constraints = choicemap()
for i=1:length(ys)
constraints[:data => i => :y] = ys[i]
end
constraints
end
```

`make_constraints (generic function with 1 method)`

Finally, we package this into an inference program that takes the data set of all x- and y-coordinates ,and returns a trace. We will be experimenting with different variants of the model, so we make the model an argument to this function:

```
function block_resimulation_inference(model, xs, ys)
observations = make_constraints(ys)
(tr, _) = generate(model, (xs,), observations)
for iter=1:500
tr = block_resimulation_update(tr)
end
tr
end
```

`block_resimulation_inference (generic function with 1 method)`

Let's see how the running time of this inference program changes as we increase the number of data points. We don't expect the running time to depend too much on the actual values of the data points, so we just construct a random data set for each run:

```
ns = [1, 3, 7, 10, 30, 70, 100]
times = []
for n in ns
xs = rand(n)
ys = rand(n)
start = time_ns()
tr = block_resimulation_inference(model, xs, ys)
push!(times, (time_ns() - start) / 1e9)
end
nothing
```

We now plot the running time versus the number of data points:

`plot(ns, times, xlabel="number of data points", ylabel="running time (seconds)", label=nothing)`

The inference program seems to scale quadratically in the number of data points.

To understand why, consider the block of code inside `block_resimulation_update`

that loops over the data points:

```
# Blocks 2-N+1: Update the outlier classifications
(xs,) = get_args(tr)
n = length(xs)
for i=1:n
(tr, _) = mh(tr, select(:data => i => :is_outlier))
end
```

The reason for the quadratic scaling is that the running time of the call to `mh`

inside this loop also grows in proportion to the number of data points. This is because the updates to a trace of a model written the generic built-in modeling language always involve re-running **the entire** model generative function.

However, it should be possible for the algorithm to scale linearly in the number of data points. Briefly, deciding whether to update a given `is_outlier`

variable can be done without referencing the other data points. This is because each `is_outiler`

variable is conditionally independent of the outlier variables and y-coordinates of the other data points, conditioned on the parameters.

We can make this conditional independence structure explicit using the Map generative function combinator. Combinators like map encapsulate common modeling patterns (e.g., a loop in which each iteration is making independent choices), and when you use them, Gen can take advantage of the restrictions they enforce to implement performance optimizations automatically during inference. The `Map`

combinator, like the `map`

function in a functional programming language, helps to execute the same generative code repeatedly.

## Rewriting the Program with Combinators

To use the map combinator to express the conditional independences in our model, we first write a generative function to generate the `is_outlier`

variable and the y-coordinate for a single data point:

```
@gen function generate_single_point(x::Float64, prob_outlier::Float64, noise::Float64,
slope::Float64, intercept::Float64)
is_outlier ~ bernoulli(prob_outlier)
mu = is_outlier ? 0. : x * slope + intercept
std = is_outlier ? 10. : noise
y ~ normal(mu, std)
return y
end;
```

`DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false)`

We then apply the `Map`

, which is a Julia function, to this generative function, to obtain a new generative function:

`generate_all_points = Map(generate_single_point);`

`Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false))`

This new generative function has one argument for each argument of `generate_single_point`

, except that these arguments are now vector-valued instead of scalar-valued. We can run the generative function on some fake data to test this:

```
xs = Float64[0, 1, 2, 3, 4]
prob_outliers = fill(0.5, 5)
noises = fill(0.2, 5)
slopes = fill(0.7, 5)
intercepts = fill(-2.0, 5)
trace = simulate(generate_all_points, (xs, prob_outliers, noises, slopes, intercepts));
```

`Gen.VectorTrace{Gen.MapType, Any, Gen.DynamicDSLTrace}(Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Float64}(20.284440959615864, -5.278816351419428, NaN, true), :is_outlier => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -5.971963531979373, 0.0, (0.0, 0.5, 0.2, 0.7, -2.0), 20.284440959615864), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Float64}(-2.1063451364355843, -3.243707075367648, NaN, true), :is_outlier => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -3.936854255927593, 0.0, (1.0, 0.5, 0.2, 0.7, -2.0), -2.1063451364355843), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Float64}(-0.3796604209588168, 0.08363025307885441, NaN, true), :is_outlier => Gen.ChoiceOrCallRecord{Bool}(false, -0.6931471805599453, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.6095169274810909, 0.0, (2.0, 0.5, 0.2, 0.7, -2.0), -0.3796604209588168), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Float64}(-7.867943504745139, -3.5310463011680246, NaN, true), :is_outlier => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -4.22419348172797, 0.0, (3.0, 0.5, 0.2, 0.7, -2.0), -7.867943504745139), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64, Float64, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing, nothing], Main.var"##generate_single_point#443", Bool[0, 0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Float64}(-2.467603369814431, -3.2519689581523163, NaN, true), :is_outlier => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -3.9451161387122617, 0.0, (4.0, 0.5, 0.2, 0.7, -2.0), -2.467603369814431)], Any[20.284440959615864, -2.1063451364355843, -0.3796604209588168, -7.867943504745139, -2.467603369814431], ([0.0, 1.0, 2.0, 3.0, 4.0], [0.5, 0.5, 0.5, 0.5, 0.5], [0.2, 0.2, 0.2, 0.2, 0.2], [0.7, 0.7, 0.7, 0.7, 0.7], [-2.0, -2.0, -2.0, -2.0, -2.0]), 5, 5, -18.68764433582829, 0.0)`

We see that the `generate_all_points`

function has traced 5 calls to `generate_single_point`

, under namespaces `1`

through `5`

. The `Map`

combinator automatically adds these indices to the trace address.

`get_choices(trace)`

```
│
├── 1
│ │
│ ├── :y : 20.284440959615864
│ │
│ └── :is_outlier : true
│
├── 2
│ │
│ ├── :y : -2.1063451364355843
│ │
│ └── :is_outlier : true
│
├── 3
│ │
│ ├── :y : -0.3796604209588168
│ │
│ └── :is_outlier : false
│
├── 4
│ │
│ ├── :y : -7.867943504745139
│ │
│ └── :is_outlier : true
│
└── 5
│
├── :y : -2.467603369814431
│
└── :is_outlier : true
```

Now, let's replace the Julia `for`

loop in our model with a call to this new function:

```
@gen function model_with_map(xs::Vector{Float64})
slope ~ normal(0, 2)
intercept ~ normal(0, 2)
noise ~ gamma(1, 1)
prob_outlier ~ uniform(0, 1)
n = length(xs)
data ~ generate_all_points(xs, fill(prob_outlier, n), fill(noise, n), fill(slope, n), fill(intercept, n))
return data
end;
```

`DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Vector{Float64}], false, Union{Nothing, Some{Any}}[nothing], Main.var"##model_with_map#444", Bool[0], false)`

Note that this new model has the same address structure as our original model had, so our inference code will not need to change. For example, the 5th data point's $y$ coordinate will be stored at the address `:data => 5 => :y`

, just as before. (The `:data`

comes from our `data ~ ...`

invocation in the `better_model`

definition, and the `:y`

comes from `generate_point`

; only the `5`

has been inserted automatically by `Map`

.)

```
trace = simulate(model_with_map, (xs,));
get_choices(trace)
```

```
│
├── :intercept : 1.2682537328550278
│
├── :slope : -2.5084458345941254
│
├── :prob_outlier : 0.08561037483554168
│
├── :noise : 0.8060620212868442
│
└── :data
│
├── 1
│ │
│ ├── :y : 1.020666043971035
│ │
│ └── :is_outlier : false
│
├── 2
│ │
│ ├── :y : -2.699652782727517
│ │
│ └── :is_outlier : false
│
├── 3
│ │
│ ├── :y : -3.7446683199416957
│ │
│ └── :is_outlier : false
│
├── 4
│ │
│ ├── :y : -5.671080203999857
│ │
│ └── :is_outlier : false
│
└── 5
│
├── :y : 4.047762698055234
│
└── :is_outlier : true
```

Let's test the running time of the inference program, applied to this new model:

```
with_map_times = []
for n in ns
xs = rand(n)
ys = rand(n)
start = time_ns()
tr = block_resimulation_inference(model_with_map, xs, ys)
push!(with_map_times, (time_ns() - start) / 1e9)
end
```

```
┌ Warning: Assignment to `xs` in soft scope is ambiguous because a global variable by the same name exists: `xs` will be treated as a new local. Disambiguate by using `local xs` to suppress this warning or `global xs` to assign to the existing global variable.
└ @ scaling_with_sml.md:194
```

We plot the results and compare them to the original model, which used the Julia `for`

loop:

```
plot(ns, times, label="original", xlabel="number of data points", ylabel="running time (seconds)")
plot!(ns, with_map_times, label="with map")
```

We see that the quadratic scaling did not improve. In fact, we actually got a that happed was a constant factor **slowdown**.

We can understand why we still have quadratic scaling, by examining the call to `generate_single_point`

:

`data ~ generate_all_points(xs, fill(prob_outlier, n), fill(noise, n), fill(slope, n), fill(intercept, n))`

Even though the function `generate_all_points`

knows that each of the calls to `generate_single_point`

is conditionally independent, and even it knows that each update to `is_outlier`

only involves a single application of `generate_single_point`

, it does not know that **none of its arguments change** within an update to `is_outlier`

. Therefore, it needs to visit each call to `generate_single_point`

. The generic built-in modeling language does not provide this information the generative functions that it invokes.

## Rewriting in the Static Modeling Language

In order to provide `generate_all_points`

with the knowledge that its arguments do not change during an update to the `is_outlier`

variable, we need to write the top-level model generative function that calls `generate_all_points`

in the Static Modeling Language, which is a restricted variant of the built-in modeling language that uses static analysis of the computation graph to generate specialized trace data structures and specialized implementations of trace operations. We indicate that a function is to be interpreted using the static language using the `static`

annotation:

```
@gen (static) function static_model_with_map(xs::Vector{Float64})
slope ~ normal(0, 2)
intercept ~ normal(0, 2)
noise ~ gamma(1, 1)
prob_outlier ~ uniform(0, 1)
n = length(xs)
data ~ generate_all_points(xs, fill(prob_outlier, n), fill(noise, n), fill(slope, n), fill(intercept, n))
return data
end
```

`Main.var"##StaticGenFunction_static_model_with_map#491"(Dict{Symbol, Any}(), Dict{Symbol, Any}())`

The static language has a number of restrictions that make it more amenable to static analysis than the unrestricted modeling language. For example, we cannot use Julia `for`

loops, and the return value needs to explicitly use the `return`

keyword, followed by a symbol (e.g. `data`

). Also, each symbol used on the left-hand side of an assignment must be unique. A more complete list of restrictions is given in the documentation.

Below, we show the static dependency graph that Gen builds for this function. Arguments are shown as diamonds, Julia computations are shown as squares, random choices are shown as circles, and calls to other generative function are shown as stars. The call that produces the return value of the function is shaded in blue.

<img src="graph.png" width="100%"/>

Now, consider the update to the `is_outlier`

variable:

`(tr, _) = mh(tr, select(:data => i => :is_outlier))`

Because this update only causes values under address `:data`

to change, the `static_model_with_map`

function can use the graph above to infer that none of the arguments to `generate_all_point`

could have possibly changed. This will allow us to obtain the linear scaling we expected.

However, before we can use a function written in the static modeling language, we need to run the following function (this is required for technical reasons, because functions written in the static modeling language use a staged programming feature of Julia called *generated functions*).

`Gen.@load_generated_functions`

Finally, we can re-run the experiment with our model that combines the map combinator with the static language:

```
static_with_map_times = []
for n in ns
xs = rand(n)
ys = rand(n)
start = time_ns()
tr = block_resimulation_inference(static_model_with_map, xs, ys)
push!(static_with_map_times, (time_ns() - start) / 1e9)
end
nothing
```

```
┌ Warning: Assignment to `xs` in soft scope is ambiguous because a global variable by the same name exists: `xs` will be treated as a new local. Disambiguate by using `local xs` to suppress this warning or `global xs` to assign to the existing global variable.
└ @ scaling_with_sml.md:259
```

We compare the results to the results for the earlier models:

```
plot(ns, times, label="original", xlabel="number of data points", ylabel="running time (seconds)")
plot!(ns, with_map_times, label="with map")
plot!(ns, static_with_map_times, label="with map and static outer fn")
```

We see that we now have the linear running time that we expected.

## Benchmarking the Performance Gain

**Note:** *the following section was drafted using an earlier version of Julia. As of Julia 1.7, the dynamic modeling language is fast enough in some cases that you may not see constant-factor performance gains by switching simple dynamic models, with few choices and no control flow, to use the static modeling language. Based on the experiment below, this model falls into that category.*

Note that in our latest model above, `generate_single_point`

was still written in the dynamic modeling language. It is not necessary to write `generate_single_point`

in the static language, but doing so might provide modest constant-factor performance improvements. Here we rewrite this function in the static language. The static modeling language does not support `if`

statements, but does support ternary expressions (`a ? b : c`

):

```
@gen (static) function static_generate_single_point(x::Float64, prob_outlier::Float64, noise::Float64,
slope::Float64, intercept::Float64)
is_outlier ~ bernoulli(prob_outlier)
mu = is_outlier ? 0. : x * slope + intercept
std = is_outlier ? 10. : noise
y ~ normal(mu, std)
return y
end;
static_generate_all_points = Map(static_generate_single_point);
@gen (static) function fully_static_model_with_map(xs::Vector{Float64})
slope ~ normal(0, 2)
intercept ~ normal(0, 2)
noise ~ gamma(1, 1)
prob_outlier ~ uniform(0, 1)
n = length(xs)
data ~ static_generate_all_points(xs, fill(prob_outlier, n), fill(noise, n), fill(slope, n), fill(intercept, n))
return data
end;
Gen.@load_generated_functions
```

```
┌ Warning: `Gen.@load_generated_functions` is no longer necessary and will be removed in a future release.
└ @ Gen ~/work/Gen.jl/Gen.jl/src/Gen.jl:33
```

Now, we re-run the experiment with our new model:

```
fully_static_with_map_times = []
let # end
for n in ns
xs = rand(n)
ys = rand(n)
start = time_ns()
tr = block_resimulation_inference(fully_static_model_with_map, xs, ys)
push!(fully_static_with_map_times, (time_ns() - start) / 1e9)
end
```

In earlier versions of Julia, we saw a modest improvement in running time, but here (running Julia 1.7.1) we see it makes little to no difference:

```
plot(ns, times, label="original", xlabel="number of data points", ylabel="running time (seconds)")
plot!(ns, with_map_times, label="with map")
plot!(ns, static_with_map_times, label="with map and static outer fn")
plot!(ns, fully_static_with_map_times, label="with map and static outer and inner fns")
```

# Checking the Inference Programs

Before wrapping up, let's confirm that all of our models are giving good results:

Let's use a synthetic data set:

```
true_inlier_noise = 0.5
true_outlier_noise = 10.
prob_outlier = 0.1
true_slope = -1
true_intercept = 2
xs = collect(range(-5, stop=5, length=50))
ys = Float64[]
for (i, x) in enumerate(xs)
if rand() < prob_outlier
y = 0. + randn() * true_outlier_noise
else
y = true_slope * x + true_intercept + randn() * true_inlier_noise
end
push!(ys, y)
end
ys[end-3] = 14
ys[end-5] = 13;
scatter(xs, ys, xlim=(-7,7), ylim=(-7,15), label=nothing)
```

We write a trace rendering function that shows the inferred line on top of the observed data set:

```
function render_trace(trace, title)
xs, = get_args(trace)
xlim = [-5, 5]
slope = trace[:slope]
intercept = trace[:intercept]
plot(xlim, slope * xlim .+ intercept, color="black", xlim=(-7,7), ylim=(-7,15), title=title, label=nothing)
ys = [trace[:data => i => :y] for i=1:length(xs)]
scatter!(xs, ys, label=nothing)
end;
```

`render_trace (generic function with 1 method)`

Finally, we run the experiment. We will visualize just one trace produced by applying our inference program to each of the four variants of our model:

```
tr = block_resimulation_inference(model, xs, ys)
fig1 = render_trace(tr, "model")
tr = block_resimulation_inference(model_with_map, xs, ys)
fig2 = render_trace(tr, "model with map")
tr = block_resimulation_inference(static_model_with_map, xs, ys)
fig3 = render_trace(tr, "static model with map")
tr = block_resimulation_inference(fully_static_model_with_map, xs, ys)
fig4 = render_trace(tr, "fully static model with map")
plot(fig1, fig2, fig3, fig4)
```

It looks like inference in all the models seems to be working reasonably.