# Trace Translators

While Generative Functions define probability distributions on traces, **Trace Translators** convert from one space of traces to another space of traces. Trace translators are building blocks of inference programs that utilize multiple model representations, like Involutive MCMC.

Trace translators are significantly more general than Bijectors. Trace translators can (i) convert between spaces of traces that include mixed numeric discrete random choices, as well as stochastic control flow, and (ii) convert between spaces for which there is no one-to-one correspondence (e.g. between models of different dimensionality, or between discrete and continuous models). Bijectors are limited to deterministic transformations between real-valued vectors of constant dimension.

## Deterministic Trace Translators

Inference programs manipulate traces, but they also keep track of probabilities and probability densities associated with these traces. Suppose we have two generative functions `p1`

and `p2`

. Given a trace `t2`

of `p2`

we can easily compute the probability (or probability density) that the trace would have been generated by `p2`

using `get_score(t2)`

. But suppose we want to construct the trace of `p2`

first sampling a trace `t1`

of `p1`

and then applying a deterministic transformation to that trace to obtain `t2`

. How can we compute the probability that a trace `t2`

would have been produced by this process? This probability is needed if, for example, `p2`

defines a probabilistic model and want to use `p1`

as a proposal distribution within importance sampling. If we produce `t2`

via an arbitrary deterministic transformation of the random choices in `t1`

, then computing the necessary probability is difficult.

If we restrict ourselves to deterministic transformations that are *bijections* (one-to-one correspondences) from the set of traces of `p1`

to the set of traces of `p2`

, then the problem is much simplified. If the transformation is a bijection this means that (i) each trace of `p1`

gets mapped to a different trace of `p2`

, and (ii) for every trace of `p2`

there is some trace of `p1`

that maps to it. Bijective transformations between traces are useful components of inference programs because the probability that a given trace `t2`

of `p2`

would have been produced by first sampling from `p1`

and then applying the transform can be computed simply as the probability that `p1`

would produce the (unique) trace `t1`

that gets mapped to the given trace by the transform. Conceptually, bijective trace transforms *preserve probability*. When trace transforms operate on traces with continuous random choices, computing probability densities of the transformed traces requires computing a Jacobian associated with the continuous part of the transformation.

Gen provides a DSL for expressing bijections between spaces of traces, called the **Trace Transform DSL**. We introduce this DSL via an example. Below are two generative functions. The first samples polar coordinates and the second uses cartesian coordinates.

```
@gen function p1()
r ~ inv_gamma(1, 1)
theta ~ uniform(-pi/2, pi/2)
end
```

```
@gen function p2()
x ~ normal(0, 1)
y ~ normal(0, 1)
end
```

### Defining a trace transform with the Trace Transform DSL

The following trace transform DSL program defines a transformation (called `f`

) that transforms traces of `p1`

into traces of `p2`

:

```
@transform f (t1) to (t2) begin
r = @read(t1[:r], :continuous)
theta = @read(t1[:theta], :continuous)
@write(t2[:x], r * cos(theta), :continuous)
@write(t2[:y], r * sin(theta), :continuous)
end
```

This transform reads values of random choices in the input trace (`t1`

) at specific addresses (indicated by the syntax `t1[addr]`

) using `@read`

and writes values to the output trace (`t2`

) using `@write`

. Each read and write operation is labeled with whether the random choice is discrete or continuous. The section Trace Transform DSL defines the DSL in more detail.

It is usually a good idea to write the inverse of the bijection. The inverse can provide a dynamic check that the transform truly is a bijection. The inverse of the above transformation is:

```
@transform finv (t2) to (t1) begin
x = @read(t2[:x], :continuous)
y = @read(t2[:y], :continuous)
r = sqrt(x^2 + y^2)
@write(t1[:r], sqrt(x^2 + y^2), :continuous)
@write(t1[:theta], atan(y, x), :continuous)
end
```

We can inform Gen that two transforms are inverses of one another using `pair_bijections!`

:

`pair_bijections!(f, finv)`

### Wrapping a trace transform in a trace translator

Note that the transform DSL code does not specify what the two generative functions are, or what the arguments to these generative functions are. This information will be required for computing probabilities and probability densities of traces. We provide this information by constructing a **Trace Translator** that wraps the transform along with this transformation:

`translator = DeterministicTraceTranslator(p2, (), choicemap(), f)`

We then can then apply the translator to a trace of `p1`

using function call syntax. The translator returns a trace of `p2`

and a log-weight that we can use to compute the probability (density) of the resulting trace:

`t2, log_weight = translator(t1)`

Specifically, the log probability (density) that the trace `t2`

was produced by first sampling `t1 = simulate(p1, ())`

and then applying the trace translator, is:

`get_score(t1) + log_weight`

Let's unpack the previous few code blocks in more detail. First, note that we did not pass in the source generative function (`p1`

) or its arguments (`()`

) when we constructed the trace translator. This is because this information will be attached to the input trace `t1`

itself. We *did* need to pass in the target generative function (`p2`

) and its arguments (`()`

) however, because this information is not included in `t1`

.

In this case, because continuous random choices are involved, the probabilities are probability densities, and the trace translator used automatic differentiation of the code in the transform `f`

to compute a change-of-variables Jacobian that is necessary to compute the correct probability density of the new trace `t2`

.

### Observations

Typically, there are observations associated with one or both of the generative functions involved, and we have values for these in a choice map, so we do not want the trace translator to be responsible for populating the values of these observed random choices. For example, suppose we want to condition `p2`

on an observed random choice `z`

:

```
@gen function p2()
x ~ normal(0, 1)
y ~ normal(0, 1)
z ~ normal(x + y, 0.1)
end
observations = choicemap()
observations[:z] = 2.3
```

When `p2`

has observations, these can be passed in as an additional argument to the trace translator constructor:

`translator = DeterministicTraceTranslator(p2, (), observations, f)`

### Discrete random choices and stochastic control flow

Trace transforms and trace translators interoperate seamlessly with discrete random choices and stochastic control flow. Both the input and the output traces can contain a mix of discrete and continuous choices, and arbitrary stochastic control flow. Consider the following simple example, where we use two different discrete representations to represent probability distributions the integers 0-7:

```
@gen function p1()
bit1 ~ bernoulli(0.5)
bit2 ~ bernoulli(0.5)
bit3 ~ bernoulli(0.5)
end
```

```
@gen function p2()
n ~ categorical([0.1, 0.1, 0.1, 0.2, 0.2, 0.15, 0.15])
end
```

We define the forward and inverse transforms:

```
@transform f (t1) to (t2) begin
n = (
@read(t1[:bit1], :discrete) * 1 +
@read(t1[:bit2], :discrete) * 2 +
@read(t1[:bit3], :discrete) * 4)
@write(t2[:n], n, :discrete)
end
```

```
@transform finv (t2) to (t1) begin
bits = digits(@read(t2[:n], :discrete), base=2)
@write(t1[:bit1], bits[1], :discrete)
@write(t1[:bit2], bits[2], :discrete)
@write(t1[:bit3], bits[3], :discrete)
end
```

Here is an example that includes discrete random choices, stochastic control flow, and continuous random choices.

```
@gen function p1()
if ({:branch} ~ bernoulli(0.5))
x ~ normal(0, 1)
else
other ~ categorical([0.3, 0.7])
end
end
```

```
@gen function p2()
k ~ uniform_discrete(1, 4)
if k <= 2
y ~ gamma(1, 1)
end
end
```

Note that transformations between spaces of traces need not be intuitive (although they probably should)! Try to convince yourself that the functions below are indeed a pair of bijections between the traces of these two generative functions.

```
@transform f (t1) to (t2) begin
if @read(t1[:branch], :discrete)
x = @read(t1[:x], :continuous)
if x > 0
@write(t2[:k], 2, :discrete)
else
@write(t2[:k], 1, :discrete)
end
@write(t2[:y], abs(x), :continuous)
else
other = @read(t1[:other], :discrete)
@write(t2[:k], (other == 1) ? 3 : 4, :discrete)
end
end
```

```
@transform finv (t2) to (t1) begin
k = @read(t2[:k], :discrete)
if k <= 2
y = @read(t2[:y], :continuous)
@write(t2[:x], (k == 1) ? -y : y, :continuous)
else
@write(t1[:other], (k == 3) ? 1 : 2, :discrete)
end
end
```

## General Trace Translators

Note that for two arbitrary generative functions `p1`

and `p2`

there may not exist any one-to-one correspondence between their spaces of traces. For example, consider a generative function `p1`

that samples points within the unit square $[0, 1]^2$

```
@gen function p1()
x ~ uniform(0, 1)
y ~ uniform(0, 1)
end
```

and another generative function `p2`

that samples one of 100 possible discrete values, each value representing one cell of the unit square:

```
@gen function p2()
i ~ uniform_discrete(1, 10) # interval [(i-1)/10, i/10]
j ~ uniform_discrete(1, 10) # interval [(j-1)/10, j/10]
end
```

There is no one-to-one correspondence between the spaces of traces of these two generative functions: The first is an uncountably infinite set, and the other is a finite set with 100 elements in it.

However, there is an intuitive notion of correspondence that we would like to be able to encode. Each discrete cell $(i, j)$ corresponds to a subset of the unit square $[(i - 1)/10, i/10] \times [(j-1)/10, j/10]$.

We can express this correspondence (and any correspondence between two arbitrary generative functions) by introducing two auxiliary generative functions `q1`

and `q2`

. The first function `q1`

will take a trace of `p1`

as input, and the second function `q2`

will take a trace of `p2`

as input. Then, instead of a transfomation between traces of `p1`

and traces of `p2`

our trace transform will transform between (i) the space of pairs of traces of `p1`

and `q1`

and (ii) the space of pairs of traces of `p2`

and `q2`

. We construct `q1`

and `q2`

so that the two spaces have the same size, and a one-to-one correspondence is possible.

For our example above, we construct `q2`

to sample the coordinate ($[0, 0.1]^2$) relative to the cell. We construct `q1`

to be empty–there is already a mapping from each trace of `p1`

to each trace of `p2`

that simply identifies what cell $(i, j)$ a given point in $[0, 1]^2$ is in, so no extra random choices are needed.

```
@gen function q1(p1_trace)
end
@gen function q2(p2_trace)
dx ~ uniform(0.0, 0.1)
dy ~ uniform(0.0, 0.1)
end
```

### Trace transforms between pairs of traces

To handle general trace translators that require auxiliary probability distributions, the trace trace DSL supports defining transformations between *pairs* of traces. For example, the following defines a trace transform that maps from pairs of traces of `p1`

and `q1`

to pairs of traces of `p2`

and `q2`

:

```
@transform f (p1_trace, q1_trace) to (p2_trace, q2_trace) begin
x = @read(p1_trace[:x], :continuous)
y = @read(p1_trace[:y], :continuous)
i = ceil(x * 10)
j = ceil(y * 10)
@write(p2_trace[:i], i, :discrete)
@write(p2_trace[:j], j, :discrete)
@write(q2_trace[:dx], x - (i-1)/10, :continuous)
@write(q2_trace[:dy], y - (j-1)/10, :continuous)
end
```

and the inverse transform:

```
@transform f_inv (p2_trace, q2_trace) to (p1_trace, q1_trace) begin
i = @read(p2_trace[:i], :discrete)
j = @read(p2_trace[:j], :discrete)
dx = @read(q2_trace[:dx], :continuous)
dy = @read(q2_trace[:dy], :continuous)
x = (i-1)/10 + dx
y = (j-1)/10 + dy
@write(p1_trace[:x], x, :continuous)
@write(p1_trace[:y], y, :continuous)
end
```

which we associate as inverses:

`pair_bijections!(f, f_inv)`

### Constructing a general trace translator

We now wrap the transform above into a general trace translator, by providing the three probabilistic programs `p2`

, `q1`

, `q2`

that it uses (a reference to `p1`

will included in the input trace), and the arguments to these functions.

```
translator = GeneralTraceTranslator(
p_new=p2,
p_new_args=(),
new_observations=choicemap(),
q_forward=q1,
q_forward_args=(),
q_backward=q2,
q_backward_args=(),
f=f)
```

Then, we can apply the trace translator to a trace (`t1`

) of `p1`

and get a trace (`t2`

) of `p2`

and a log-weight:

`t2, log_weight = translator(t1)`

## Symmetric Trace Translators

When the previous and new generative functions (e.g. `p1`

and `p2`

in the previous example) are the same, and their arguments are the same, and `q_forward`

and `q_backward`

(and their arguments) are also identical, we call this the trace translator a **Symmetric Trace Translator**. Symmetric trace translators are important because they form the basis of Involutive MCMC. Instead of translating a trace of one generative function to the trace of another generative function, they translate a trace of a generative function to another trace of the *same* generative function.

Symmetric trace translators have the interesting property that the function `f`

is an **involution**, or a function that is its own inverse. To indicate that a trace transform is an involution, use `is_involution!`

.

Because symmetric trace translators translate within the same generative function, their implementation uses `update`

to incrementally modify the trace from the previous to the new trace. This has two benefits when the previous and new traces have random choices that aren't modified between them: (i) the incremental modification may be more efficient than writing the new trace entirely from scratch, and (ii) the transform DSL program does not need to specify a value for addresses whose value is not changed from the previous trace.

## Simple Extending Trace Translators

Simple extending trace translators extend an existing trace with new random choices sampled from a proposal distribution, as well as any new observations. The arguments of the trace may also be updated. This type of trace translation is the basic operation used in Particle Filtering. For example, we might have a model that sequentially samples new latent variables `(:z, t)`

and observations `(:x, t)`

for each timestep `t`

:

```
@gen function model(T::Int)
for t in 1:T
z = {(:z, t)} ~ normal(0, 1)
x = {(:x, t)} ~ normal(z, 1)
end
end
```

Each time we observe a new `(:x ,t)`

, we might want to propose `(:z, t)`

so that it is close in value:

```
@gen function proposal(trace::Trace, x::Real)
t = get_args(trace)[1] + 1
{(:z, t)} ~ normal(x, 1)
end
```

Suppose we initially generated a trace up to timestep `t=1`

, e.g. by calling `t1 = simulate(model, (1,))`

. Now we observe `(:x, 2)`

to be `5.0`

. By constructing a simple extending trace translator, we can simultaneously update the trace `t1`

with new arguments, introduce the new observation at `(:x, 2)`

, and propose a likely value for `(:z, 2)`

:

```
translator = SimpleExtendingTraceTranslator(
p_new_args=(2,), p_argdiffs=(UnknownChange(),),
new_observations=choicemap((:x, 2) => 5.0),
q_forward=proposal, q_forward_args=(5.0,))
t2, log_weight = translator(t1)
```

Similar functionality can be achieved through a combination of `propose`

on the proposal and `update`

on the original trace, but using a trace translator provides a nice layer of abstraction.

## Trace Transform DSL

The **Trace Transform DSL** is a differentiable programming language for writing deterministic transformations of traces. Programs written in this DSL are called *transforms*. Transforms read the value of random choices from input trace(s) and write values of random choices to output trace(s). These programs are not typically executed directly by users, but are instead wrapped into one of the several forms of trace translators listed above (`GeneralTraceTranslator`

, and `SymmetricTraceTranslator`

).

A transform is identified with the `@transform`

macro, and uses one of the following four syntactic forms for the signature (the name of the transform, and the names of the input and output traces are all user-defined varibles; the only keywords are `@transform`

, `to`

, `begin`

, and `end`

):

*A transform from one trace to another, without extra parameters*

```
@transform f t1 to t2 begin
...
end
```

*A transform from one trace to another, with extra parameters*

```
@transform f(x, y, ..) t1 to t2 begin
...
end
```

*A transform from pairs of traces to pairs of traces, without extra parameters*

```
@transform f (t1, t2) to (t3, t4) begin
...
end
```

*A transform from one trace to another, with extra parameters*

```
@transform f(x, y, ..) (t1, t2) to (t3, t4) begin
...
end
```

The extra parameters are optional, and can be used to pass arguments to a transform function that is invoked, from another transform function, using the `@tcall`

macro. For example:

```
@transform g(x) t1 to t2 begin
...
end
@transform f t1 to t2 begin
x = ..
@tcall(g(x))
end
```

The callee transform function (`g`

above) reads and writes to the same input and output traces as the caller transform function (`f`

above). Note that the input and output traces can be assigned different names in the caller and the callee.

The body of a transform reads the values of random choices at addresses in the input trace(s), performs computation using regular Julia code (provided this code can be differentiated with ForwardDiff.jl) and writes valeus of random choices at addresses in the output trace(s). In the body `@read`

expressions read a value from a particular address of an input trace:

`val = @read(<source>, <type-label>)`

where `<source>`

is an expression of the form `<trace>[<addr>]`

where `<trace>`

must be the name of an input trace in the transform's signature. The `<type-label>`

is either `:continuous`

or `:discrete`

, and indicates whether the random choice is discrete or continuous (in measure-theoretic terms, whether it uses the counting measure, or a Lebesgue-measure a Euclidean space of some dimension). Similarly, `@write`

expressions write a value to a particular address in an output trace:

`@write(<destination>, <value>, <type-label>)`

Sometimes trace transforms need to directly copy the value from one address in an input trace to one address in an output trace. In these cases, it is recommended to use the explicit `@copy`

expression:

`@copy(<source>, <destination>)`

where `<source>`

and `<destination>`

are of the form `<trace>[<addr>]`

as above. Note you can also copy collections of multiple random choices under an address namespace in an input trace to an address namespace in an output trace. For example,

`@copy(trace1[:foo], trace2[:bar])`

will copy every random choice in `trace1`

with an address of the form `:foo => <rest>`

to `trace2`

at address `:bar => <rest>`

.

It is also possible to read the *return value* from an input trace using the following syntax, but this value must be discrete (in the local neighborhood of traces, the return value must be constant as a function of all continuous random choices in input traces):

`val = @read(<trace>[], :discrete)`

This feature is useful when the generative function precomputes a quantity as part of its return value, and we would like to reuse this value, instead of having to recompute it as part of the transform. The `discrete' requirement is needed because the transform DSL does not currently backpropagate through the return value (this feature could be added in the future).

Tips for defining valid transforms:

If you find yourself copying the same continuous source address to multiple locations, it probably means your transform is not valid (the Jacobian matrix will have rows that are identical, and so the Jacobian determinant will be zero).

You can gain some confidence that your transform is valid by enabling dynamic checks (

`check=true`

) in the trace translator that uses it.

## API

`Gen.@transform`

— Macro.```
@transform f[(params...)] (in1 [,in2]) to (out1 [,out2])
..
end
```

Write a program in the Trace Transform DSL.

`Gen.@read`

— Macro.`@read(<source>, <annotation>)`

Macro for reading the value of a random choice from an input trace in the Trace Transform DSL.

<source> is of the form <trace>[<addr>] where <trace> is an input trace, and <annotation> is either :discrete or :continuous.

`Gen.@write`

— Macro.`@write(<destination>, <value>, <annotation>)`

Macro for writing the value of a random choice to an output trace in the Trace Transform DSL.

<destination> is of the form <trace>[<addr>] where <trace> is an input trace, and <annotation> is either :discrete or :continuous.

`Gen.@copy`

— Macro.`@copy(<source>, <destination>)`

Macro for copying the value of a random choice (or a whole namespace of random choices) from an input trace to an output trace in the Trace Transform DSL.

<destination> is of the form <trace>[<addr>] where <trace> is an input trace, and <annotation> is either :discrete or :continuous.

`Gen.pair_bijections!`

— Function.`pair_bijections!(f1::TraceTransformDSLProgram, f2::TraceTransformDSLProgram)`

Assert that a pair of bijections contsructed using the Trace Transform DSL are inverses of one another.

`Gen.is_involution!`

— Function.`is_involution!(f::TraceTransformDSLProgram)`

Assert that a bijection constructed with the Trace Transform DSL is its own inverse.

`Gen.inverse`

— Function.`b::TraceTransformDSLProgram = inverse(a::TraceTransformDSLProgram)`

Obtain the inverse of a bijection that was constructed with the Trace Transform DSL.

The inverse must have been associated with the bijection either via `pair_bijections!`

or [`is_involution!`

])(@ref).

`Gen.DeterministicTraceTranslator`

— Type.```
translator = DeterministicTraceTranslator(;
p_new::GenerativeFunction, p_args::Tuple=();
new_observations::ChoiceMap=EmptyChoiceMap()
f::TraceTransformDSLProgram)
```

Constructor for a deterministic trace translator.

Run the translator with:

`(output_trace, log_weight) = translator(input_trace)`

`Gen.GeneralTraceTranslator`

— Type.```
translator = GeneralTraceTranslator(;
p_new::GenerativeFunction,
p_new_args::Tuple = (),
new_observations::ChoiceMap = EmptyChoiceMap(),
q_forward::GenerativeFunction,
q_forward_args::Tuple = (),
q_backward::GenerativeFunction,
q_backward_args::Tuple = (),
f::TraceTransformDSLProgram)
```

Constructor for a general trace translator.

Run the translator with:

`(output_trace, log_weight) = translator(input_trace; check=false, prev_observations=EmptyChoiceMap())`

Use `check`

to enable a bijection check (this requires that the transform `f`

has been paired with its inverse using `pair_bijections! or `is_involution`

).

If `check`

is enabled, then `prev_observations`

is a choice map containing the observed random choices in the previous trace.

```
translator = SimpleExtendingTraceTranslator(;
p_new_args::Tuple = (),
p_argdiffs::Tuple = (),
new_observations::ChoiceMap = EmptyChoiceMap(),
q_forward::GenerativeFunction,
q_forward_args::Tuple = ())
```

Constructor for a simple extending trace translator.

Run the translator with:

`(output_trace, log_weight) = translator(input_trace)`

`Gen.SymmetricTraceTranslator`

— Type.```
translator = SymmetricTraceTranslator(;
q::GenerativeFunction,
q_args::Tuple = (),
involution::Union{TraceTransformDSLProgram,Function})
```

Constructor for a symmetric trace translator.

The involution is either constructed via the `@transform`

macro (recommended), or can be provided as a Julia function.

Run the translator with:

`(output_trace, log_weight) = translator(input_trace; check=false, observations=EmptyChoiceMap())`

Use `check`

to enable the involution check (this requires that the transform `f`

has been marked with `is_involution`

).

If `check`

is enabled, then `observations`

is a choice map containing the observed random choices, and the check will additionally ensure they are not mutated by the involution.