This is mostly to show two possibilities:
The key is using the SVI.update call, which I have
avoided so far for the convenience of SVI.run. A batch of
steps can be take in a jitted loop for performance and the loss
monitored for convergence every time a parameter/input data is
changed.
With matplotlib you need to use an interactive backend
like qt or osx and add a
plt.pause(...) call in the loop to make sure there is a
chance to re-render the figure. Super useful and something I didn’t know
about (thanks ChatGPT).
A minimal example below, where the concentration parameter for a Dirichlet prior to a categorical observation is varied smoothly from 1.0 (i.e. flat) to 0.1 (pretty pointy at the extremes).
%matplotlib osx
UsageError: Invalid GUI request 'macosx', valid ones are:dict_keys(['inline', 'nbagg', 'webagg', 'notebook', 'ipympl', 'widget', None, 'qt', 'qt5', 'qt6', 'wx', 'tk', 'gtk', 'gtk3', 'osx', 'asyncio'])
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import seaborn as sns
from matplotlib import pyplot as plt
from numpyro.infer import SVI, Trace_ELBO, autoguide
from numpyro.optim import Adam
from tqdm import tqdm
sns.set_theme(
"talk", "ticks", font="Arial", font_scale=1.0, rc={"svg.fonttype": "none"}
)
num_steps = 100
def model(alpha_scale):
alpha = jnp.ones(12) * alpha_scale
theta = numpyro.sample("theta", dist.Dirichlet(alpha))
numpyro.sample("obs1", dist.Categorical(theta), obs=jnp.array([2,2,2,2,2,5,5,5]))
guide = autoguide.AutoNormal(model)
optimizer = Adam(1e-3)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
rng_key = jax.random.PRNGKey(0)
state = svi.init(
rng_key,
alpha_scale=1.0,
)
@jax.jit
def run_stage(state, init_loss, alpha_scale):
def body_fn(i, val):
return svi.update(val[0], alpha_scale)
return jax.lax.fori_loop(0, 100, body_fn, (state, init_loss))
f, (a1, a2) = plt.subplots(nrows=2, sharex=True)
a1.set(ylabel="ELBO loss")
a2.set(xlabel="SVI batch", ylabel=r"$\theta$ (posterior)")
f.tight_layout()
for i in tqdm(range(num_steps)):
alpha_scale = 0.02 ** (i / num_steps)
a1.set_title(f"$\\alpha$: {alpha_scale:.2f}")
while True:
# Run SVI steps until convergence
init_loss = svi.evaluate(state, alpha_scale)
state, loss = run_stage(state, init_loss, alpha_scale)
if jnp.abs(loss - init_loss) / np.abs(init_loss) < 0.02:
break
params = svi.get_params(state)
a1.scatter([i], loss, c="k", s=5)
posterior = guide.sample_posterior(rng_key, params)
theta = posterior["theta"]
a2.scatter(i * np.ones_like(theta), theta, c=np.arange(12), s=5)
# This is really important to get live updates.
plt.pause(0.01)
f.savefig("2025-10-15-live-svi_result.svg")
print("Posterior theta mean:", posterior["theta"])
plt.close()
100%|██████████| 100/100 [00:06<00:00, 14.36it/s]
Posterior theta mean: [1.26911415e-08 7.69821611e-08 6.65301085e-01 7.23654193e-06
1.26773830e-06 2.86240101e-01 2.82157103e-10 9.47567692e-04
2.23408958e-10 4.75026183e-02 6.67153069e-11 5.59603919e-09]
Not sure if the result will show up as expected when this notebook is converted to HTML but it looks very nice and it’s good to be able to see it converge.