Probability Distributions

Probability Distributions

Gen provides a library of built-in probability distributions, and two ways of writing custom distributions, both of which are explained below:

  1. The @dist constructor, for a distribution that can be expressed as a simple deterministic transformation (technically, a pushforward) of an existing distribution.

  2. An API for defining arbitrary custom distributions in plain Julia code.

Built-In Distributions


Samples a Bool value which is true with given probability

beta(alpha::Real, beta::Real)

Sample a Float64 from a beta distribution.

beta_uniform(theta::Real, alpha::Real, beta::Real)

Samples a Float64 value from a mixture of a uniform distribution on [0, 1] with probability 1-theta and a beta distribution with parameters alpha and beta with probability theta.

binom(n::Integer, p::Real)

Sample an Int from the Binomial distribution with parameters n (number of trials) and p (probability of success in each trial).

categorical(probs::AbstractArray{U, 1}) where {U <: Real}

Given a vector of probabilities probs where sum(probs) = 1, sample an Int i from the set {1, 2, .., length(probs)} with probability probs[i].


Sample a Float64 from the exponential distribution with rate parameter rate.

gamma(shape::Real, scale::Real)

Sample a Float64 from a gamma distribution.


Sample an Int from the Geometric distribution with parameter p.

inv_gamma(shape::Real, scale::Real)

Sample a Float64 from a inverse gamma distribution.

laplce(loc::Real, scale::Real)

Sample a Float64 from a laplace distribution.

mvnormal(mu::AbstractVector{T}, cov::AbstractMatrix{U}} where {T<:Real,U<:Real}

Samples a Vector{Float64} value from a multivariate normal distribution.

neg_binom(r::Real, p::Real)

Sample an Int from a Negative Binomial distribution. Returns the number of failures before the rth success in a sequence of independent Bernoulli trials. r is the number of successes (which may be fractional) and p is the probability of success per trial.

normal(mu::Real, std::Real)

Samples a Float64 value from a normal distribution.

piecewise_uniform(bounds, probs)

Samples a Float64 value from a piecewise uniform continuous distribution.

There are n bins where n = length(probs) and n + 1 = length(bounds). Bounds must satisfy bounds[i] < bounds[i+1] for all i. The probability density at x is zero if x <= bounds[1] or x >= bounds[end] and is otherwise probs[bin] / (bounds[bin] - bounds[bin+1]) where bounds[bin] < x <= bounds[bin+1].


Sample an Int from the Poisson distribution with rate lambda.

uniform(low::Real, high::Real)

Sample a Float64 from the uniform distribution on the interval [low, high].

uniform_discrete(low::Integer, high::Integer)

Sample an Int from the uniform distribution on the set {low, low + 1, ..., high-1, high}.


Defining New Distributions Inline with the @dist DSL

The @dist DSL allows the user to concisely define a distribution, as long as that distribution can be expressed as a certain type of deterministic transformation of an existing distribution. The syntax of the @dist DSL, as well as the class of permitted deterministic transformations, are explained below.

@dist name(arg1, arg2, ..., argN) = body


@dist function name(arg1, arg2, ..., argN)

Here body is ordinary Julia code, with the constraint that body must contain exactly one random choice. The value of the @dist expression is then a Gen.Distribution object called name, parameterized by arg1, ..., argN, representing the distribution over return values of body.

This DSL is designed to address the issue that sometimes, values stored in the trace do not correspond to the most natural physical elements of the model state space, making inference programming and querying more taxing than necessary. For example, suppose we have a model of classes at a school, where the number of students is random, with mean 10, but always at least 3. Rather than writing the model as

@gen function class_model()
   n_students = @trace(poisson(7), :n_students_minus_3) + 3

and thinking about the random variable :n_students_minus_3, you can use the @dist DSL to instead write

@dist student_distr(mean, min) = poisson(mean-min) + min

@gen function class_model()
   n_students = @trace(student_distr(10, 3), :n_students)

and think about the more natural random variable :n_students. This leads to more natural inference programs, which can constrain and propose directly to the :n_students trace address.

Permitted constructs for the body of a @dist

It is not possible for @dist to work on any arbitrary body. We now describe which constructs are permitted inside the body of a @dist expression.

We can think of the body of an @dist function as containing ordinary Julia code, except that in addition to being described by their ordinary Julia types, each expression also belongs to one of three "type spaces." These are:

  1. CONST: Constants, whose value is known at the time this @dist expression is evaluated.
  2. ARG: Arguments and (deterministic, differentiable) functions of arguments. All expressions representing non-random values that depend on distribution arguments are ARG expressions.
  3. RND: Random variables. All expressions whose runtime values may differ across multiple calls to this distribution (with the same arguments) are RND expressions.

Importantly, Julia control flow constructs generally expect CONST values: the condition of an if or the range of a for loop cannot be ARG or RND.

The body expression as a whole must be a RND expression, representing a random variable. The behavior of the @dist definition is then to define a new distribution (with name name) that samples and evaluates the logpdf of the random variable represented by the body expression.

Expressions are typed compositionally, with the following typing rules:

  1. Literals and free variables are CONSTs. Literals and symbols that appear free in the @dist body are of type CONST.

  2. Arguments are ARGs. Symbols bound as arguments in the @dist declaration have type ARG in its body.

  3. Drawing from a distribution gives RND. If d is a distribution, and x_i are of type ARG or CONST, d(x_1, x_2, ...) is of type RND.

  4. Functions of CONSTs are CONSTs. If f is a deterministic function and x_i are all of type CONST, f(x_1, x_2, ...) is of type CONST.

  5. Functions of CONSTs and ARGs are ARGs. If f is a differentiable function, and each x_i is either a CONST or a scalar ARG (with at least one x_i being an ARG), then f(x_1, x_2, ...) is of type ARG.

  6. Functions of CONSTs, ARGs, and RNDs are RNDs. If f is one of a special set of deterministic functions we've defined (+, -, *, /, exp, log, getindex), and exactly one of its arguments x_i is of type RND, then f(x_1, x_2, ...) is of type RND.

One way to think about this, without all the rules, is that CONST values are "contaminated" by interaction with ARG values (becoming ARGs themselves), and both CONST and ARG are "contaminated" by interaction with RND. Thinking of the body as an AST, the journey from leaf node to root node always involves transitions in the direction of CONST -> ARG -> RND, never in reverse.


Users may not reassign to arguments (like x in the above example), and may not apply functions with side effects. Names bound to expressions of type RND must be used only once. e.g., let x = normal(0, 1) in x + x is not allowed.


Let's walk through some examples.

@dist f(x) = exp(normal(x, 1))

We can annotate with types:

1 :: CONST		  (by rule 1)
x :: ARG 		  (by rule 2)
normal(x, 1) :: RND 	  (by rule 3)
exp(normal(x, 1)) :: RND  (by rule 6)

Here's another:

@dist function labeled_cat(labels, probs)
	index = categorical(probs)

And the types:

probs :: ARG 			(by rule 2)
categorical(probs) :: RND 	(by rule 3)
index :: RND 			(Julia assignment)
labels :: ARG 			(by rule 2)
labels[index] :: RND 		(by rule 6, f == getindex)

Note that getindex is designed to work on anything indexible, not just vectors. So, for example, it also works with Dicts.

Another one (not as realistic, but it uses all the rules):

@dist function weird(x)
  log(normal(exp(x), exp(x))) + (x * (2 + 3))

And the types:

2, 3 :: CONST 						(by rule 1)
2 + 3 :: CONST 						(by rule 4)
x :: ARG 						(by rule 2)
x * (2 + 3) :: ARG 					(by rule 5)
exp(x) :: ARG 						(by rule 5)
normal(exp(x), exp(x)) :: RND 				(by rule 3)
log(normal(exp(x), exp(x))) :: RND 			(by rule 6)
log(normal(exp(x), exp(x))) + (x * (2 + 3)) :: RND 	(by rule 6)

Defining New Distributions From Scratch

For distributions that cannot be expressed in the @dist DSL, users can define a custom distribution by defining an (ordinary Julia) subtype of Gen.Distribution and implementing the methods of the Distribution API. This method requires more custom code than using the @dist DSL, but also affords more flexibility: arbitrary user-defined logic for sampling, PDF evaluation, etc.