Live variational inference

This is mostly to show two possibilities:

The key is using the SVI.update call, which I have avoided so far for the convenience of SVI.run. A batch of steps can be take in a jitted loop for performance and the loss monitored for convergence every time a parameter/input data is changed.

With matplotlib you need to use an interactive backend like qt or osx and add a plt.pause(...) call in the loop to make sure there is a chance to re-render the figure. Super useful and something I didn’t know about (thanks ChatGPT).

A minimal example below, where the concentration parameter for a Dirichlet prior to a categorical observation is varied smoothly from 1.0 (i.e. flat) to 0.1 (pretty pointy at the extremes).

%matplotlib osx
UsageError: Invalid GUI request 'macosx', valid ones are:dict_keys(['inline', 'nbagg', 'webagg', 'notebook', 'ipympl', 'widget', None, 'qt', 'qt5', 'qt6', 'wx', 'tk', 'gtk', 'gtk3', 'osx', 'asyncio'])
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import seaborn as sns
from matplotlib import pyplot as plt
from numpyro.infer import SVI, Trace_ELBO, autoguide
from numpyro.optim import Adam
from tqdm import tqdm

sns.set_theme(
    "talk", "ticks", font="Arial", font_scale=1.0, rc={"svg.fonttype": "none"}
)

num_steps = 100


def model(alpha_scale):
    alpha = jnp.ones(12) * alpha_scale
    theta = numpyro.sample("theta", dist.Dirichlet(alpha))
    numpyro.sample("obs1", dist.Categorical(theta), obs=jnp.array([2,2,2,2,2,5,5,5]))

guide = autoguide.AutoNormal(model)

optimizer = Adam(1e-3)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

rng_key = jax.random.PRNGKey(0)
state = svi.init(
    rng_key,
    alpha_scale=1.0,
)


@jax.jit
def run_stage(state, init_loss, alpha_scale):
    def body_fn(i, val):
        return svi.update(val[0], alpha_scale)

    return jax.lax.fori_loop(0, 100, body_fn, (state, init_loss))


f, (a1, a2) = plt.subplots(nrows=2, sharex=True)
a1.set(ylabel="ELBO loss")
a2.set(xlabel="SVI batch", ylabel=r"$\theta$ (posterior)")
f.tight_layout()
for i in tqdm(range(num_steps)):
    alpha_scale = 0.02 ** (i / num_steps)
    a1.set_title(f"$\\alpha$: {alpha_scale:.2f}")
    while True:
        # Run SVI steps until convergence
        init_loss = svi.evaluate(state, alpha_scale)
        state, loss = run_stage(state, init_loss, alpha_scale)
        if jnp.abs(loss - init_loss) / np.abs(init_loss) < 0.02:
            break
    params = svi.get_params(state)
    a1.scatter([i], loss, c="k", s=5)
    posterior = guide.sample_posterior(rng_key, params)
    theta = posterior["theta"]
    a2.scatter(i * np.ones_like(theta), theta, c=np.arange(12), s=5)
    # This is really important to get live updates.
    plt.pause(0.01)

f.savefig("2025-10-15-live-svi_result.svg")
print("Posterior theta mean:", posterior["theta"])
plt.close()
100%|██████████| 100/100 [00:06<00:00, 14.36it/s]


Posterior theta mean: [1.26911415e-08 7.69821611e-08 6.65301085e-01 7.23654193e-06
 1.26773830e-06 2.86240101e-01 2.82157103e-10 9.47567692e-04
 2.23408958e-10 4.75026183e-02 6.67153069e-11 5.59603919e-09]

Not sure if the result will show up as expected when this notebook is converted to HTML but it looks very nice and it’s good to be able to see it converge.

The result