Reparameterizing distributions in numpyro

Every single time I’ve tried to “chance” it with probabilistic methods I am once again reminded that you can’t treat them like a black box. The following is a self-contained example of a pathological case so simple it almost feels like it should just work, and yet it fails miserably without intervention (in this case by reparameterizing the distributions). Part of me wonders if this all just means we need better tools or whether we just have to accept the fundamental complexity and model defensively.

import numpy as np
import pandas as pd
import seaborn as sns
from jax import lax
from jax import numpy as jnp
from jax.random import PRNGKey
from matplotlib import pyplot as plt
from numpyro import deterministic, handlers, sample
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, reparam
from numpyro.optim import Adam

sns.set_theme('talk', 'ticks', font='Arial', font_scale=1.0, rc={'svg.fonttype': 'none'})

Problem setting

Consider a catalyzed reaction $A \xrightarrow{\mathrm{cat}} \cdots$ where the catalyst itself is slowly undergoing decomposition cat → ⋯. Initial concentrations [A]0 and

$\frac{d [A]}{d t} = -k[A][\mathrm{cat}]$
$\frac{d [\mathrm{cat}]}{d t} = -k_d[\mathrm{cat}]$

A couple of simple helper functions to integrate these differential equations.

def solve(a_0, k, cats, dt):
    return lax.scan(lambda a_n, cat: (a_n - dt * k * cat * a_n, a_n - dt * k * cat * a_n), a_0, cats)[1]

def solve_cat(cat_0, k_d, dt, n):
    return lax.scan(lambda cat_n, _: (cat_n - dt * k_d * cat_n, cat_n - dt * k_d * cat_n), cat_0, length=n)[1]
    

plt.plot(solve(1.0, 1.0, jnp.ones((1000,)), 0.01))
plt.plot(solve_cat(1.0, 1.0, 0.02, 1000))
[<matplotlib.lines.Line2D at 0x132ea1a60>]
png

Now a simple model, first without worrying about measurement error. We’ll use both MCMC and SVI.

def model(a_0, cat_0, dt, n):
    k = sample('k', dist.LogNormal(0.0, 1.0))
    k_d = sample('k_d', dist.LogNormal(-2.0, 1.0))
    cats = deterministic('cat', solve_cat(cat_0, k_d, dt, n))
    a_n = deterministic('a_n', solve(a_0, k, cats, dt))

args = (1.0, 0.1, 0.01, 500)
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=100)
mcmc.run(PRNGKey(0), *args)
mcmc_samples = mcmc.get_samples()

guide = autoguide.AutoNormal(model)
svi = SVI(model, guide, Adam(0.01), Trace_ELBO())
svi_result = svi.run(PRNGKey(1), 2000, *args)
svi_samples = guide.sample_posterior(PRNGKey(0), svi_result.params, sample_shape=(100,))
sample: 100%|██████████| 1100/1100 [00:00<00:00, 1685.64it/s, 3 steps of size 9.31e-01. acc. prob=0.92]
100%|██████████| 2000/2000 [00:00<00:00, 4384.19it/s, init loss: 7.5166, avg. loss [1901-2000]: 0.0118]
f, (a1, a2) = plt.subplots(ncols=2, sharey=True, figsize=(10, 4))
a1.set_title('MCMC')
a2.set_title('SVI')

a_line, *_ = a1.plot(mcmc_samples['a_n'].T, c='b', alpha=0.1, label='[A]')
a_line.set_alpha(1)
cat_line, *_ = a1.plot(mcmc_samples['cat'].T, c='r', alpha=0.1, label='[cat]')
cat_line.set_alpha(1)
a2.plot(svi_samples['a_n'].T, c='b', alpha=0.1)
a2.plot(svi_samples['cat'].T, c='r', alpha=0.1)
f.legend(handles=[a_line, cat_line])
a_line.set_alpha(0.1)
cat_line.set_alpha(0.1)
png

Very nice, we get a range of [A] trajectories based on possible k’s but also kd’s. But now let’s just sample simulate adding measurement error as a LogNormal. The previous variables shouldn’t be affected because we are not making an observation. What’s notable about this a_draw variable is that it has a very narrow distribution (σ = 0.02).

def model(a_0, cat_0, dt, n, err):
    k = sample('k', dist.LogNormal(0.0, 1.0))
    k_d = sample('k_d', dist.LogNormal(-2.0, 1.0))
    cats = deterministic('cat', solve_cat(cat_0, k_d, dt, n))
    a_n = deterministic('a_n', solve(a_0, k, cats, dt))
    a_draw = sample('a_draw', dist.LogNormal(jnp.log(a_n), err))

args = (1.0, 0.1, 0.01, 500, 0.05)
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=100)
mcmc.run(PRNGKey(0), *args)
mcmc_samples = mcmc.get_samples()

guide = autoguide.AutoNormal(model)
svi = SVI(model, guide, Adam(0.01), Trace_ELBO())
svi_result = svi.run(PRNGKey(1), 2000, *args)
svi_samples = guide.sample_posterior(PRNGKey(0), svi_result.params, sample_shape=(100,))
sample: 100%|██████████| 1100/1100 [00:02<00:00, 470.02it/s, 63 steps of size 4.88e-02. acc. prob=0.92]
100%|██████████| 2000/2000 [00:00<00:00, 2674.99it/s, init loss: 139124.0000, avg. loss [1901-2000]: 26.2503]
f, (a1, a2) = plt.subplots(ncols=2, sharey=True, figsize=(10, 4))
a1.set_title('MCMC')
a2.set_title('SVI')

a_line, *_ = a1.plot(mcmc_samples['a_n'].T, c='b', alpha=0.1, label='[A]')
a_line.set_alpha(1)
cat_line, *_ = a1.plot(mcmc_samples['cat'].T, c='r', alpha=0.1, label='[cat]')
cat_line.set_alpha(1)
meas_line, *_ = a1.plot(mcmc_samples['a_draw'].T, c='k', alpha=0.02, label='[A] measured')
meas_line.set_alpha(1)
a2.plot(svi_samples['a_n'].T, c='b', alpha=0.1)
a2.plot(svi_samples['cat'].T, c='r', alpha=0.1)
a2.plot(svi_samples['a_draw'].T, c='k', alpha=0.05)
f.legend(handles=[a_line, cat_line, meas_line])
a_line.set_alpha(0.1)
cat_line.set_alpha(0.1)
meas_line.set_alpha(0.02)
png

Yes, not looking good at all. We were expecting the exact sample plots! I tried increasing the number of samples, warmup and SVI steps and those didn’t help either. The culprit is clearly the narrow distribution of a_draw as increasing its std to 1.0 everything goes back to normal.

After a bit of head scratching, it turns out there is a way to fix this without rewriting the model and manually de-centering/scaling. This happens on the level of the Normal distributions being sampled with (μ = 0 and σ = 1.0) but the LogNormal distribution is a further exponential transform away from so that’s where TransformReparam comes in.

reparam_config = {
    'k': reparam.TransformReparam(),
    'k_d': reparam.TransformReparam(),
    'a_draw': reparam.TransformReparam(),
    'k_base': reparam.LocScaleReparam(0),
    'k_d_base': reparam.LocScaleReparam(0),
    'a_draw_base': reparam.LocScaleReparam(0),
}

def model(a_0, cat_0, dt, n, err):
    k = sample('k', dist.LogNormal(0.0, 1.0))
    k_d = sample('k_d', dist.LogNormal(-2.0, 1.0))
    cats = deterministic('cat', solve_cat(cat_0, k_d, dt, n))
    a_n = deterministic('a_n', solve(a_0, k, cats, dt))
    a_draw = sample('a_draw', dist.LogNormal(jnp.log(a_n), err))

model = handlers.reparam(model, reparam_config)

args = (1.0, 0.1, 0.01, 500, 0.05)
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=100)
mcmc.run(PRNGKey(0), *args)
mcmc_samples = mcmc.get_samples()

guide = autoguide.AutoNormal(model)
svi = SVI(model, guide, Adam(0.01), Trace_ELBO())
svi_result = svi.run(PRNGKey(1), 2000, *args)
svi_samples = guide.sample_posterior(PRNGKey(0), svi_result.params, sample_shape=(100,))
sample: 100%|██████████| 1100/1100 [00:00<00:00, 1229.30it/s, 15 steps of size 3.03e-01. acc. prob=0.89]
100%|██████████| 2000/2000 [00:00<00:00, 3245.90it/s, init loss: 1233.4703, avg. loss [1901-2000]: 2.5142]
f, (a1, a2) = plt.subplots(ncols=2, sharey=True, figsize=(10, 4))
a1.set_title('MCMC')
a2.set_title('SVI')

a_line, *_ = a1.plot(mcmc_samples['a_n'].T, c='b', alpha=0.1, label='[A]')
a_line.set_alpha(1)
cat_line, *_ = a1.plot(mcmc_samples['cat'].T, c='r', alpha=0.1, label='[cat]')
cat_line.set_alpha(1)
meas_line, *_ = a1.plot(mcmc_samples['a_draw'].T, c='k', alpha=0.05, label='[A] measured')
meas_line.set_alpha(1)
a2.plot(svi_samples['a_n'].T, c='b', alpha=0.1)
a2.plot(svi_samples['cat'].T, c='r', alpha=0.1)
a2.plot(svi_samples['a_draw'].T, c='k', alpha=0.05)
f.legend(handles=[a_line, cat_line, meas_line])
a_line.set_alpha(0.1)
cat_line.set_alpha(0.1)
meas_line.set_alpha(0.02)
png

Beautiful! The latent concentrations are now exactly as before and we also have a nice extra variable showing measurement with error. How about a quick look at the sampled values of k and kd, which should ideally be uncorrelated and have a lognormal marginal, i.e. each is lognormally distributed irrespective of the other one.

def posterior_df(samples):
    dfs = []

    for var, vals in samples.items():
        dfs.append(pd.DataFrame(samples[var]).melt().rename(columns={"value": var})[[var]])
    
    return pd.concat([dfs[0], *[df.iloc[:, 0] for df in dfs[1:]]], axis=1)

sns.jointplot(posterior_df(mcmc_samples), x='k', y='k_d', kind='kde', fill=True)
plt.gca().set(ylabel='$k_d$', xlabel='$k$')
plt.gca().set_xlim(left=0.0)
plt.gca().set_ylim(bottom=0.0);
png

Nice, normal as expected!

You will notice a quirk: I had to know that the name of the underlying Normal sites. And one downside of using this automatic reparameterisation is that TransformReparam doesn’t currently support observations. That’s fine though, you can split the sampled and observed parts (the observed part won’t need reparameterising as it won’t need to be sampled).

Adding support for observations

Visualising the prior predictive, i.e. what’s predicted solely from our priors without observing anything is nice but not that exciting. It took some trial and error to find the right combination of types and effect handlers to do this but the solution is to use a masked distribution. Here is how it works:

  1. For any unobserved values in a_draw that need imputing (i.e. drawing from the prior predictive rather than observing), we can continue to use our reparamterized TransformedDistribution, but masked so the log probability of the any observed sites is ignored. Essentially, you are still sampling them but not accounting for their log probability, since we are later going to observe them.
  2. For any observed values in a_draw_obs, reparameterizing is not necessary, since we are not going to sample them. We apply the opposite mask: observed values simply add to our overall joint log probability without having to be sampled; unobserved values are also observed from the imputed values but their log probs are ignored because we mask them. This is because we already calculated their log prob in the a_draw site.

I have to confess that this took a while to get right. Especially because some of the types don’t compose very well, so for example TransformedDistribution(MaskedDistribution(...)) is not supported but using a mask handler works.

reparam_config = {
    'k': reparam.TransformReparam(),
    'k_d': reparam.TransformReparam(),
    'a_draw': reparam.TransformReparam(),
    'k_base': reparam.LocScaleReparam(0),
    'k_d_base': reparam.LocScaleReparam(0),
    'a_draw_base': reparam.LocScaleReparam(0),
}

def model(a_0, cat_0, dt, n, err, obs):
    has_obs = ~jnp.isnan(obs)
    k = sample('k', dist.LogNormal(0.0, 1.0))
    k_d = sample('k_d', dist.LogNormal(-2.0, 1.0))
    cats = deterministic('cat', solve_cat(cat_0, k_d, dt, n))
    a_n = deterministic('a_n', solve(a_0, k, cats, dt))
    with handlers.mask(mask=~has_obs):
        d = dist.Normal(jnp.log(a_n), err)
    a_draw = sample('a_draw', dist.TransformedDistribution(d, dist.transforms.ExpTransform()))
    obs = jnp.where(has_obs, obs, a_draw)
    a_draw_combined = deterministic('a_draw_combined', obs)
    a_draw_obs = sample('a_draw_obs', dist.LogNormal(jnp.log(a_n), err).mask(has_obs), obs=obs)

def infer(model, args, guide=None):
    mcmc = MCMC(NUTS(model), num_warmup=2000, num_samples=100)
    mcmc.run(PRNGKey(0), *args)
    mcmc_samples = mcmc.get_samples()

    guide = guide or autoguide.AutoNormal(model)
    svi = SVI(model, guide, Adam(0.01), Trace_ELBO())
    svi_result = svi.run(PRNGKey(1), 10000, *args)
    svi_samples = guide.sample_posterior(PRNGKey(0), svi_result.params, sample_shape=(100,))
    return locals()


model = handlers.reparam(model, reparam_config)

obs = np.full((500,), np.nan)
args = (1.0, 0.1, 0.01, 500, 0.05, obs)

results = infer(model, args)
sample: 100%|██████████| 2100/2100 [00:01<00:00, 1092.37it/s, 15 steps of size 3.03e-01. acc. prob=0.85]
100%|██████████| 10000/10000 [00:01<00:00, 5173.40it/s, init loss: 1233.4703, avg. loss [9501-10000]: 2.4468]

Sampling from the prior predictive, i.e. when the observations are all np.nan.

def plot_samples(mcmc_samples, svi_samples, obs):
    f, (a1, a2) = plt.subplots(ncols=2, sharey=True, figsize=(12, 5))
    a1.set_title('MCMC')
    a2.set_title('SVI')
    has_obs = ~jnp.isnan(obs)


    a_line, *_ = a1.plot(mcmc_samples['a_n'].T, c='b', alpha=0.1, label='[A]')
    a_line.set_alpha(1)
    cat_line, *_ = a1.plot(mcmc_samples['cat'].T, c='r', alpha=0.1, label='[cat]')
    cat_line.set_alpha(1)
    meas_line, *_ = a1.plot(mcmc_samples['a_draw_combined'].T, c='y', alpha=0.05, label='[A] measured', zorder=-1)
    meas_line.set_alpha(1)
    a2.plot(svi_samples['a_n'].T, c='b', alpha=0.1)
    a2.plot(svi_samples['cat'].T, c='r', alpha=0.1)
    a2.plot(svi_samples['a_draw_combined'].T, c='y', alpha=0.05, zorder=-1)
    f.legend(handles=[a_line, cat_line, meas_line])
    a1.scatter(jnp.nonzero(has_obs)[0], obs[has_obs], c='k', zorder=10)
    a2.scatter(jnp.nonzero(has_obs)[0], obs[has_obs], c='k', zorder=10)
    a1.set(ylim=(-0.1,1.1))
    a_line.set_alpha(0.1)
    cat_line.set_alpha(0.1)
    meas_line.set_alpha(0.05)
    return f

plot_samples(results['mcmc_samples'], results['svi_samples'], obs)
plt.tight_layout()
png

Now let’s supply a single observation, specifically one that suggests the reaction is faster than predicted by most trajectories. In the posterior we would expect to see either a higher than anticipated k or a lower than anticipated kd and the two should be positively correlated.

obs[100] = 0.6

args = (1.0, 0.1, 0.01, 500, 0.05, obs)

results = infer(model, args)

plot_samples(results['mcmc_samples'], results['svi_samples'], obs)
plt.tight_layout()
sample: 100%|██████████| 2100/2100 [00:02<00:00, 868.05it/s, 15 steps of size 3.08e-01. acc. prob=0.85] 
100%|██████████| 10000/10000 [00:02<00:00, 4996.71it/s, init loss: 1234.0802, avg. loss [9501-10000]: 3.8389]
png
all_data = pd.concat([
    posterior_df(results['mcmc_samples'])[['k', 'k_d']].assign(Method='MCMC'),
    posterior_df(results['svi_samples'])[['k', 'k_d']].assign(Method='SVI'),
    ], ignore_index=True)
sns.kdeplot(all_data, x='k', y='k_d', hue="Method", fill=True)
plt.gca().set(xlabel='$k$', ylabel='$k_d$')
[Text(0.5, 0, '$k$'), Text(0, 0.5, '$k_d$')]
png

And here we go, a very different picture now, and as expected the AutoNormal SVI guide fails to capture the correlation between k and kd. A full rank multivariate normal should be able to handle this, albeit taking about 10x as long as MCMC. For some reason autoguide.AutoLowRankMultivariateNormal(rank=2) doesn’t seem to do the trick.

A more advanced option to get the best of both worlds would be creating a separate full-rank guide for k and kd then using a simple AutoNormal for the rest (specifically a_draw) using AutoGuideList.

# "Easy" but slow and doesn't really scale
# results = infer(model, args, autoguide.AutoMultivariateNormal(model))

# More involved but very fast
guide = autoguide.AutoGuideList(model)
guide.append(autoguide.AutoNormal(handlers.block(model, expose=['a_draw_base_decentered'])))
guide.append(autoguide.AutoMultivariateNormal(handlers.block(model, expose=['k_base_decentered', 'k_d_base_decentered'])))
results = infer(model, args, guide)
sample: 100%|██████████| 2100/2100 [00:02<00:00, 932.25it/s, 15 steps of size 3.08e-01. acc. prob=0.85]  
100%|██████████| 10000/10000 [00:02<00:00, 4894.64it/s, init loss: 1272.4036, avg. loss [9501-10000]: 3.7564]

This is more involved and it seems like AutoGuideList doesn’t deal with deterministic sites, so we have derive them manually. I had to read a lot of numpyro’s source code to figure out how.

from numpyro.infer.util import soft_vmap

def predictive(sample):
    with handlers.seed(rng_seed=PRNGKey(0)):
        with handlers.substitute(data=sample):
            return {k: v['value'] for k, v in handlers.trace(model).get_trace(*args).items()}

results['svi_samples'].update(soft_vmap(predictive, results['svi_samples']))
all_data = pd.concat([
    posterior_df(results['mcmc_samples'])[['k', 'k_d']].assign(Method='MCMC'),
    posterior_df(results['svi_samples'])[['k', 'k_d']].assign(Method='SVI'),
    ], ignore_index=True)
sns.kdeplot(all_data, x='k', y='k_d', hue="Method", fill=True)
plt.gca().set(xlabel='$k$', ylabel='$k_d$')
[Text(0.5, 0, '$k$'), Text(0, 0.5, '$k_d$')]
png

And here we go!

plot_samples(results['mcmc_samples'], results['svi_samples'], obs)
plt.tight_layout()
png

I hope this was fun to explore together. Very happy to chat if you have any suggestions or questions. This experiment was motivated by my upcoming RSC book on digital chemistry, as part of which I am hoping to present probabilistic programming as a powerful data interpretation aid for experimentalists working with automated platforms that can generate a wealth of experimental data.