Probability Distributions
Gen provides a library of built-in probability distributions, and two ways of writing custom distributions, both of which are explained below:
The
@distconstructor, for a distribution that can be expressed as a simple deterministic transformation (technically, a pushforward) of an existing distribution.An API for defining arbitrary custom distributions in plain Julia code.
Built-In Distributions
Gen.bernoulli — Constant.bernoulli(prob_true::Real)Samples a Bool value which is true with given probability
Gen.beta — Constant.beta(alpha::Real, beta::Real)Sample a Float64 from a beta distribution.
Gen.beta_uniform — Constant.beta_uniform(theta::Real, alpha::Real, beta::Real)Samples a Float64 value from a mixture of a uniform distribution on [0, 1] with probability 1-theta and a beta distribution with parameters alpha and beta with probability theta.
Gen.binom — Constant.binom(n::Integer, p::Real)Sample an Int from the Binomial distribution with parameters n (number of trials) and p (probability of success in each trial).
Gen.categorical — Constant.categorical(probs::AbstractArray{U, 1}) where {U <: Real}Given a vector of probabilities probs where sum(probs) = 1, sample an Int i from the set {1, 2, .., length(probs)} with probability probs[i].
Gen.exponential — Constant.exponential(rate::Real)Sample a Float64 from the exponential distribution with rate parameter rate.
Gen.gamma — Constant.gamma(shape::Real, scale::Real)Sample a Float64 from a gamma distribution.
Gen.geometric — Constant.geometric(p::Real)Sample an Int from the Geometric distribution with parameter p.
Gen.inv_gamma — Constant.inv_gamma(shape::Real, scale::Real)Sample a Float64 from a inverse gamma distribution.
Gen.laplace — Constant.laplce(loc::Real, scale::Real)Sample a Float64 from a laplace distribution.
Gen.mvnormal — Constant.mvnormal(mu::AbstractVector{T}, cov::AbstractMatrix{U}} where {T<:Real,U<:Real}Samples a Vector{Float64} value from a multivariate normal distribution.
Gen.neg_binom — Constant.neg_binom(r::Real, p::Real)Sample an Int from a Negative Binomial distribution. Returns the number of failures before the rth success in a sequence of independent Bernoulli trials. r is the number of successes (which may be fractional) and p is the probability of success per trial.
Gen.normal — Constant.normal(mu::Real, std::Real)Samples a Float64 value from a normal distribution.
Gen.piecewise_uniform — Constant.piecewise_uniform(bounds, probs)Samples a Float64 value from a piecewise uniform continuous distribution.
There are n bins where n = length(probs) and n + 1 = length(bounds). Bounds must satisfy bounds[i] < bounds[i+1] for all i. The probability density at x is zero if x <= bounds[1] or x >= bounds[end] and is otherwise probs[bin] / (bounds[bin] - bounds[bin+1]) where bounds[bin] < x <= bounds[bin+1].
Gen.poisson — Constant.poisson(lambda::Real)Sample an Int from the Poisson distribution with rate lambda.
Gen.uniform — Constant.uniform(low::Real, high::Real)Sample a Float64 from the uniform distribution on the interval [low, high].
Gen.uniform_discrete — Constant.uniform_discrete(low::Integer, high::Integer)Sample an Int from the uniform distribution on the set {low, low + 1, ..., high-1, high}.
Defining New Distributions Inline with the @dist DSL
The @dist DSL allows the user to concisely define a distribution, as long as that distribution can be expressed as a certain type of deterministic transformation of an existing distribution. The syntax of the @dist DSL, as well as the class of permitted deterministic transformations, are explained below.
@dist name(arg1, arg2, ..., argN) = bodyor
@dist function name(arg1, arg2, ..., argN)
body
endHere body is ordinary Julia code, with the constraint that body must contain exactly one random choice. The value of the @dist expression is then a Gen.Distribution object called name, parameterized by arg1, ..., argN, representing the distribution over return values of body.
This DSL is designed to address the issue that sometimes, values stored in the trace do not correspond to the most natural physical elements of the model state space, making inference programming and querying more taxing than necessary. For example, suppose we have a model of classes at a school, where the number of students is random, with mean 10, but always at least 3. Rather than writing the model as
@gen function class_model()
n_students = @trace(poisson(7), :n_students_minus_3) + 3
...
endand thinking about the random variable :n_students_minus_3, you can use the @dist DSL to instead write
@dist student_distr(mean, min) = poisson(mean-min) + min
@gen function class_model()
n_students = @trace(student_distr(10, 3), :n_students)
...
endand think about the more natural random variable :n_students. This leads to more natural inference programs, which can constrain and propose directly to the :n_students trace address.
Permitted constructs for the body of a @dist
It is not possible for @dist to work on any arbitrary body. We now describe which constructs are permitted inside the body of a @dist expression.
We can think of the body of an @dist function as containing ordinary Julia code, except that in addition to being described by their ordinary Julia types, each expression also belongs to one of three "type spaces." These are:
CONST: Constants, whose value is known at the time this@distexpression is evaluated.ARG: Arguments and (deterministic, differentiable) functions of arguments. All expressions representing non-random values that depend on distribution arguments areARGexpressions.RND: Random variables. All expressions whose runtime values may differ across multiple calls to this distribution (with the same arguments) areRNDexpressions.
Importantly, Julia control flow constructs generally expect CONST values: the condition of an if or the range of a for loop cannot be ARG or RND.
The body expression as a whole must be a RND expression, representing a random variable. The behavior of the @dist definition is then to define a new distribution (with name name) that samples and evaluates the logpdf of the random variable represented by the body expression.
Expressions are typed compositionally, with the following typing rules:
Literals and free variables are
CONSTs. Literals and symbols that appear free in the@distbody are of typeCONST.Arguments are
ARGs. Symbols bound as arguments in the@distdeclaration have typeARGin its body.Drawing from a distribution gives
RND. Ifdis a distribution, andx_iare of typeARGorCONST,d(x_1, x_2, ...)is of typeRND.Functions of
CONSTs areCONSTs. Iffis a deterministic function andx_iare all of typeCONST,f(x_1, x_2, ...)is of typeCONST.Functions of
CONSTs andARGs areARGs. Iffis a differentiable function, and eachx_iis either aCONSTor a scalarARG(with at least onex_ibeing anARG), thenf(x_1, x_2, ...)is of typeARG.Functions of
CONSTs,ARGs, andRNDs areRNDs. Iffis one of a special set of deterministic functions we've defined (+,-,*,/,exp,log,getindex), and exactly one of its argumentsx_iis of typeRND, thenf(x_1, x_2, ...)is of typeRND.
One way to think about this, without all the rules, is that CONST values are "contaminated" by interaction with ARG values (becoming ARGs themselves), and both CONST and ARG are "contaminated" by interaction with RND. Thinking of the body as an AST, the journey from leaf node to root node always involves transitions in the direction of CONST -> ARG -> RND, never in reverse.
Restrictions
Users may not reassign to arguments (like x in the above example), and may not apply functions with side effects. Names bound to expressions of type RND must be used only once. e.g., let x = normal(0, 1) in x + x is not allowed.
Examples
Let's walk through some examples.
@dist f(x) = exp(normal(x, 1))We can annotate with types:
1 :: CONST (by rule 1)
x :: ARG (by rule 2)
normal(x, 1) :: RND (by rule 3)
exp(normal(x, 1)) :: RND (by rule 6)Here's another:
@dist function labeled_cat(labels, probs)
index = categorical(probs)
labels[index]
endAnd the types:
probs :: ARG (by rule 2)
categorical(probs) :: RND (by rule 3)
index :: RND (Julia assignment)
labels :: ARG (by rule 2)
labels[index] :: RND (by rule 6, f == getindex)Note that getindex is designed to work on anything indexible, not just vectors. So, for example, it also works with Dicts.
Another one (not as realistic, but it uses all the rules):
@dist function weird(x)
log(normal(exp(x), exp(x))) + (x * (2 + 3))
endAnd the types:
2, 3 :: CONST (by rule 1)
2 + 3 :: CONST (by rule 4)
x :: ARG (by rule 2)
x * (2 + 3) :: ARG (by rule 5)
exp(x) :: ARG (by rule 5)
normal(exp(x), exp(x)) :: RND (by rule 3)
log(normal(exp(x), exp(x))) :: RND (by rule 6)
log(normal(exp(x), exp(x))) + (x * (2 + 3)) :: RND (by rule 6)Defining New Distributions From Scratch
For distributions that cannot be expressed in the @dist DSL, users can define a custom distribution by defining an (ordinary Julia) subtype of Gen.Distribution and implementing the methods of the Distribution API. This method requires more custom code than using the @dist DSL, but also affords more flexibility: arbitrary user-defined logic for sampling, PDF evaluation, etc.