%23%20%2F%2F%2F%20script%0A%23%20dependencies%20%3D%20%5B%0A%23%20%20%20%20%20%22diffrax%3D%3D0.7.2%22%2C%0A%23%20%20%20%20%20%22flax%3D%3D0.12.7%22%2C%0A%23%20%20%20%20%20%22jax%3D%3D0.10.0%22%2C%0A%23%20%20%20%20%20%22marimo%22%2C%0A%23%20%20%20%20%20%22matplotlib%3D%3D3.10.9%22%2C%0A%23%20%20%20%20%20%22numpy%3D%3D2.4.4%22%2C%0A%23%20%20%20%20%20%22numpyro%3D%3D0.21.0%22%2C%0A%23%20%20%20%20%20%22optax%3D%3D0.2.8%22%2C%0A%23%20%5D%0A%23%20requires-python%20%3D%20%22%3E%3D3.13%22%0A%23%20%2F%2F%2F%0A%0Aimport%20marimo%0A%0A__generated_with%20%3D%20%220.23.5%22%0Aapp%20%3D%20marimo.App(%0A%20%20%20%20width%3D%22medium%22%2C%0A%20%20%20%20app_title%3D%22Exploration%3A%20Flow%20matching%2C%20NeuTraHMC%20and%20normalizing%20flow%20guides%20in%20NumPyro%22%2C%0A)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Exploration%3A%20Flow%20matching%2C%20NeuTraHMC%20and%20normalizing%20flow%20guides%20in%20NumPyro%0A%20%20%20%20I've%20been%20reading%20papers%20about%20normalizing%20flows%20and%20flow%20matching%2C%20mostly%20as%20a%20way%20to%20express%20complex%20and%20multimodal%20posteriors%20in%20my%20%60numpyro%60%20models%20(mixture%20distributions%20in%20%60numpyro%60%20could%20partially%20fulfil%20the%20same%20need).%0A%0A%20%20%20%20The%20following%20is%20not%20original%20writing%2C%20though%20I%20have%20made%20a%20few%20tweaks.%20It%20is%20the%20result%20of%20an%20exploration%20with%20a%20coding%20agent%20(Pi%20%2B%20Opus%204.7)%2C%20useful%20as%20a%20future%20reference%20and%20hopefully%20for%20anyone%20looking%20for%20examples%20of%20these%20building%20blocks%20in%20action.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20A%20Flow%20Matching%20demo%0A%0A%20%20%20%20We'll%20fit%20a%20**flow-matching**%20transport%20from%20%24p_0%20%3D%20%5Cmathcal%7BN%7D(0%2C%20I_2)%24%20to%20a%204%C3%974%0A%20%20%20%20**checkerboard**%20target%20%24p_1%24%2C%20then%20use%20the%20learned%20flow%20as%0A%0A%20%20%20%201.%20a%20**NumPyro%20SVI%20guide**%20(the%20flow-matching%20loss%20is%20a%20tractable%20surrogate%0A%20%20%20%20%20%20%20for%20the%20variational%20objective%20%E2%80%94%20no%20log-determinants%2C%20no%20ELBO)%2C%0A%20%20%20%202.%20a%20**NeuTra%20reparametrization**%20for%20HMC%3A%20run%20MCMC%20in%20the%20simple%20base%20space%0A%20%20%20%20%20%20%20and%20push%20samples%20through%20the%20flow.%0A%0A%20%20%20%20**Reference.**%20Lipman%20et%20al.%2C%20*Flow%20Matching%20for%20Generative%20Modeling*%2C%202022%0A%20%20%20%20(%5BarXiv%3A2210.02747%5D(https%3A%2F%2Farxiv.org%2Fpdf%2F2210.02747)).%0A%0A%20%20%20%20The%20conditional%20flow-matching%20(CFM)%20loss%2C%20with%20Gaussian%20probability%20paths%0A%20%20%20%20%24p_t(x%20%5Cmid%20x_1)%20%3D%20%5Cmathcal%7BN%7D(t%20x_1%2C%20(1-(1-%5Csigma_%7B%5Cmin%7D)t)%5E2%20I)%24%2C%20is%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cmathcal%7BL%7D_%7B%5Ctext%7BCFM%7D%7D(%5Ctheta)%20%3D%20%5Cmathbb%7BE%7D_%7Bt%2C%5C%2C%20x_1%20%5Csim%20p_1%2C%5C%2C%20x_0%20%5Csim%20%5Cmathcal%7BN%7D(0%2CI)%7D%0A%20%20%20%20%5Cbig%5C%7C%20v_%5Ctheta(x_t%2C%20t)%20-%20(x_1%20-%20(1-%5Csigma_%7B%5Cmin%7D)%20x_0)%20%5Cbig%5C%7C%5E2%2C%0A%20%20%20%20%24%24%0A%0A%20%20%20%20where%20%24x_t%20%3D%20(1-(1-%5Csigma_%7B%5Cmin%7D)t)%20x_0%20%2B%20t%20x_1%24%20is%20a%20straight-line%0A%20%20%20%20interpolation.%20We%20learn%20%24v_%5Ctheta%24%20and%20integrate%20%24%5Cdot%20x%20%3D%20v_%5Ctheta(x%2C%20t)%24%0A%20%20%20%20from%20%24t%3D0%24%20to%20%24t%3D1%24%20to%20push%20base%20samples%20to%20the%20target.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20jax%0A%20%20%20%20import%20jax.numpy%20as%20jnp%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20flax.linen%20as%20nn%0A%20%20%20%20import%20optax%0A%20%20%20%20import%20diffrax%0A%20%20%20%20import%20numpyro%0A%20%20%20%20import%20numpyro.distributions%20as%20dist%0A%20%20%20%20from%20numpyro.infer%20import%20MCMC%2C%20NUTS%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20import%20seaborn%20as%20sns%0A%0A%20%20%20%20sns.set_theme('talk'%2C%20'ticks'%2C%20font%3D'Arial'%2C%20font_scale%3D1.0%2C%20rc%3D%7B'svg.fonttype'%3A%20'none'%7D)%0A%20%20%20%20return%20MCMC%2C%20NUTS%2C%20diffrax%2C%20dist%2C%20jax%2C%20jnp%2C%20mo%2C%20nn%2C%20np%2C%20numpyro%2C%20optax%2C%20plt%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20The%20target%3A%20a%204%C3%974%20checkerboard%20on%20%24%5B-2%2C%202%5D%5E2%24%0A%0A%20%20%20%20Eight%20unit%20squares%20of%20mass%2C%20equally%20weighted.%20Density%20is%20uniform%20on%20the%0A%20%20%20%20%22white%22%20cells%20and%20zero%20elsewhere%20%E2%80%94%20a%20classic%20stress%20test%20for%20normalizing%0A%20%20%20%20flows%20because%20it%20has%20disconnected%20support%20and%20sharp%20boundaries.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(jax%2C%20jnp)%3A%0A%20%20%20%20def%20sample_checkerboard(key%2C%20n%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Sample%20n%20points%20uniformly%20from%20the%208%20white%20cells%20of%20a%204x4%20checkerboard.%22%22%22%0A%20%20%20%20%20%20%20%20k1%2C%20k2%20%3D%20jax.random.split(key)%0A%20%20%20%20%20%20%20%20whites%20%3D%20jnp.array(%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B%5B0%2C%200%5D%2C%20%5B2%2C%200%5D%2C%20%5B1%2C%201%5D%2C%20%5B3%2C%201%5D%2C%20%5B0%2C%202%5D%2C%20%5B2%2C%202%5D%2C%20%5B1%2C%203%5D%2C%20%5B3%2C%203%5D%5D%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20cell%20%3D%20whites%5Bjax.random.randint(k1%2C%20(n%2C)%2C%200%2C%208)%5D%0A%20%20%20%20%20%20%20%20cell_w%20%3D%202%20*%20scale%20%2F%204%0A%20%20%20%20%20%20%20%20lo%20%3D%20-scale%20%2B%20cell%20*%20cell_w%0A%20%20%20%20%20%20%20%20return%20lo%20%2B%20jax.random.uniform(k2%2C%20(n%2C%202))%20*%20cell_w%0A%0A%0A%20%20%20%20def%20checkerboard_log_prob(x%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Log%20density%20(up%20to%20additive%20constant%20from%20total%20mass%20%3D%201).%22%22%22%0A%20%20%20%20%20%20%20%20cell_w%20%3D%202%20*%20scale%20%2F%204%0A%20%20%20%20%20%20%20%20%23%20snap%20to%20integer%20cell%20index%0A%20%20%20%20%20%20%20%20ij%20%3D%20jnp.floor((x%20%2B%20scale)%20%2F%20cell_w).astype(jnp.int32)%0A%20%20%20%20%20%20%20%20inside%20%3D%20jnp.all((ij%20%3E%3D%200)%20%26%20(ij%20%3C%204)%2C%20axis%3D-1)%0A%20%20%20%20%20%20%20%20white%20%3D%20((ij%5B...%2C%200%5D%20%2B%20ij%5B...%2C%201%5D)%20%25%202%20%3D%3D%200)%20%26%20inside%0A%20%20%20%20%20%20%20%20%23%208%20white%20cells%20of%20area%20cell_w**2%20-%3E%20uniform%20density%201%2F(8*cell_w**2)%0A%20%20%20%20%20%20%20%20return%20jnp.where(white%2C%20jnp.log(1.0%20%2F%20(8%20*%20cell_w%20**%202))%2C%20-jnp.inf)%0A%0A%0A%20%20%20%20return%20checkerboard_log_prob%2C%20sample_checkerboard%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(jax%2C%20plt%2C%20sample_checkerboard)%3A%0A%20%20%20%20_xs%20%3D%20sample_checkerboard(jax.random.PRNGKey(0)%2C%204000)%0A%20%20%20%20_fig%2C%20_ax%20%3D%20plt.subplots(figsize%3D(4%2C%204))%0A%20%20%20%20_ax.scatter(_xs%5B%3A%2C%200%5D%2C%20_xs%5B%3A%2C%201%5D%2C%20s%3D2%2C%20alpha%3D0.5)%0A%20%20%20%20_ax.set_xlim(-2.5%2C%202.5)%3B%20_ax.set_ylim(-2.5%2C%202.5)%3B%20_ax.set_aspect(%22equal%22)%0A%20%20%20%20_ax.set_title(%22samples%20from%20target%20%24p_1%24%22)%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Velocity%20field%20%24v_%5Ctheta(x%2C%20t)%24%0A%0A%20%20%20%20A%20small%20MLP%20that%20takes%20the%20spatial%20point%20%24x%20%5Cin%20%5Cmathbb%7BR%7D%5E2%24%20concatenated%0A%20%20%20%20with%20a%20sinusoidal%20embedding%20of%20%24t%20%5Cin%20%5B0%2C%201%5D%24%20and%20returns%20a%20velocity%20in%0A%20%20%20%20%24%5Cmathbb%7BR%7D%5E2%24.%20This%20is%20the%20only%20learned%20object.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jax%2C%20jnp%2C%20nn)%3A%0A%20%20%20%20def%20sinusoidal_embed(t%2C%20dim%3D32%2C%20max_freq%3D1000.0)%3A%0A%20%20%20%20%20%20%20%20half%20%3D%20dim%20%2F%2F%202%0A%20%20%20%20%20%20%20%20freqs%20%3D%20jnp.exp(jnp.linspace(0.0%2C%20jnp.log(max_freq)%2C%20half))%0A%20%20%20%20%20%20%20%20ang%20%3D%20t%5B...%2C%20None%5D%20*%20freqs%0A%20%20%20%20%20%20%20%20return%20jnp.concatenate(%5Bjnp.sin(ang)%2C%20jnp.cos(ang)%5D%2C%20axis%3D-1)%0A%0A%0A%20%20%20%20class%20VelocityField(nn.Module)%3A%0A%20%20%20%20%20%20%20%20hidden%3A%20int%20%3D%20128%0A%20%20%20%20%20%20%20%20depth%3A%20int%20%3D%204%0A%20%20%20%20%20%20%20%20t_embed_dim%3A%20int%20%3D%2032%0A%0A%20%20%20%20%20%20%20%20%40nn.compact%0A%20%20%20%20%20%20%20%20def%20__call__(self%2C%20x%2C%20t)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20x%3A%20(...%2C%202)%2C%20t%3A%20(...)%20scalar%20in%20%5B0%2C1%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20t%20%3D%20jnp.broadcast_to(t%2C%20x.shape%5B%3A-1%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20h%20%3D%20jnp.concatenate(%5Bx%2C%20sinusoidal_embed(t%2C%20self.t_embed_dim)%5D%2C%20axis%3D-1)%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20_%20in%20range(self.depth)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20h%20%3D%20nn.Dense(self.hidden)(h)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20h%20%3D%20nn.silu(h)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20nn.Dense(2)(h)%0A%0A%0A%20%20%20%20vfield%20%3D%20VelocityField()%0A%20%20%20%20_init_params%20%3D%20vfield.init(jax.random.PRNGKey(0)%2C%20jnp.zeros((1%2C%202))%2C%20jnp.zeros((1%2C)))%0A%20%20%20%20_n_params%20%3D%20sum(p.size%20for%20p%20in%20jax.tree_util.tree_leaves(_init_params))%0A%20%20%20%20print(f%22velocity%20field%20has%20%7B_n_params%3A%2C%7D%20parameters%22)%0A%20%20%20%20return%20(vfield%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Flow-matching%20loss%0A%0A%20%20%20%20Conditional%20flow%20matching%20with%20%24%5Csigma_%7B%5Cmin%7D%20%3D%2010%5E%7B-4%7D%24.%20We%20sample%0A%20%20%20%20%24t%20%5Csim%20%5Cmathrm%7BU%7D(0%2C%201)%24%2C%20%24x_0%20%5Csim%20%5Cmathcal%7BN%7D(0%2C%20I)%24%2C%0A%20%20%20%20%24x_1%20%5Csim%20p_1%24%2C%20build%20%24x_t%20%3D%20(1%20-%20(1-%5Csigma_%7B%5Cmin%7D)%20t)%5C%2C%20x_0%20%2B%20t%5C%2C%20x_1%24%0A%20%20%20%20and%20regress%20%24v_%5Ctheta(x_t%2C%20t)%24%20onto%20the%20conditional%20vector%20field%0A%20%20%20%20%24u_t(x_t%20%5Cmid%20x_1)%20%3D%20x_1%20-%20(1-%5Csigma_%7B%5Cmin%7D)%5C%2C%20x_0%24.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(jax%2C%20jnp%2C%20sample_checkerboard%2C%20vfield)%3A%0A%20%20%20%20SIGMA_MIN%20%3D%201e-4%0A%0A%20%20%20%20def%20fm_loss(params%2C%20key%2C%20batch_size%3D512)%3A%0A%20%20%20%20%20%20%20%20k_t%2C%20k_x0%2C%20k_x1%20%3D%20jax.random.split(key%2C%203)%0A%20%20%20%20%20%20%20%20t%20%3D%20jax.random.uniform(k_t%2C%20(batch_size%2C))%0A%20%20%20%20%20%20%20%20x0%20%3D%20jax.random.normal(k_x0%2C%20(batch_size%2C%202))%0A%20%20%20%20%20%20%20%20x1%20%3D%20sample_checkerboard(k_x1%2C%20batch_size)%0A%20%20%20%20%20%20%20%20a%20%3D%201.0%20-%20(1.0%20-%20SIGMA_MIN)%20*%20t%20%20%23%20noise-scale%20schedule%0A%20%20%20%20%20%20%20%20xt%20%3D%20a%5B%3A%2C%20None%5D%20*%20x0%20%2B%20t%5B%3A%2C%20None%5D%20*%20x1%0A%20%20%20%20%20%20%20%20target%20%3D%20x1%20-%20(1.0%20-%20SIGMA_MIN)%20*%20x0%0A%20%20%20%20%20%20%20%20pred%20%3D%20vfield.apply(params%2C%20xt%2C%20t)%0A%20%20%20%20%20%20%20%20return%20jnp.mean(jnp.sum((pred%20-%20target)%20**%202%2C%20axis%3D-1))%0A%0A%0A%20%20%20%20return%20(fm_loss%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%20Training%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(fm_loss%2C%20jax%2C%20jnp%2C%20np%2C%20optax%2C%20vfield)%3A%0A%20%20%20%20N_STEPS%20%3D%2015_000%0A%20%20%20%20BATCH%20%3D%201024%0A%20%20%20%20LR%20%3D%201e-3%0A%0A%20%20%20%20_optim%20%3D%20optax.chain(%0A%20%20%20%20%20%20%20%20optax.clip_by_global_norm(1.0)%2C%0A%20%20%20%20%20%20%20%20optax.adam(LR)%2C%0A%20%20%20%20)%0A%0A%20%20%20%20%40jax.jit%0A%20%20%20%20def%20_step(params%2C%20opt_state%2C%20key)%3A%0A%20%20%20%20%20%20%20%20loss%2C%20grads%20%3D%20jax.value_and_grad(fm_loss)(params%2C%20key%2C%20BATCH)%0A%20%20%20%20%20%20%20%20updates%2C%20opt_state%20%3D%20_optim.update(grads%2C%20opt_state)%0A%20%20%20%20%20%20%20%20params%20%3D%20optax.apply_updates(params%2C%20updates)%0A%20%20%20%20%20%20%20%20return%20params%2C%20opt_state%2C%20loss%0A%0A%20%20%20%20_key%20%3D%20jax.random.PRNGKey(0)%0A%20%20%20%20_params0%20%3D%20vfield.init(_key%2C%20jnp.zeros((1%2C%202))%2C%20jnp.zeros((1%2C)))%0A%20%20%20%20_opt_state0%20%3D%20_optim.init(_params0)%0A%0A%20%20%20%20def%20train_flow(n_steps%3DN_STEPS)%3A%0A%20%20%20%20%20%20%20%20p%2C%20s%20%3D%20_params0%2C%20_opt_state0%0A%20%20%20%20%20%20%20%20losses%20%3D%20np.zeros(n_steps)%0A%20%20%20%20%20%20%20%20k%20%3D%20jax.random.PRNGKey(1)%0A%20%20%20%20%20%20%20%20for%20i%20in%20range(n_steps)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20k%2C%20sk%20%3D%20jax.random.split(k)%0A%20%20%20%20%20%20%20%20%20%20%20%20p%2C%20s%2C%20loss%20%3D%20_step(p%2C%20s%2C%20sk)%0A%20%20%20%20%20%20%20%20%20%20%20%20losses%5Bi%5D%20%3D%20float(loss)%0A%20%20%20%20%20%20%20%20return%20p%2C%20losses%0A%0A%20%20%20%20trained_params%2C%20loss_history%20%3D%20train_flow()%0A%20%20%20%20print(f%22final%20loss%20after%20%7BN_STEPS%7D%20steps%3A%20%7Bloss_history%5B-200%3A%5D.mean()%3A.4f%7D%22)%0A%20%20%20%20return%20loss_history%2C%20trained_params%0A%0A%0A%40app.cell%0Adef%20_(loss_history%2C%20np%2C%20plt)%3A%0A%20%20%20%20_fig%2C%20_ax%20%3D%20plt.subplots(figsize%3D(5%2C%203))%0A%20%20%20%20_ax.plot(loss_history%2C%20lw%3D0.5%2C%20alpha%3D0.6)%0A%20%20%20%20%23%20smoothed%0A%20%20%20%20_w%20%3D%20100%0A%20%20%20%20_smooth%20%3D%20np.convolve(loss_history%2C%20np.ones(_w)%2F_w%2C%20mode%3D%22valid%22)%0A%20%20%20%20_ax.plot(np.arange(_w-1%2C%20len(loss_history))%2C%20_smooth%2C%20lw%3D1.5%2C%20color%3D%22C1%22)%0A%20%20%20%20_ax.set_xlabel(%22step%22)%3B%20_ax.set_ylabel(%22flow-matching%20loss%22)%0A%20%20%20%20_ax.set_yscale(%22log%22)%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Pushing%20samples%20through%20the%20learned%20flow%0A%0A%20%20%20%20With%20the%20trained%20%24v_%5Ctheta%24%2C%20integrate%0A%0A%20%20%20%20%24%24%5Cdot%20x(t)%20%3D%20v_%5Ctheta(x(t)%2C%20t)%2C%20%5Cqquad%20x(0)%20%3D%20z%20%5Csim%20%5Cmathcal%7BN%7D(0%2C%20I_2)%24%24%0A%0A%20%20%20%20from%20%24t%3D0%24%20to%20%24t%3D1%24%20using%20%60diffrax%60%20(Tsit5%2C%20adaptive).%20The%20endpoint%20%24x(1)%24%20is%0A%20%20%20%20our%20model%20sample%2C%20and%20the%20same%20ODE%20solved%20backward%20maps%20target%20%E2%86%92%20base.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(diffrax%2C%20jax%2C%20jnp%2C%20vfield)%3A%0A%20%20%20%20def%20_vf_term(t%2C%20x%2C%20params)%3A%0A%20%20%20%20%20%20%20%20return%20vfield.apply(params%2C%20x%2C%20jnp.full(x.shape%5B%3A-1%5D%2C%20t))%0A%0A%0A%20%20%20%20_solver%20%3D%20diffrax.Tsit5()%0A%20%20%20%20_term%20%3D%20diffrax.ODETerm(_vf_term)%0A%0A%20%20%20%20%40jax.jit%0A%20%20%20%20def%20push_forward(params%2C%20z%2C%20t0%3D0.0%2C%20t1%3D1.0)%3A%0A%20%20%20%20%20%20%20%20%22%22%22z%3A%20(...%2C%202).%20Integrate%20v_theta%20from%20t0%20to%20t1.%22%22%22%0A%20%20%20%20%20%20%20%20sol%20%3D%20diffrax.diffeqsolve(%0A%20%20%20%20%20%20%20%20%20%20%20%20_term%2C%20_solver%2C%20t0%3Dt0%2C%20t1%3Dt1%2C%20dt0%3D0.05%2C%20y0%3Dz%2C%20args%3Dparams%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20stepsize_controller%3Ddiffrax.PIDController(rtol%3D1e-4%2C%20atol%3D1e-4)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20max_steps%3D4096%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20return%20sol.ys%5B-1%5D%0A%0A%0A%20%20%20%20from%20functools%20import%20partial%0A%0A%20%20%20%20%40partial(jax.jit%2C%20static_argnames%3D(%22n_save%22%2C))%0A%20%20%20%20def%20push_forward_trajectory(params%2C%20z%2C%20n_save%3D20)%3A%0A%20%20%20%20%20%20%20%20ts%20%3D%20jnp.linspace(0.0%2C%201.0%2C%20n_save)%0A%20%20%20%20%20%20%20%20sol%20%3D%20diffrax.diffeqsolve(%0A%20%20%20%20%20%20%20%20%20%20%20%20_term%2C%20_solver%2C%20t0%3D0.0%2C%20t1%3D1.0%2C%20dt0%3D0.05%2C%20y0%3Dz%2C%20args%3Dparams%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20saveat%3Ddiffrax.SaveAt(ts%3Dts)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20stepsize_controller%3Ddiffrax.PIDController(rtol%3D1e-4%2C%20atol%3D1e-4)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20max_steps%3D4096%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20return%20ts%2C%20sol.ys%0A%0A%0A%20%20%20%20return%20push_forward%2C%20push_forward_trajectory%0A%0A%0A%40app.cell%0Adef%20_(jax%2C%20plt%2C%20push_forward%2C%20push_forward_trajectory%2C%20trained_params)%3A%0A%20%20%20%20_z%20%3D%20jax.random.normal(jax.random.PRNGKey(11)%2C%20(4000%2C%202))%0A%20%20%20%20_x1%20%3D%20push_forward(trained_params%2C%20_z)%0A%0A%20%20%20%20_z_traj%20%3D%20jax.random.normal(jax.random.PRNGKey(12)%2C%20(1000%2C%202))%0A%20%20%20%20_ts%2C%20_traj%20%3D%20push_forward_trajectory(trained_params%2C%20_z_traj%2C%20n_save%3D40)%0A%0A%20%20%20%20_fig%2C%20_axes%20%3D%20plt.subplots(1%2C%203%2C%20figsize%3D(12%2C%204))%0A%20%20%20%20_axes%5B0%5D.scatter(_z%5B%3A%2C%200%5D%2C%20_z%5B%3A%2C%201%5D%2C%20s%3D2%2C%20alpha%3D0.4)%0A%20%20%20%20_axes%5B0%5D.set_title(%22base%20%24z%20%5Csim%20%5Cmathcal%7BN%7D(0%2C%20I)%24%22)%0A%20%20%20%20for%20_a%20in%20_axes%3A%0A%20%20%20%20%20%20%20%20_a.set_xlim(-3%2C%203)%3B%20_a.set_ylim(-3%2C%203)%3B%20_a.set_aspect(%22equal%22)%0A%0A%20%20%20%20%23%20trajectories%0A%20%20%20%20for%20_i%20in%20range(_traj.shape%5B1%5D)%3A%0A%20%20%20%20%20%20%20%20_axes%5B1%5D.plot(_traj%5B%3A%2C%20_i%2C%200%5D%2C%20_traj%5B%3A%2C%20_i%2C%201%5D%2C%20lw%3D0.4%2C%20alpha%3D0.5%2C%20color%3D%22C0%22)%0A%20%20%20%20_axes%5B1%5D.scatter(_traj%5B0%2C%20%3A%2C%200%5D%2C%20_traj%5B0%2C%20%3A%2C%201%5D%2C%20s%3D4%2C%20color%3D%22C2%22%2C%20label%3D%22%24t%3D0%24%22)%0A%20%20%20%20_axes%5B1%5D.scatter(_traj%5B-1%2C%20%3A%2C%200%5D%2C%20_traj%5B-1%2C%20%3A%2C%201%5D%2C%20s%3D4%2C%20color%3D%22C3%22%2C%20label%3D%22%24t%3D1%24%22)%0A%20%20%20%20_axes%5B1%5D.legend(loc%3D%22upper%20right%22%2C%20fontsize%3D8)%0A%20%20%20%20_axes%5B1%5D.set_title(%22flow%20trajectories%22)%0A%0A%20%20%20%20_axes%5B2%5D.scatter(_x1%5B%3A%2C%200%5D%2C%20_x1%5B%3A%2C%201%5D%2C%20s%3D2%2C%20alpha%3D0.5%2C%20color%3D%22C3%22)%0A%20%20%20%20_axes%5B2%5D.set_title(%22pushed%20samples%20%24x%20%3D%20%5CPhi_1(z)%24%22)%0A%0A%20%20%20%20_fig.tight_layout()%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%20Interactive%3A%20drag%20a%20point%20and%20watch%20the%20gradients%0A%0A%20%20%20%20Drag%20the%20handle%20in%20the%20left%20panel%20(%24z%24-space%2C%20base%20%24%5Cmathcal%7BN%7D(0%2C%20I)%24).%0A%20%20%20%20The%20right%20panel%20shows%20where%20the%20FM%20flow%20sends%20it%3A%20%24x%20%3D%20%5CPhi(z)%24%20on%20the%0A%20%20%20%20checkerboard.%0A%0A%20%20%20%20Three%20arrows%20tell%20the%20NeuTra%20story%3A%0A%0A%20%20%20%20*%20**Right%20panel%2C%20red.**%20%24%5Cnabla_x%20%5Clog%20%5Ctilde%20p(x)%24%20%E2%80%94%20the%20gradient%0A%20%20%20%20%20%20vanilla%20HMC%20would%20feel%20at%20%24x%24.%20Near%20a%20cell%20boundary%20it's%20huge%20(sharpness%0A%20%20%20%20%20%20times%20unit-distance%20%E2%89%88%2020)%2C%20pointing%20into%20the%20white%20cell.%0A%20%20%20%20*%20**Left%20panel%2C%20orange.**%20%24J_%5CPhi(z)%5E%5Ctop%20%5Cnabla_x%20%5Clog%20%5Ctilde%20p%24%20%E2%80%94%0A%20%20%20%20%20%20that%20same%20gradient%20pulled%20back%20through%20the%20flow's%20Jacobian%20transpose.%0A%20%20%20%20%20%20Often%20*much%20smaller*%20than%20the%20red%20arrow%3A%20this%20is%20the%20flow%20attenuating%0A%20%20%20%20%20%20the%20target's%20sharpness.%0A%20%20%20%20*%20**Left%20panel%2C%20cyan.**%20%24%5Cnabla_z%20%5Clog%20%7C%5Cdet%20J_%5CPhi(z)%7C%24%20%E2%80%94%20the%20volume%0A%20%20%20%20%20%20gradient.%20Points%20away%20from%20regions%20where%20the%20flow%20expands%20%24z%24-volume.%0A%0A%20%20%20%20NeuTra-HMC%20follows%20the%20**vector%20sum**%20(drawn%20in%20white).%20Compare%20that%20with%0A%20%20%20%20the%20red%20arrow%20on%20the%20right%3A%20NeuTra%20sees%20a%20tamer%20landscape%20than%20naive%20HMC%0A%20%20%20%20would%2C%20but%20the%20sharpness%20is%20*attenuated*%2C%20not%20removed%20%E2%80%94%20which%20is%20exactly%0A%20%20%20%20the%20point%20of%20the%20earlier%20discussion%20about%20what%20flows%20can%20and%20can't%20fix.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(flow_transform%2C%20jax%2C%20jnp%2C%20soft_checkerboard_log_prob)%3A%0A%20%20%20%20import%20anywidget%0A%20%20%20%20import%20traitlets%0A%0A%0A%20%20%20%20%23%20----------%20Python-side%3A%20gradient%20computations%20----------%0A%0A%20%20%20%20def%20_phi(z)%3A%0A%20%20%20%20%20%20%20%20return%20flow_transform(z)%0A%0A%0A%20%20%20%20def%20_log_pi_z(z%2C%20sharpness)%3A%0A%20%20%20%20%20%20%20%20x%20%3D%20flow_transform(z)%0A%20%20%20%20%20%20%20%20ld%20%3D%20flow_transform.log_abs_det_jacobian(z%2C%20x)%0A%20%20%20%20%20%20%20%20return%20soft_checkerboard_log_prob(x%2C%20sharpness%3Dsharpness)%20%2B%20ld%0A%0A%0A%20%20%20%20def%20_log_p_at_phi(z%2C%20sharpness)%3A%0A%20%20%20%20%20%20%20%20return%20soft_checkerboard_log_prob(flow_transform(z)%2C%20sharpness%3Dsharpness)%0A%0A%0A%20%20%20%20%40jax.jit%0A%20%20%20%20def%20widget_quantities(z%2C%20sharpness)%3A%0A%20%20%20%20%20%20%20%20x%20%3D%20flow_transform(z)%0A%20%20%20%20%20%20%20%20grad_z_full%20%3D%20jax.grad(_log_pi_z)(z%2C%20sharpness)%0A%20%20%20%20%20%20%20%20grad_z_target%20%3D%20jax.grad(_log_p_at_phi)(z%2C%20sharpness)%0A%20%20%20%20%20%20%20%20grad_z_logdet%20%3D%20grad_z_full%20-%20grad_z_target%0A%20%20%20%20%20%20%20%20grad_x_target%20%3D%20jax.grad(%0A%20%20%20%20%20%20%20%20%20%20%20%20lambda%20xx%3A%20soft_checkerboard_log_prob(xx%2C%20sharpness%3Dsharpness)%0A%20%20%20%20%20%20%20%20)(x)%0A%20%20%20%20%20%20%20%20return%20x%2C%20grad_z_full%2C%20grad_z_target%2C%20grad_z_logdet%2C%20grad_x_target%0A%0A%0A%20%20%20%20%23%20----------%20anywidget%20definition%20----------%0A%0A%20%20%20%20_NEUTRA_VIEW_ESM%20%3D%20r%22%22%22%0A%20%20%20%20const%20NS%20%3D%20%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%3B%0A%20%20%20%20const%20RANGE%20%3D%203.0%3B%20%20%20%20%20%20%20%20%20%20%2F%2F%20both%20panels%3A%20%5B-RANGE%2C%20RANGE%5D%0A%20%20%20%20const%20SIZE%20%3D%20300%3B%20%20%20%20%20%20%20%20%20%20%20%2F%2F%20px%20per%20panel%0A%20%20%20%20const%20CB_HALF%20%3D%202.0%3B%20%20%20%20%20%20%20%20%2F%2F%20checkerboard%20extends%20to%20%5B-2%2C%202%5D%0A%20%20%20%20const%20ARROW_SCALE%20%3D%200.04%3B%20%20%20%2F%2F%20visual%20scale%3A%20data%20unit%20per%20pixel-screen%20vector%0A%0A%20%20%20%20function%20s(tag%2C%20attrs%20%3D%20%7B%7D)%20%7B%0A%20%20%20%20%20%20const%20e%20%3D%20document.createElementNS(NS%2C%20tag)%3B%0A%20%20%20%20%20%20for%20(const%20%5Bk%2C%20v%5D%20of%20Object.entries(attrs))%20e.setAttribute(k%2C%20v)%3B%0A%20%20%20%20%20%20return%20e%3B%0A%20%20%20%20%7D%0A%20%20%20%20function%20dataToPx(d)%20%7B%20return%20((d%20%2B%20RANGE)%20%2F%20(2%20*%20RANGE))%20*%20SIZE%3B%20%7D%0A%20%20%20%20function%20pxToData(p)%20%7B%20return%20(p%20%2F%20SIZE)%20*%202%20*%20RANGE%20-%20RANGE%3B%20%7D%0A%0A%20%20%20%20function%20makePanel(title)%20%7B%0A%20%20%20%20%20%20const%20wrap%20%3D%20document.createElement(%22div%22)%3B%0A%20%20%20%20%20%20wrap.style.cssText%20%3D%0A%20%20%20%20%20%20%20%20%22display%3Ainline-block%3B%20margin%3A4px%3B%20vertical-align%3Atop%3B%20font-family%3A%20ui-sans-serif%2C%20system-ui%2C%20sans-serif%3B%22%3B%0A%20%20%20%20%20%20const%20h%20%3D%20document.createElement(%22div%22)%3B%0A%20%20%20%20%20%20h.textContent%20%3D%20title%3B%0A%20%20%20%20%20%20h.style.cssText%20%3D%20%22font-size%3A12px%3B%20color%3A%23555%3B%20text-align%3Acenter%3B%20margin-bottom%3A2px%3B%22%3B%0A%20%20%20%20%20%20const%20svg%20%3D%20s(%22svg%22%2C%20%7B%0A%20%20%20%20%20%20%20%20width%3A%20SIZE%2C%20height%3A%20SIZE%2C%20viewBox%3A%20%600%200%20%24%7BSIZE%7D%20%24%7BSIZE%7D%60%2C%0A%20%20%20%20%20%20%20%20style%3A%20%22border%3A1px%20solid%20%23ddd%3B%20background%3A%23fafafa%3B%20cursor%3A%20default%3B%22%2C%0A%20%20%20%20%20%20%7D)%3B%0A%20%20%20%20%20%20wrap.append(h%2C%20svg)%3B%0A%20%20%20%20%20%20return%20%7B%20wrap%2C%20svg%20%7D%3B%0A%20%20%20%20%7D%0A%0A%20%20%20%20function%20drawAxes(svg)%20%7B%0A%20%20%20%20%20%20const%20c%20%3D%20dataToPx(0)%3B%0A%20%20%20%20%20%20svg.append(s(%22line%22%2C%20%7B%20x1%3A%200%2C%20y1%3A%20c%2C%20x2%3A%20SIZE%2C%20y2%3A%20c%2C%20stroke%3A%20%22%23ddd%22%2C%20%22stroke-width%22%3A%201%20%7D))%3B%0A%20%20%20%20%20%20svg.append(s(%22line%22%2C%20%7B%20x1%3A%20c%2C%20y1%3A%200%2C%20x2%3A%20c%2C%20y2%3A%20SIZE%2C%20stroke%3A%20%22%23ddd%22%2C%20%22stroke-width%22%3A%201%20%7D))%3B%0A%20%20%20%20%7D%0A%0A%20%20%20%20function%20drawCheckerboard(svg)%20%7B%0A%20%20%20%20%20%20const%20cw%20%3D%20(CB_HALF%20*%202%20%2F%204)%3B%0A%20%20%20%20%20%20for%20(let%20i%20%3D%200%3B%20i%20%3C%204%3B%20i%2B%2B)%20%7B%0A%20%20%20%20%20%20%20%20for%20(let%20j%20%3D%200%3B%20j%20%3C%204%3B%20j%2B%2B)%20%7B%0A%20%20%20%20%20%20%20%20%20%20const%20white%20%3D%20(i%20%2B%20j)%20%25%202%20%3D%3D%3D%200%3B%0A%20%20%20%20%20%20%20%20%20%20const%20x0%20%3D%20dataToPx(-CB_HALF%20%2B%20i%20*%20cw)%3B%0A%20%20%20%20%20%20%20%20%20%20%2F%2F%20SVG%20y%20is%20flipped%3A%20screen%20y%20grows%20downward%2C%20data%20y%20grows%20upward.%0A%20%20%20%20%20%20%20%20%20%20const%20y0%20%3D%20dataToPx(-(-CB_HALF%20%2B%20(j%20%2B%201)%20*%20cw))%3B%0A%20%20%20%20%20%20%20%20%20%20const%20w%20%3D%20dataToPx(-CB_HALF%20%2B%20cw)%20-%20dataToPx(-CB_HALF)%3B%0A%20%20%20%20%20%20%20%20%20%20svg.append(s(%22rect%22%2C%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20x%3A%20x0%2C%20y%3A%20y0%2C%20width%3A%20w%2C%20height%3A%20w%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20fill%3A%20white%20%3F%20%22%23e8e8e8%22%20%3A%20%22%23ffffff%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20stroke%3A%20%22%23ccc%22%2C%20%22stroke-width%22%3A%200.5%2C%0A%20%20%20%20%20%20%20%20%20%20%7D))%3B%0A%20%20%20%20%20%20%20%20%7D%0A%20%20%20%20%20%20%7D%0A%20%20%20%20%7D%0A%0A%20%20%20%20function%20drawNormalContours(svg)%20%7B%0A%20%20%20%20%20%20%2F%2F%201%2C%202%2C%203%20sigma%20circles%20on%20N(0%2C%20I)%0A%20%20%20%20%20%20const%20cx%20%3D%20dataToPx(0)%2C%20cy%20%3D%20dataToPx(0)%3B%0A%20%20%20%20%20%20for%20(const%20r%20of%20%5B1%2C%202%2C%203%5D)%20%7B%0A%20%20%20%20%20%20%20%20const%20rpx%20%3D%20dataToPx(r)%20-%20dataToPx(0)%3B%0A%20%20%20%20%20%20%20%20svg.append(s(%22circle%22%2C%20%7B%0A%20%20%20%20%20%20%20%20%20%20cx%2C%20cy%2C%20r%3A%20rpx%2C%20fill%3A%20%22none%22%2C%0A%20%20%20%20%20%20%20%20%20%20stroke%3A%20%22%23bbb%22%2C%20%22stroke-width%22%3A%200.7%2C%20%22stroke-dasharray%22%3A%20%223%203%22%2C%0A%20%20%20%20%20%20%20%20%7D))%3B%0A%20%20%20%20%20%20%7D%0A%20%20%20%20%7D%0A%0A%20%20%20%20function%20arrowDef(svg%2C%20id%2C%20color)%20%7B%0A%20%20%20%20%20%20let%20defs%20%3D%20svg.querySelector(%22defs%22)%3B%0A%20%20%20%20%20%20if%20(!defs)%20%7B%20defs%20%3D%20s(%22defs%22)%3B%20svg.append(defs)%3B%20%7D%0A%20%20%20%20%20%20const%20m%20%3D%20s(%22marker%22%2C%20%7B%0A%20%20%20%20%20%20%20%20id%2C%20viewBox%3A%20%220%200%2010%2010%22%2C%20refX%3A%208%2C%20refY%3A%205%2C%0A%20%20%20%20%20%20%20%20markerWidth%3A%206%2C%20markerHeight%3A%206%2C%20orient%3A%20%22auto-start-reverse%22%2C%0A%20%20%20%20%20%20%7D)%3B%0A%20%20%20%20%20%20m.append(s(%22path%22%2C%20%7B%20d%3A%20%22M%200%200%20L%2010%205%20L%200%2010%20z%22%2C%20fill%3A%20color%20%7D))%3B%0A%20%20%20%20%20%20defs.append(m)%3B%0A%20%20%20%20%7D%0A%0A%20%20%20%20function%20drawArrow(svg%2C%20x_data%2C%20y_data%2C%20vx_data%2C%20vy_data%2C%20color%2C%20width%2C%20markerId)%20%7B%0A%20%20%20%20%20%20const%20x0%20%3D%20dataToPx(x_data)%3B%0A%20%20%20%20%20%20const%20y0%20%3D%20dataToPx(-y_data)%3B%0A%20%20%20%20%20%20%2F%2F%20end%20point%20in%20data%2C%20then%20convert%3B%20arrow%20vector%20is%20in%20data-units%20already.%0A%20%20%20%20%20%20const%20x1%20%3D%20dataToPx(x_data%20%2B%20vx_data)%3B%0A%20%20%20%20%20%20const%20y1%20%3D%20dataToPx(-(y_data%20%2B%20vy_data))%3B%0A%20%20%20%20%20%20const%20line%20%3D%20s(%22line%22%2C%20%7B%0A%20%20%20%20%20%20%20%20x1%3A%20x0%2C%20y1%3A%20y0%2C%20x2%3A%20x1%2C%20y2%3A%20y1%2C%0A%20%20%20%20%20%20%20%20stroke%3A%20color%2C%20%22stroke-width%22%3A%20width%2C%0A%20%20%20%20%20%20%20%20%22marker-end%22%3A%20%60url(%23%24%7BmarkerId%7D)%60%2C%0A%20%20%20%20%20%20%20%20%22stroke-linecap%22%3A%20%22round%22%2C%0A%20%20%20%20%20%20%7D)%3B%0A%20%20%20%20%20%20svg.append(line)%3B%0A%20%20%20%20%20%20return%20line%3B%0A%20%20%20%20%7D%0A%0A%20%20%20%20function%20clipVec(vx%2C%20vy%2C%20maxLen)%20%7B%0A%20%20%20%20%20%20const%20L%20%3D%20Math.hypot(vx%2C%20vy)%3B%0A%20%20%20%20%20%20if%20(L%20%3C%201e-9)%20return%20%5B0%2C%200%5D%3B%0A%20%20%20%20%20%20const%20f%20%3D%20Math.min(1%2C%20maxLen%20%2F%20L)%3B%0A%20%20%20%20%20%20return%20%5Bvx%20*%20f%2C%20vy%20*%20f%5D%3B%0A%20%20%20%20%7D%0A%0A%20%20%20%20function%20render(%7B%20model%2C%20el%20%7D)%20%7B%0A%20%20%20%20%20%20const%20left%20%3D%20makePanel(%22z-space%20%20(base%20N(0%2C%20I))%22)%3B%0A%20%20%20%20%20%20const%20right%20%3D%20makePanel(%22x-space%20%20(target%20on%20checkerboard)%22)%3B%0A%20%20%20%20%20%20const%20row%20%3D%20document.createElement(%22div%22)%3B%0A%20%20%20%20%20%20row.style.cssText%20%3D%20%22display%3Aflex%3B%20gap%3A6px%3B%20flex-wrap%3Awrap%3B%20align-items%3Aflex-start%3B%22%3B%0A%20%20%20%20%20%20row.append(left.wrap%2C%20right.wrap)%3B%0A%0A%20%20%20%20%20%20const%20info%20%3D%20document.createElement(%22div%22)%3B%0A%20%20%20%20%20%20info.style.cssText%20%3D%0A%20%20%20%20%20%20%20%20%22font%3A%2011px%20ui-monospace%2C%20monospace%3B%20color%3A%23444%3B%20padding%3A6px%208px%3B%20line-height%3A1.5%3B%22%3B%0A%20%20%20%20%20%20el.append(row%2C%20info)%3B%0A%0A%20%20%20%20%20%20%2F%2F%20arrow%20markers%0A%20%20%20%20%20%20arrowDef(left.svg%2C%20%22ah-target%22%2C%20%20%22%23e8893a%22)%3B%20%20%2F%2F%20orange%0A%20%20%20%20%20%20arrowDef(left.svg%2C%20%22ah-logdet%22%2C%20%20%22%233aa6e8%22)%3B%20%20%2F%2F%20cyan%0A%20%20%20%20%20%20arrowDef(left.svg%2C%20%22ah-full%22%2C%20%20%20%20%22%23222%22)%3B%20%20%20%20%20%2F%2F%20dark%20for%20%22what%20HMC%20follows%22%0A%20%20%20%20%20%20arrowDef(right.svg%2C%20%22ah-targetx%22%2C%22%23d33b3b%22)%3B%20%20%2F%2F%20red%0A%0A%20%20%20%20%20%20drawAxes(left.svg)%3B%0A%20%20%20%20%20%20drawNormalContours(left.svg)%3B%0A%20%20%20%20%20%20drawAxes(right.svg)%3B%0A%20%20%20%20%20%20drawCheckerboard(right.svg)%3B%0A%0A%20%20%20%20%20%20%2F%2F%20dynamic%20layer%3A%20redrawn%20on%20every%20update%0A%20%20%20%20%20%20const%20dynLeft%20%3D%20s(%22g%22)%3B%20left.svg.append(dynLeft)%3B%0A%20%20%20%20%20%20const%20dynRight%20%3D%20s(%22g%22)%3B%20right.svg.append(dynRight)%3B%0A%0A%20%20%20%20%20%20%2F%2F%20draggable%20handle%20in%20z-space%0A%20%20%20%20%20%20const%20handle%20%3D%20s(%22circle%22%2C%20%7B%20r%3A%207%2C%20fill%3A%20%22%23333%22%2C%20stroke%3A%20%22%23fff%22%2C%20%22stroke-width%22%3A%202%2C%20style%3A%20%22cursor%3Agrab%3B%22%20%7D)%3B%0A%20%20%20%20%20%20left.svg.append(handle)%3B%0A%0A%20%20%20%20%20%20%2F%2F%20image%20point%20in%20x-space%0A%20%20%20%20%20%20const%20imgPt%20%3D%20s(%22circle%22%2C%20%7B%20r%3A%205%2C%20fill%3A%20%22%23333%22%2C%20stroke%3A%20%22%23fff%22%2C%20%22stroke-width%22%3A%201.5%20%7D)%3B%0A%20%20%20%20%20%20right.svg.append(imgPt)%3B%0A%0A%20%20%20%20%20%20function%20update()%20%7B%0A%20%20%20%20%20%20%20%20const%20z%20%3D%20model.get(%22z%22)%3B%0A%20%20%20%20%20%20%20%20const%20x%20%3D%20model.get(%22x%22)%3B%0A%20%20%20%20%20%20%20%20const%20gT%20%3D%20model.get(%22grad_z_target%22)%3B%0A%20%20%20%20%20%20%20%20const%20gL%20%3D%20model.get(%22grad_z_logdet%22)%3B%0A%20%20%20%20%20%20%20%20const%20gF%20%3D%20model.get(%22grad_z_full%22)%3B%0A%20%20%20%20%20%20%20%20const%20gXt%20%3D%20model.get(%22grad_x_target%22)%3B%0A%0A%20%20%20%20%20%20%20%20handle.setAttribute(%22cx%22%2C%20dataToPx(z%5B0%5D))%3B%0A%20%20%20%20%20%20%20%20handle.setAttribute(%22cy%22%2C%20dataToPx(-z%5B1%5D))%3B%0A%20%20%20%20%20%20%20%20imgPt.setAttribute(%22cx%22%2C%20dataToPx(x%5B0%5D))%3B%0A%20%20%20%20%20%20%20%20imgPt.setAttribute(%22cy%22%2C%20dataToPx(-x%5B1%5D))%3B%0A%0A%20%20%20%20%20%20%20%20%2F%2F%20clear%20dyn%20layers%0A%20%20%20%20%20%20%20%20while%20(dynLeft.firstChild)%20dynLeft.removeChild(dynLeft.firstChild)%3B%0A%20%20%20%20%20%20%20%20while%20(dynRight.firstChild)%20dynRight.removeChild(dynRight.firstChild)%3B%0A%0A%20%20%20%20%20%20%20%20%2F%2F%20arrows.%20Scale%20them%20so%20they're%20visible%20but%20not%20running%20off%20the%20panel.%0A%20%20%20%20%20%20%20%20%2F%2F%20Target%20gradient%20in%20z%3A%20orange.%20log-det%3A%20cyan.%20Full%20(sum)%3A%20dark.%0A%20%20%20%20%20%20%20%20const%20sT%20%20%3D%20clipVec(gT%5B0%5D%20%20*%20ARROW_SCALE%2C%20gT%5B1%5D%20%20*%20ARROW_SCALE%2C%202.5)%3B%0A%20%20%20%20%20%20%20%20const%20sL%20%20%3D%20clipVec(gL%5B0%5D%20%20*%20ARROW_SCALE%2C%20gL%5B1%5D%20%20*%20ARROW_SCALE%2C%202.5)%3B%0A%20%20%20%20%20%20%20%20const%20sF%20%20%3D%20clipVec(gF%5B0%5D%20%20*%20ARROW_SCALE%2C%20gF%5B1%5D%20%20*%20ARROW_SCALE%2C%202.5)%3B%0A%20%20%20%20%20%20%20%20const%20sXt%20%3D%20clipVec(gXt%5B0%5D%20*%20ARROW_SCALE%2C%20gXt%5B1%5D%20*%20ARROW_SCALE%2C%202.5)%3B%0A%0A%20%20%20%20%20%20%20%20drawArrow(dynLeft%2C%20z%5B0%5D%2C%20z%5B1%5D%2C%20sT%5B0%5D%2C%20sT%5B1%5D%2C%20%22%23e8893a%22%2C%201.8%2C%20%22ah-target%22)%3B%0A%20%20%20%20%20%20%20%20drawArrow(dynLeft%2C%20z%5B0%5D%2C%20z%5B1%5D%2C%20sL%5B0%5D%2C%20sL%5B1%5D%2C%20%22%233aa6e8%22%2C%201.8%2C%20%22ah-logdet%22)%3B%0A%20%20%20%20%20%20%20%20drawArrow(dynLeft%2C%20z%5B0%5D%2C%20z%5B1%5D%2C%20sF%5B0%5D%2C%20sF%5B1%5D%2C%20%22%23222%22%2C%20%20%20%202.6%2C%20%22ah-full%22)%3B%0A%20%20%20%20%20%20%20%20drawArrow(dynRight%2C%20x%5B0%5D%2C%20x%5B1%5D%2C%20sXt%5B0%5D%2C%20sXt%5B1%5D%2C%20%22%23d33b3b%22%2C%202.2%2C%20%22ah-targetx%22)%3B%0A%0A%20%20%20%20%20%20%20%20%2F%2F%20numeric%20readout%0A%20%20%20%20%20%20%20%20const%20fmt%20%3D%20(v)%20%3D%3E%20v.toFixed(2).padStart(6)%3B%0A%20%20%20%20%20%20%20%20info.innerHTML%20%3D%0A%20%20%20%20%20%20%20%20%20%20%60%3Cb%3Ez%3C%2Fb%3E%20%3D%20(%24%7Bfmt(z%5B0%5D)%7D%2C%20%24%7Bfmt(z%5B1%5D)%7D)%20%20%E2%86%92%20%20%3Cb%3Ex%3C%2Fb%3E%20%3D%20(%24%7Bfmt(x%5B0%5D)%7D%2C%20%24%7Bfmt(x%5B1%5D)%7D)%3Cbr%3E%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%3Cspan%20style%3D%22color%3A%23d33b3b%22%3E%E2%96%AE%3C%2Fspan%3E%20%E2%88%87%3Csub%3Ex%3C%2Fsub%3E%20log%20p%CC%83(x)%20%3D%20(%24%7Bfmt(gXt%5B0%5D)%7D%2C%20%24%7Bfmt(gXt%5B1%5D)%7D)%20%20%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%7C%C2%B7%7C%20%3D%20%24%7Bfmt(Math.hypot(gXt%5B0%5D%2C%20gXt%5B1%5D))%7D%3Cbr%3E%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%3Cspan%20style%3D%22color%3A%23e8893a%22%3E%E2%96%AE%3C%2Fspan%3E%20J%E1%B5%80%20%E2%88%87%3Csub%3Ex%3C%2Fsub%3E%20log%20p%CC%83%20%3D%20(%24%7Bfmt(gT%5B0%5D)%7D%2C%20%24%7Bfmt(gT%5B1%5D)%7D)%20%20%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%7C%C2%B7%7C%20%3D%20%24%7Bfmt(Math.hypot(gT%5B0%5D%2C%20gT%5B1%5D))%7D%3Cbr%3E%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%3Cspan%20style%3D%22color%3A%233aa6e8%22%3E%E2%96%AE%3C%2Fspan%3E%20%E2%88%87%3Csub%3Ez%3C%2Fsub%3E%20log%7Cdet%20J%7C%20%3D%20(%24%7Bfmt(gL%5B0%5D)%7D%2C%20%24%7Bfmt(gL%5B1%5D)%7D)%20%20%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%7C%C2%B7%7C%20%3D%20%24%7Bfmt(Math.hypot(gL%5B0%5D%2C%20gL%5B1%5D))%7D%3Cbr%3E%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%3Cspan%20style%3D%22color%3A%23222%22%3E%E2%96%AE%3C%2Fspan%3E%20NeuTra%20force%20%3D%20(%24%7Bfmt(gF%5B0%5D)%7D%2C%20%24%7Bfmt(gF%5B1%5D)%7D)%20%20%60%20%2B%0A%20%20%20%20%20%20%20%20%20%20%60%7C%C2%B7%7C%20%3D%20%24%7Bfmt(Math.hypot(gF%5B0%5D%2C%20gF%5B1%5D))%7D%60%3B%0A%20%20%20%20%20%20%7D%0A%0A%20%20%20%20%20%20%2F%2F%20drag%20handler%0A%20%20%20%20%20%20const%20ctrl%20%3D%20new%20AbortController()%3B%0A%20%20%20%20%20%20const%20%7B%20signal%20%7D%20%3D%20ctrl%3B%0A%20%20%20%20%20%20let%20dragging%20%3D%20false%3B%0A%20%20%20%20%20%20function%20startDrag(e)%20%7B%20dragging%20%3D%20true%3B%20handle.style.cursor%20%3D%20%22grabbing%22%3B%20doMove(e)%3B%20%7D%0A%20%20%20%20%20%20function%20endDrag()%20%20%20%20%7B%20dragging%20%3D%20false%3B%20handle.style.cursor%20%3D%20%22grab%22%3B%20%7D%0A%20%20%20%20%20%20function%20doMove(e)%20%7B%0A%20%20%20%20%20%20%20%20if%20(!dragging)%20return%3B%0A%20%20%20%20%20%20%20%20const%20rect%20%3D%20left.svg.getBoundingClientRect()%3B%0A%20%20%20%20%20%20%20%20const%20px%20%3D%20e.clientX%20-%20rect.left%3B%0A%20%20%20%20%20%20%20%20const%20py%20%3D%20e.clientY%20-%20rect.top%3B%0A%20%20%20%20%20%20%20%20const%20zx%20%3D%20pxToData(px)%3B%0A%20%20%20%20%20%20%20%20const%20zy%20%3D%20-pxToData(py)%3B%20%20%20%2F%2F%20flip%0A%20%20%20%20%20%20%20%20const%20zxc%20%3D%20Math.max(-RANGE%2C%20Math.min(RANGE%2C%20zx))%3B%0A%20%20%20%20%20%20%20%20const%20zyc%20%3D%20Math.max(-RANGE%2C%20Math.min(RANGE%2C%20zy))%3B%0A%20%20%20%20%20%20%20%20model.set(%22z%22%2C%20%5Bzxc%2C%20zyc%5D)%3B%0A%20%20%20%20%20%20%20%20model.save_changes()%3B%0A%20%20%20%20%20%20%7D%0A%20%20%20%20%20%20handle.addEventListener(%22mousedown%22%2C%20startDrag%2C%20%7B%20signal%20%7D)%3B%0A%20%20%20%20%20%20left.svg.addEventListener(%22mousedown%22%2C%20startDrag%2C%20%7B%20signal%20%7D)%3B%20%20%2F%2F%20click%20anywhere%20to%20move%0A%20%20%20%20%20%20window.addEventListener(%22mousemove%22%2C%20doMove%2C%20%7B%20signal%20%7D)%3B%0A%20%20%20%20%20%20window.addEventListener(%22mouseup%22%2C%20endDrag%2C%20%7B%20signal%20%7D)%3B%0A%0A%20%20%20%20%20%20model.on(%22change%3Ax%22%2C%20update)%3B%0A%20%20%20%20%20%20update()%3B%0A%0A%20%20%20%20%20%20return%20()%20%3D%3E%20ctrl.abort()%3B%0A%20%20%20%20%7D%0A%0A%20%20%20%20export%20default%20%7B%20render%20%7D%3B%0A%20%20%20%20%22%22%22%0A%0A%0A%20%20%20%20class%20NeuTraGradientView(anywidget.AnyWidget)%3A%0A%20%20%20%20%20%20%20%20z%20%3D%20traitlets.List(traitlets.Float()%2C%20default_value%3D%5B0.5%2C%20-0.3%5D).tag(sync%3DTrue)%0A%20%20%20%20%20%20%20%20x%20%3D%20traitlets.List(traitlets.Float()%2C%20default_value%3D%5B0.0%2C%200.0%5D).tag(sync%3DTrue)%0A%20%20%20%20%20%20%20%20grad_z_full%20%3D%20traitlets.List(traitlets.Float()%2C%20default_value%3D%5B0.0%2C%200.0%5D).tag(sync%3DTrue)%0A%20%20%20%20%20%20%20%20grad_z_target%20%3D%20traitlets.List(traitlets.Float()%2C%20default_value%3D%5B0.0%2C%200.0%5D).tag(sync%3DTrue)%0A%20%20%20%20%20%20%20%20grad_z_logdet%20%3D%20traitlets.List(traitlets.Float()%2C%20default_value%3D%5B0.0%2C%200.0%5D).tag(sync%3DTrue)%0A%20%20%20%20%20%20%20%20grad_x_target%20%3D%20traitlets.List(traitlets.Float()%2C%20default_value%3D%5B0.0%2C%200.0%5D).tag(sync%3DTrue)%0A%20%20%20%20%20%20%20%20sharpness%20%3D%20traitlets.Float(20.0).tag(sync%3DTrue)%0A%20%20%20%20%20%20%20%20_esm%20%3D%20_NEUTRA_VIEW_ESM%0A%0A%0A%20%20%20%20def%20refresh_quantities(widget)%3A%0A%20%20%20%20%20%20%20%20z%20%3D%20jnp.array(widget.z%2C%20dtype%3Djnp.float32)%0A%20%20%20%20%20%20%20%20x%2C%20gF%2C%20gT%2C%20gL%2C%20gXt%20%3D%20widget_quantities(z%2C%20float(widget.sharpness))%0A%20%20%20%20%20%20%20%20with%20widget.hold_sync()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20widget.grad_z_target%20%3D%20%5Bfloat(gT%5B0%5D)%2C%20%20float(gT%5B1%5D)%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20widget.grad_z_logdet%20%3D%20%5Bfloat(gL%5B0%5D)%2C%20%20float(gL%5B1%5D)%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20widget.grad_z_full%20%20%20%3D%20%5Bfloat(gF%5B0%5D)%2C%20%20float(gF%5B1%5D)%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20widget.grad_x_target%20%3D%20%5Bfloat(gXt%5B0%5D)%2C%20float(gXt%5B1%5D)%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20widget.x%20%3D%20%5Bfloat(x%5B0%5D)%2C%20float(x%5B1%5D)%5D%20%20%23%20last%3A%20triggers%20single%20change%3Ax%20in%20JS%0A%0A%0A%20%20%20%20return%20NeuTraGradientView%2C%20refresh_quantities%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(NeuTraGradientView%2C%20refresh_quantities)%3A%0A%20%20%20%20neutra_view%20%3D%20NeuTraGradientView()%0A%20%20%20%20refresh_quantities(neutra_view)%0A%20%20%20%20neutra_view.observe(lambda%20_ch%3A%20refresh_quantities(neutra_view)%2C%20names%3D%5B%22z%22%2C%20%22sharpness%22%5D)%0A%20%20%20%20neutra_view%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20From%20flow%20to%20NumPyro%20%60Transform%60%0A%0A%20%20%20%20For%20SVI%20%2F%20MCMC%20we%20need%20a%20bijection%20%24%5CPhi%20%3A%20z%20%5Cmapsto%20x%24%20together%20with%20the%0A%20%20%20%20log-determinant%20%24%5Clog%20%7C%5Cdet%20%5Cpartial%20%5CPhi%20%2F%20%5Cpartial%20z%7C%24.%20Both%20come%20from%0A%20%20%20%20the%20**augmented%20ODE**%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cfrac%7Bd%7D%7Bdt%7D%5C!%5Cbegin%7Bpmatrix%7D%20x%20%5C%5C%20%5Cell%20%5Cend%7Bpmatrix%7D%20%3D%0A%20%20%20%20%5Cbegin%7Bpmatrix%7D%20v_%5Ctheta(x%2C%20t)%20%5C%5C%20%5Cnabla%20%5C!%5Ccdot%20v_%5Ctheta(x%2C%20t)%20%5Cend%7Bpmatrix%7D%2C%0A%20%20%20%20%5Cqquad%20(x(0)%2C%20%5Cell(0))%20%3D%20(z%2C%200).%0A%20%20%20%20%24%24%0A%0A%20%20%20%20In%202D%20the%20divergence%20is%20just%20the%20trace%20of%20a%20%242%5Ctimes%202%24%20Jacobian%2C%20so%20we%0A%20%20%20%20compute%20it%20exactly%20with%20%60jax.jacrev%60%20(no%20Hutchinson%20estimator%20needed).%0A%20%20%20%20The%20transform's%20%60_inverse%60%20runs%20the%20same%20ODE%20from%20%24t%3D1%24%20back%20to%20%24t%3D0%24.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(diffrax%2C%20jax%2C%20jnp%2C%20trained_params%2C%20vfield)%3A%0A%20%20%20%20from%20numpyro.distributions.transforms%20import%20Transform%0A%20%20%20%20from%20numpyro.distributions%20import%20constraints%0A%0A%0A%20%20%20%20def%20_vf_aug_(t%2C%20y%2C%20params)%3A%0A%20%20%20%20%20%20%20%20x%2C%20_%20%3D%20y%0A%20%20%20%20%20%20%20%20def%20single(xi)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20J%20%3D%20jax.jacrev(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20lambda%20xx%3A%20vfield.apply(params%2C%20xx%2C%20jnp.full(xx.shape%5B%3A-1%5D%2C%20t))%0A%20%20%20%20%20%20%20%20%20%20%20%20)(xi)%0A%20%20%20%20%20%20%20%20%20%20%20%20v%20%3D%20vfield.apply(params%2C%20xi%2C%20jnp.full(()%2C%20t))%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20v%2C%20jnp.trace(J)%0A%20%20%20%20%20%20%20%20v%2C%20div%20%3D%20jax.vmap(single)(x)%0A%20%20%20%20%20%20%20%20return%20(v%2C%20div)%0A%0A%0A%20%20%20%20_aug_term_%20%3D%20diffrax.ODETerm(_vf_aug_)%0A%20%20%20%20_solver_%20%3D%20diffrax.Tsit5()%0A%0A%0A%20%20%20%20%40jax.jit%0A%20%20%20%20def%20_integrate_aug(params%2C%20z%2C%20t0%2C%20t1)%3A%0A%20%20%20%20%20%20%20%20y0%20%3D%20(z%2C%20jnp.zeros(z.shape%5B%3A-1%5D))%0A%20%20%20%20%20%20%20%20sol%20%3D%20diffrax.diffeqsolve(%0A%20%20%20%20%20%20%20%20%20%20%20%20_aug_term_%2C%20_solver_%2C%20t0%3Dt0%2C%20t1%3Dt1%2C%20dt0%3D0.1%20*%20jnp.sign(t1%20-%20t0)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20y0%3Dy0%2C%20args%3Dparams%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20stepsize_controller%3Ddiffrax.PIDController(rtol%3D1e-3%2C%20atol%3D1e-3)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20max_steps%3D2048%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20return%20sol.ys%5B0%5D%5B-1%5D%2C%20sol.ys%5B1%5D%5B-1%5D%0A%0A%0A%20%20%20%20class%20FlowMatchTransform(Transform)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Bijection%20z%20-%3E%20x%20defined%20by%20integrating%20v_theta%20from%20t%3D0%20to%20t%3D1.%22%22%22%0A%20%20%20%20%20%20%20%20domain%20%3D%20constraints.real_vector%0A%20%20%20%20%20%20%20%20codomain%20%3D%20constraints.real_vector%0A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20params)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.params%20%3D%20params%0A%0A%20%20%20%20%20%20%20%20def%20__call__(self%2C%20z)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%2C%20_%20%3D%20_integrate_aug(self.params%2C%20jnp.atleast_2d(z)%2C%200.0%2C%201.0)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20x.reshape(z.shape)%0A%0A%20%20%20%20%20%20%20%20def%20_inverse(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20z%2C%20_%20%3D%20_integrate_aug(self.params%2C%20jnp.atleast_2d(x)%2C%201.0%2C%200.0)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20z.reshape(x.shape)%0A%0A%20%20%20%20%20%20%20%20def%20log_abs_det_jacobian(self%2C%20z%2C%20x%2C%20intermediates%3DNone)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20_%2C%20ld%20%3D%20_integrate_aug(self.params%2C%20jnp.atleast_2d(z)%2C%200.0%2C%201.0)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20ld.reshape(z.shape%5B%3A-1%5D)%0A%0A%20%20%20%20%20%20%20%20def%20tree_flatten(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20(self.params%2C)%2C%20((%22params%22%2C)%2C%20dict())%0A%0A%20%20%20%20%20%20%20%20%40classmethod%0A%20%20%20%20%20%20%20%20def%20tree_unflatten(cls%2C%20aux_data%2C%20children)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20(params%2C)%20%3D%20children%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20cls(params)%0A%0A%0A%20%20%20%20flow_transform%20%3D%20FlowMatchTransform(trained_params)%0A%0A%20%20%20%20_z_test%20%3D%20jax.random.normal(jax.random.PRNGKey(0)%2C%20(3%2C%202))%0A%20%20%20%20_x_test%20%3D%20flow_transform(_z_test)%0A%20%20%20%20_z_back%20%3D%20flow_transform._inverse(_x_test)%0A%20%20%20%20print(%22round-trip%20max%20error%3A%22%2C%20float(jnp.max(jnp.abs(_z_test%20-%20_z_back))))%0A%20%20%20%20print(%22log%7Cdet%20J%7C%3A%22%2C%20flow_transform.log_abs_det_jacobian(_z_test%2C%20_x_test))%0A%0A%20%20%20%20return%20(flow_transform%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20NeuTra-HMC%3A%20sampling%20the%20checkerboard%20via%20the%20flow%0A%0A%20%20%20%20Define%20a%20NumPyro%20model%20whose%20target%20is%20the%20(unnormalized)%20checkerboard%0A%20%20%20%20log-density%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Clog%20p(x)%20%3D%20%5Clog%20%5Ctilde%20p_%7B%5Ctext%7Bchecker%7D%7D(x).%0A%20%20%20%20%24%24%0A%0A%20%20%20%20Naively%20running%20HMC%20on%20this%20is%20hopeless%20%E2%80%94%20the%20support%20is%20disconnected%20and%0A%20%20%20%20the%20density%20is%20piecewise%20constant%2C%20so%20gradients%20are%20zero%20almost%20everywhere.%0A%0A%20%20%20%20**NeuTra-HMC**%20(Hoffman%20et%20al.%2C%202019)%20sidesteps%20this%20by%20**reparametrizing**%0A%20%20%20%20in%20the%20base%20space.%20We%20declare%20%24x%20%3D%20%5CPhi(z)%24%20with%20%24z%20%5Csim%20%5Cmathcal%7BN%7D(0%2C%20I)%24%0A%20%20%20%20via%20a%20%60TransformedDistribution%60%2C%20then%20%60TransformReparam%60%20rewrites%20the%20model%0A%20%20%20%20to%20sample%20%24z%24%20directly.%20HMC%20runs%20in%20%24z%24-space%20%E2%80%94%20where%20the%20geometry%20is%20benign%2C%0A%20%20%20%20courtesy%20of%20the%20trained%20flow%20%E2%80%94%20and%20we%20transform%20back.%0A%0A%20%20%20%20Because%20real%20HMC%20needs%20differentiable%20log-densities%2C%20we%20use%20a%20**smoothed**%0A%20%20%20%20checkerboard%20target%20(a%20soft-indicator%20over%20each%20white%20cell).%20The%20trained%0A%20%20%20%20flow%20already%20concentrates%20mass%20there%2C%20so%20HMC%20just%20has%20to%20clean%20up.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(checkerboard_log_prob%2C%20jax%2C%20jnp)%3A%0A%20%20%20%20def%20soft_checkerboard_log_prob(x%2C%20scale%3D2.0%2C%20sharpness%3D20.0)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Differentiable%20approximation%20of%20the%20checkerboard%20log-density.%22%22%22%0A%20%20%20%20%20%20%20%20cell_w%20%3D%202%20*%20scale%20%2F%204%0A%20%20%20%20%20%20%20%20whites%20%3D%20jnp.array(%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B%5B0%2C%200%5D%2C%20%5B2%2C%200%5D%2C%20%5B1%2C%201%5D%2C%20%5B3%2C%201%5D%2C%20%5B0%2C%202%5D%2C%20%5B2%2C%202%5D%2C%20%5B1%2C%203%5D%2C%20%5B3%2C%203%5D%5D%0A%20%20%20%20%20%20%20%20)%20%20%23%20(8%2C%202)%0A%20%20%20%20%20%20%20%20centers%20%3D%20-scale%20%2B%20(whites%20%2B%200.5)%20*%20cell_w%20%20%23%20(8%2C%202)%0A%20%20%20%20%20%20%20%20%23%20log%20p(x)%20%3D%20logsumexp%20over%20white%20cells%20of%20log%20soft-indicator%0A%20%20%20%20%20%20%20%20%23%20soft%20indicator%20on%20a%20cell%3A%20product%20over%20dims%20of%20sigmoid(s*(half%20-%20%7Cx-c%7C))%0A%20%20%20%20%20%20%20%20half%20%3D%20cell_w%20%2F%202%0A%20%20%20%20%20%20%20%20diff%20%3D%20jnp.abs(x%5B...%2C%20None%2C%20%3A%5D%20-%20centers)%20%20%23%20(...%2C%208%2C%202)%0A%20%20%20%20%20%20%20%20log_ind%20%3D%20jax.nn.log_sigmoid(sharpness%20*%20(half%20-%20diff)).sum(axis%3D-1)%20%20%23%20(...%2C%208)%0A%20%20%20%20%20%20%20%20return%20jax.scipy.special.logsumexp(log_ind%2C%20axis%3D-1)%20-%20jnp.log(8.0)%0A%0A%0A%20%20%20%20%23%20sanity%3A%20should%20be%20high%20inside%20white%20cells%20and%20low%20elsewhere%0A%20%20%20%20_pts%20%3D%20jnp.array(%5B%5B0.5%2C%200.5%5D%2C%20%5B-0.5%2C%20-0.5%5D%2C%20%5B-1.5%2C%20-1.5%5D%2C%20%5B0.0%2C%200.0%5D%5D)%0A%20%20%20%20print(%22soft%20log-prob%20at%20test%20points%3A%22%2C%20soft_checkerboard_log_prob(_pts))%0A%20%20%20%20print(%22hard%20log-prob%20at%20test%20points%3A%22%2C%20checkerboard_log_prob(_pts))%0A%20%20%20%20return%20(soft_checkerboard_log_prob%2C)%0A%0A%0A%40app.cell%0Adef%20_(dist%2C%20flow_transform%2C%20jnp%2C%20numpyro%2C%20soft_checkerboard_log_prob)%3A%0A%20%20%20%20from%20numpyro.handlers%20import%20reparam%0A%20%20%20%20from%20numpyro.infer.reparam%20import%20TransformReparam%2C%20NeuTraReparam%0A%20%20%20%20from%20numpyro.distributions%20import%20TransformedDistribution%0A%0A%20%20%20%20def%20checkerboard_model()%3A%0A%20%20%20%20%20%20%20%20%23%20x%20%3D%20Phi(z)%20with%20z%20~%20N(0%2CI)%3B%20TransformReparam%20sites%20the%20latent%20z%20under%20HMC.%0A%20%20%20%20%20%20%20%20base%20%3D%20dist.Normal(jnp.zeros(2)%2C%20jnp.ones(2)).to_event(1)%0A%20%20%20%20%20%20%20%20x%20%3D%20numpyro.sample(%0A%20%20%20%20%20%20%20%20%20%20%20%20%22x%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20TransformedDistribution(base%2C%20%5Bflow_transform%5D)%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20numpyro.factor(%22target%22%2C%20soft_checkerboard_log_prob(x))%0A%0A%0A%20%20%20%20reparam_model%20%3D%20reparam(checkerboard_model%2C%20config%3D%7B%22x%22%3A%20TransformReparam()%7D)%0A%20%20%20%20return%20NeuTraReparam%2C%20reparam_model%0A%0A%0A%40app.cell%0Adef%20_(MCMC%2C%20NUTS%2C%20jax%2C%20reparam_model)%3A%0A%20%20%20%20_kernel%20%3D%20NUTS(reparam_model%2C%20max_tree_depth%3D6%2C%20target_accept_prob%3D0.8)%0A%20%20%20%20_mcmc%20%3D%20MCMC(_kernel%2C%20num_warmup%3D200%2C%20num_samples%3D500%2C%20num_chains%3D1%2C%20progress_bar%3DTrue)%0A%20%20%20%20_mcmc.run(jax.random.PRNGKey(0))%0A%20%20%20%20mcmc_samples%20%3D%20_mcmc.get_samples()%0A%20%20%20%20print(%22MCMC%20sample%20keys%3A%22%2C%20list(mcmc_samples.keys()))%0A%20%20%20%20print(%22x%20shape%3A%22%2C%20mcmc_samples%5B%22x%22%5D.shape)%0A%20%20%20%20_mcmc.print_summary()%0A%0A%20%20%20%20return%20(mcmc_samples%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Comparison%3A%20true%20target%20vs%20flow-only%20vs%20flow-NeuTra-HMC%0A%0A%20%20%20%20Three%20sets%20of%202-D%20points%3A%0A%0A%20%20%20%201.%20**True%20%24p_1%24**%20%E2%80%94%20ground-truth%20uniform%20on%20the%20white%20cells.%0A%20%20%20%202.%20**Flow%20only**%20%E2%80%94%20push%20%24z%20%5Csim%20%5Cmathcal%7BN%7D(0%2C%20I)%24%20through%20%24%5CPhi%24%20once.%0A%20%20%20%20%20%20%20Approximates%20the%20target%20but%20leaks%20mass%20onto%20black%20cells%20(the%20FM%20loss%0A%20%20%20%20%20%20%20was%20minimized%2C%20not%20driven%20to%20zero).%0A%20%20%20%203.%20**NeuTra-HMC**%20%E2%80%94%20HMC%20in%20%24z%24-space%20against%20the%20(smoothed)%20checkerboard%0A%20%20%20%20%20%20%20posterior.%20The%20MCMC%20step%20*corrects*%20the%20flow's%20residual%20error%3A%20mass%20on%0A%20%20%20%20%20%20%20white%20cells%20goes%20from%20~87%25%20to%20~98%25.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20jax%2C%0A%20%20%20%20mcmc_samples%2C%0A%20%20%20%20np%2C%0A%20%20%20%20plt%2C%0A%20%20%20%20push_forward%2C%0A%20%20%20%20sample_checkerboard%2C%0A%20%20%20%20trained_params%2C%0A)%3A%0A%20%20%20%20_n%20%3D%202000%0A%20%20%20%20_z%20%3D%20jax.random.normal(jax.random.PRNGKey(20)%2C%20(_n%2C%202))%0A%20%20%20%20flow_samples%20%3D%20np.asarray(push_forward(trained_params%2C%20_z))%0A%20%20%20%20true_samples%20%3D%20np.asarray(sample_checkerboard(jax.random.PRNGKey(21)%2C%20_n))%0A%20%20%20%20mcmc_x%20%3D%20np.asarray(mcmc_samples%5B%22x%22%5D)%0A%0A%20%20%20%20def%20_draw_grid(ax%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20cw%20%3D%202%20*%20scale%20%2F%204%0A%20%20%20%20%20%20%20%20for%20i%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20j%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20white%20%3D%20(i%20%2B%20j)%20%25%202%20%3D%3D%201%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ax.add_patch(plt.Rectangle(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(-scale%20%2B%20i*cw%2C%20-scale%20%2B%20j*cw)%2C%20cw%2C%20cw%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20facecolor%3D%220.92%22%20if%20white%20else%20%221.0%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20edgecolor%3D%220.7%22%2C%20linewidth%3D0.5%2C%20zorder%3D0%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20))%0A%0A%20%20%20%20_fig%2C%20_axes%20%3D%20plt.subplots(1%2C%203%2C%20figsize%3D(12%2C%204))%0A%20%20%20%20for%20_ax%2C%20_x%2C%20_ti%20in%20zip(_axes%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5Btrue_samples%2C%20flow_samples%2C%20mcmc_x%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B%22true%20%24p_1%24%22%2C%20%22flow%20only%20%24%5CPhi(z)%24%22%2C%20%22NeuTra-HMC%22%5D)%3A%0A%20%20%20%20%20%20%20%20_draw_grid(_ax)%0A%20%20%20%20%20%20%20%20_ax.scatter(_x%5B%3A%2C%200%5D%2C%20_x%5B%3A%2C%201%5D%2C%20s%3D2%2C%20alpha%3D0.5)%0A%20%20%20%20%20%20%20%20_ax.set_xlim(-2.5%2C%202.5)%3B%20_ax.set_ylim(-2.5%2C%202.5)%3B%20_ax.set_aspect(%22equal%22)%0A%20%20%20%20%20%20%20%20_ax.set_title(_ti)%0A%0A%20%20%20%20_fig.tight_layout()%0A%20%20%20%20_fig%0A%20%20%20%20return%20flow_samples%2C%20mcmc_x%2C%20true_samples%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(flow_samples%2C%20mcmc_x%2C%20mo%2C%20np%2C%20true_samples)%3A%0A%20%20%20%20def%20_white_mass(x%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20x%20%3D%20np.asarray(x)%0A%20%20%20%20%20%20%20%20edges%20%3D%20np.linspace(-scale%2C%20scale%2C%205)%0A%20%20%20%20%20%20%20%20H%2C%20_%2C%20_%20%3D%20np.histogram2d(x%5B%3A%2C%200%5D%2C%20x%5B%3A%2C%201%5D%2C%20bins%3D%5Bedges%2C%20edges%5D)%0A%20%20%20%20%20%20%20%20if%20H.sum()%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%200.0%0A%20%20%20%20%20%20%20%20H%20%3D%20H%20%2F%20H.sum()%0A%20%20%20%20%20%20%20%20mw%20%3D%20((np.add.outer(np.arange(4)%2C%20np.arange(4))%20%25%202)%20%3D%3D%200)%0A%20%20%20%20%20%20%20%20return%20float(H%5Bmw%5D.sum())%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20f%22%22%22%0A%20%20%20%20%20%20%20%20%7C%20sampler%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%7C%20white-cell%20mass%20(target%20%3D%201.0)%20%7C%0A%20%20%20%20%20%20%20%20%7C------------------------%7C---%3A%7C%0A%20%20%20%20%20%20%20%20%7C%20flow%20only%20%24%5C%5CPhi(z)%24%20%20%20%20%7C%20%7B_white_mass(flow_samples)%3A.3f%7D%20%7C%0A%20%20%20%20%20%20%20%20%7C%20NeuTra-HMC%20%20%20%20%20%20%20%20%20%20%20%20%20%7C%20%7B_white_mass(mcmc_x)%3A.3f%7D%20%7C%0A%20%20%20%20%20%20%20%20%7C%20true%20%24p_1%24%20(sanity)%20%20%20%20%7C%20%7B_white_mass(true_samples)%3A.3f%7D%20%7C%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20A%20discrete%20flow%20guide%20for%20SVI%0A%0A%20%20%20%20A%20continuous%20flow%20trained%20by%20flow%20matching%20is%20*not*%20a%20natural%20fit%20for%20SVI%3A%0A%20%20%20%20the%20FM%20loss%20avoids%20log-determinants%20by%20construction%2C%20so%20to%20use%20the%20same%0A%20%20%20%20network%20as%20a%20variational%20posterior%20we'd%20have%20to%20integrate%20the%20augmented%0A%20%20%20%20ODE%20through%20every%20gradient%20step%20%E2%80%94%20exactly%20the%20cost%20FM%20is%20designed%20to%20avoid.%0A%0A%20%20%20%20For%20SVI%2C%20the%20principled%20choice%20is%20a%20**discrete%20normalizing%20flow**%20with%20a%0A%20%20%20%20closed-form%20Jacobian.%20We%20use%20NumPyro's%20%60AutoBNAFNormal%60%20(Block%20Neural%0A%20%20%20%20Autoregressive%20Flow%2C%20%5BDe%20Cao%20et%20al.%2C%202019%5D(https%3A%2F%2Farxiv.org%2Fabs%2F1904.04676))%3A%0A%20%20%20%20a%20stack%20of%20monotone%20autoregressive%20transforms%20whose%20log-determinant%20is%0A%20%20%20%20the%20sum%20of%20the%20diagonal%20of%20a%20triangular%20Jacobian%20%E2%80%94%20no%20ODEs%2C%20no%20Hutchinson%0A%20%20%20%20trace%2C%20just%20dense%20matmuls.%0A%0A%20%20%20%20**Mode%20collapse%20and%20how%20to%20avoid%20it.**%20Reverse-KL%20is%20mode-seeking%2C%20and%0A%20%20%20%20naive%20ELBO%20training%20of%20a%20unimodal-base%20flow%20on%20a%20multimodal%20target%0A%20%20%20%20almost%20always%20loses%20modes%20%E2%80%94%20the%20flow%20finds%20a%20few%20good%20cells%20and%20never%0A%20%20%20%20discovers%20the%20rest.%20We%20fix%20this%20with%20a%20**temperature%20anneal**%3A%20start%20with%0A%20%20%20%20a%20smooth%20target%20(low%20%60sharpness%60%2C%20where%20the%20cells%20overlap%20into%20one%0A%20%20%20%20blurry%20blob%20with%20all%208%20modes%20attracting%20mass)%2C%20gradually%20sharpen%20to%20the%0A%20%20%20%20real%20target.%20By%20the%20time%20the%20cells%20are%20crisp%2C%20the%20flow's%20topology%20is%0A%20%20%20%20locked%20in%20over%20all%208%20modes.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(dist%2C%20jnp%2C%20numpyro%2C%20soft_checkerboard_log_prob)%3A%0A%20%20%20%20from%20numpyro.infer%20import%20SVI%2C%20Trace_ELBO%0A%20%20%20%20from%20numpyro.infer.autoguide%20import%20AutoBNAFNormal%2C%20AutoIAFNormal%0A%0A%0A%20%20%20%20def%20svi_model(sharpness%3D40.0)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Smoothed%20checkerboard%20target.%20%60%60sharpness%60%60%20is%20a%20per-step%20kwarg%0A%20%20%20%20%20%20%20%20so%20we%20can%20anneal%20it%20during%20training.%22%22%22%0A%20%20%20%20%20%20%20%20x%20%3D%20numpyro.sample(%0A%20%20%20%20%20%20%20%20%20%20%20%20%22x%22%2C%20dist.Normal(jnp.zeros(2)%2C%205%20*%20jnp.ones(2)).to_event(1)%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20numpyro.factor(%22target%22%2C%20soft_checkerboard_log_prob(x%2C%20sharpness%3Dsharpness))%0A%0A%0A%20%20%20%20%23%20A%20larger%20flow%20than%20the%20single-stage%20version%3A%204%20BNAF%20layers%20x%2064%20hidden.%0A%20%20%20%20bnaf_guide%20%3D%20AutoBNAFNormal(svi_model%2C%20num_flows%3D4%2C%20hidden_factors%3D%5B64%2C%2064%5D)%0A%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20AutoBNAFNormal%2C%0A%20%20%20%20%20%20%20%20AutoIAFNormal%2C%0A%20%20%20%20%20%20%20%20SVI%2C%0A%20%20%20%20%20%20%20%20Trace_ELBO%2C%0A%20%20%20%20%20%20%20%20bnaf_guide%2C%0A%20%20%20%20%20%20%20%20svi_model%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(SVI%2C%20Trace_ELBO%2C%20bnaf_guide%2C%20jax%2C%20jnp%2C%20np%2C%20optax%2C%20svi_model)%3A%0A%20%20%20%20SHARPNESS_SCHEDULE%20%3D%20%5B2.0%2C%205.0%2C%2010.0%2C%2020.0%2C%2040.0%5D%0A%20%20%20%20STEPS_PER_STAGE%20%3D%203000%0A%0A%20%20%20%20_svi%20%3D%20SVI(%0A%20%20%20%20%20%20%20%20svi_model%2C%20bnaf_guide%2C%0A%20%20%20%20%20%20%20%20optax.adam(2e-3)%2C%0A%20%20%20%20%20%20%20%20Trace_ELBO(num_particles%3D32)%2C%0A%20%20%20%20)%0A%0A%20%20%20%20_state%20%3D%20_svi.init(jax.random.PRNGKey(0)%2C%20sharpness%3DSHARPNESS_SCHEDULE%5B0%5D)%0A%20%20%20%20_loss_chunks%20%3D%20%5B%5D%0A%20%20%20%20for%20_i%2C%20_s%20in%20enumerate(SHARPNESS_SCHEDULE)%3A%0A%20%20%20%20%20%20%20%20_res%20%3D%20_svi.run(%0A%20%20%20%20%20%20%20%20%20%20%20%20jax.random.PRNGKey(100%20%2B%20_i)%2C%20STEPS_PER_STAGE%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20progress_bar%3DFalse%2C%20init_state%3D_state%2C%20sharpness%3D_s%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20_state%20%3D%20_res.state%0A%20%20%20%20%20%20%20%20_loss_chunks.append(np.asarray(_res.losses))%0A%20%20%20%20%20%20%20%20print(f%22sharpness%3D%7B_s%3A%3E5.1f%7D%20%20-ELBO%3D%7Bfloat(jnp.mean(_res.losses%5B-100%3A%5D))%3A.3f%7D%22)%0A%0A%20%20%20%20svi_params%20%3D%20_svi.get_params(_state)%0A%20%20%20%20svi_losses%20%3D%20np.concatenate(_loss_chunks)%0A%20%20%20%20svi_stage_boundaries%20%3D%20np.cumsum(%5Blen(c)%20for%20c%20in%20_loss_chunks%5D)%0A%0A%20%20%20%20return%20SHARPNESS_SCHEDULE%2C%20svi_losses%2C%20svi_params%2C%20svi_stage_boundaries%0A%0A%0A%40app.cell%0Adef%20_(SHARPNESS_SCHEDULE%2C%20np%2C%20plt%2C%20svi_losses%2C%20svi_stage_boundaries)%3A%0A%20%20%20%20_fig%2C%20_ax%20%3D%20plt.subplots(figsize%3D(6%2C%203))%0A%20%20%20%20_ax.plot(svi_losses%2C%20lw%3D0.4%2C%20alpha%3D0.5)%0A%20%20%20%20_w%20%3D%20100%0A%20%20%20%20_ax.plot(np.arange(_w-1%2C%20len(svi_losses))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20np.convolve(svi_losses%2C%20np.ones(_w)%2F_w%2C%20mode%3D%22valid%22)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20lw%3D1.5%2C%20color%3D%22C1%22)%0A%20%20%20%20%23%20annotate%20stage%20boundaries%0A%20%20%20%20for%20_b%2C%20_s%20in%20zip(svi_stage_boundaries%5B%3A-1%5D%2C%20SHARPNESS_SCHEDULE%5B1%3A%5D)%3A%0A%20%20%20%20%20%20%20%20_ax.axvline(_b%2C%20color%3D%220.7%22%2C%20lw%3D0.6%2C%20linestyle%3D%22--%22)%0A%20%20%20%20%20%20%20%20_ax.text(_b%2C%20_ax.get_ylim()%5B1%5D*0.98%2C%20f%22%20%20s%3D%7B_s%3Ag%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20fontsize%3D8%2C%20va%3D%22top%22%2C%20color%3D%220.4%22)%0A%20%20%20%20_ax.set_xlabel(%22SVI%20step%22)%3B%20_ax.set_ylabel(%22%24-%5Cmathrm%7BELBO%7D%24%22)%0A%20%20%20%20_ax.set_title(%22BNAF%20guide%20training%20(annealed%20sharpness)%22)%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(bnaf_guide%2C%20jax%2C%20np%2C%20svi_params)%3A%0A%20%20%20%20_post%20%3D%20bnaf_guide.sample_posterior(%0A%20%20%20%20%20%20%20%20jax.random.PRNGKey(7)%2C%20svi_params%2C%20sample_shape%3D(4000%2C)%0A%20%20%20%20)%0A%20%20%20%20svi_samples%20%3D%20np.asarray(_post%5B%22x%22%5D)%0A%0A%20%20%20%20return%20(svi_samples%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Four-way%20comparison%0A%0A%20%20%20%20Now%20we%20have%20four%20samplers%20on%20the%20same%20target.%20Same%20scatter%20axes%2C%20same%0A%20%20%20%20checkerboard%20backdrop%3A%0A%0A%20%20%20%201.%20**True%20%24p_1%24**%20%E2%80%94%20uniform%20on%20white%20cells.%0A%20%20%20%202.%20**FM%20flow**%20%E2%80%94%20trained%20on%20samples%2C%20no%20log-det%2C%20sample-time%20ODE.%0A%20%20%20%203.%20**NeuTra-HMC**%20%E2%80%94%20uses%20the%20FM%20flow%20as%20a%20reparametrization%20for%20HMC.%0A%20%20%20%20%20%20%20The%20MCMC%20step%20*corrects*%20the%20flow's%20leakage%20onto%20black%20cells.%0A%20%20%20%204.%20**BNAF%20%2B%20SVI**%20%E2%80%94%20discrete%20flow%20trained%20by%20reverse-KL%2C%20no%20samples%20needed.%0A%20%20%20%20%20%20%20Sharper%20cell%20boundaries%20than%20the%20FM%20flow%20(it's%20directly%20minimizing%0A%20%20%20%20%20%20%20reverse-KL%20against%20the%20smoothed%20target)%20but%20mode-seeking%20%E2%80%94%20typically%0A%20%20%20%20%20%20%20drops%20a%20subset%20of%20cells%20entirely.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(flow_samples%2C%20mcmc_x%2C%20plt%2C%20svi_samples%2C%20true_samples)%3A%0A%20%20%20%20def%20_grid(ax%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20cw%20%3D%202%20*%20scale%20%2F%204%0A%20%20%20%20%20%20%20%20for%20i%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20j%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20white%20%3D%20(i%20%2B%20j)%20%25%202%20%3D%3D%201%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ax.add_patch(plt.Rectangle(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(-scale%20%2B%20i*cw%2C%20-scale%20%2B%20j*cw)%2C%20cw%2C%20cw%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20facecolor%3D%220.92%22%20if%20white%20else%20%221.0%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20edgecolor%3D%220.7%22%2C%20linewidth%3D0.5%2C%20zorder%3D0%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20))%0A%0A%0A%20%20%20%20_fig%2C%20_axes%20%3D%20plt.subplots(1%2C%204%2C%20figsize%3D(16%2C%204))%0A%20%20%20%20_titles%20%3D%20%5B%22true%20%24p_1%24%22%2C%20%22FM%20flow%20%24%5CPhi(z)%24%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22NeuTra-HMC%22%2C%20%22BNAF%20%2B%20SVI%22%5D%0A%20%20%20%20_data%20%3D%20%5Btrue_samples%2C%20flow_samples%2C%20mcmc_x%2C%20svi_samples%5D%0A%0A%20%20%20%20for%20_ax%2C%20_x%2C%20_ti%20in%20zip(_axes%2C%20_data%2C%20_titles)%3A%0A%20%20%20%20%20%20%20%20_grid(_ax)%0A%20%20%20%20%20%20%20%20_ax.scatter(_x%5B%3A%2C%200%5D%2C%20_x%5B%3A%2C%201%5D%2C%20s%3D2%2C%20alpha%3D0.5)%0A%20%20%20%20%20%20%20%20_ax.set_xlim(-2.5%2C%202.5)%3B%20_ax.set_ylim(-2.5%2C%202.5)%3B%20_ax.set_aspect(%22equal%22)%0A%20%20%20%20%20%20%20%20_ax.set_title(_ti)%0A%0A%20%20%20%20_fig.tight_layout()%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(flow_samples%2C%20mcmc_x%2C%20mo%2C%20np%2C%20svi_samples%2C%20true_samples)%3A%0A%20%20%20%20def%20_per_cell_mass(x%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20x%20%3D%20np.asarray(x)%0A%20%20%20%20%20%20%20%20edges%20%3D%20np.linspace(-scale%2C%20scale%2C%205)%0A%20%20%20%20%20%20%20%20H%2C%20_%2C%20_%20%3D%20np.histogram2d(x%5B%3A%2C%200%5D%2C%20x%5B%3A%2C%201%5D%2C%20bins%3D%5Bedges%2C%20edges%5D)%0A%20%20%20%20%20%20%20%20if%20H.sum()%20%3D%3D%200%3A%20return%20H%2C%200.0%2C%200%0A%20%20%20%20%20%20%20%20H%20%3D%20H%20%2F%20H.sum()%0A%20%20%20%20%20%20%20%20mw%20%3D%20((np.add.outer(np.arange(4)%2C%20np.arange(4))%20%25%202)%20%3D%3D%200)%0A%20%20%20%20%20%20%20%20n_white_visited%20%3D%20int(((H%20%3E%200.005)%20%26%20mw).sum())%0A%20%20%20%20%20%20%20%20return%20H%2C%20float(H%5Bmw%5D.sum())%2C%20n_white_visited%0A%0A%20%20%20%20_rows%20%3D%20%5B%5D%0A%20%20%20%20for%20_name%2C%20_x%20in%20%5B(%22true%20%24p_1%24%22%2C%20true_samples)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(%22FM%20flow%22%2C%20flow_samples)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(%22NeuTra-HMC%22%2C%20mcmc_x)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(%22BNAF%20%2B%20SVI%22%2C%20svi_samples)%5D%3A%0A%20%20%20%20%20%20%20%20_%2C%20_wm%2C%20_nvis%20%3D%20_per_cell_mass(_x)%0A%20%20%20%20%20%20%20%20_rows.append(f%22%7C%20%7B_name%7D%20%7C%20%7B_wm%3A.3f%7D%20%7C%20%7B_nvis%7D%20%2F%208%20%7C%22)%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20%22%7C%20sampler%20%7C%20white-cell%20mass%20%7C%20white%20cells%20visited%20%7C%5Cn%22%0A%20%20%20%20%20%20%20%20%22%7C---%7C---%3A%7C---%3A%7C%5Cn%22%0A%20%20%20%20%20%20%20%20%2B%20%22%5Cn%22.join(_rows)%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Aside%3A%20does%20BNAF%20cope%20with%20R%26M's%20double-crescent%3F%0A%0A%20%20%20%20The%20original%20normalizing-flow%20paper%20(Rezende%20%26%20Mohamed%2C%202015)%20used%20%24K%3D8%24%0A%20%20%20%20and%20%24K%3D32%24%20stacks%20of%20*planar*%20flows%20to%20fit%20a%20smooth%20bimodal%20%22double%0A%20%20%20%20crescent%22%20target.%20A%20single%20planar%20layer%20in%202D%20has%20only%205%20parameters%2C%20so%0A%20%20%20%20even%20%24K%3D32%24%20totals%20%24%5Capprox%20160%24%20parameters.%0A%0A%20%20%20%20Modern%20flows%20are%20heavier%20per%20layer%20(one%20BNAF%20block%20at%20width%208%20already%0A%20%20%20%20has%20~400%20parameters)%2C%20but%20that%20should%20buy%20us%20layer%20count.%20Question%3A%20how%0A%20%20%20%20*few*%20BNAF%20layers%20does%20it%20take%3F%20The%20energy%20from%20the%20paper%3A%0A%0A%20%20%20%20%24%24%0A%20%20%20%20U(z)%20%3D%20%5Ctfrac%7B1%7D%7B2%7D%5C!%5Cleft(%5Cfrac%7B%5C%7Cz%5C%7C-2%7D%7B0.4%7D%5Cright)%5E%7B%5C!2%7D%20-%20%5Clog%5C!%5Cleft(e%5E%7B-%5Cfrac%7B1%7D%7B2%7D((z_1-2)%2F0.6)%5E2%7D%20%2B%20e%5E%7B-%5Cfrac%7B1%7D%7B2%7D((z_1%2B2)%2F0.6)%5E2%7D%5Cright).%0A%20%20%20%20%24%24%0A%0A%20%20%20%20A%20ring%20of%20radius%202%20modulated%20by%20a%20bimodal%20logsumexp%20along%20%24z_1%24.%20Smooth%2C%0A%20%20%20%20full%20support%2C%20gentle%20saddle%20between%20modes%20(about%202%20nats%20deep)%20%E2%80%94%20the%20kind%0A%20%20%20%20of%20target%20reverse-KL%20handles%20well.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(jax%2C%20jnp%2C%20np%2C%20plt)%3A%0A%20%20%20%20def%20crescent_log_prob(z)%3A%0A%20%20%20%20%20%20%20%20%22%22%22R%26M%202015%20'double%20crescent'%20energy%20U_1(z)%2C%20returned%20as%20log%20p%20(%3D%20-U).%22%22%22%0A%20%20%20%20%20%20%20%20r%20%3D%20jnp.linalg.norm(z%2C%20axis%3D-1)%0A%20%20%20%20%20%20%20%20ring%20%3D%200.5%20*%20((r%20-%202.0)%20%2F%200.4)%20**%202%0A%20%20%20%20%20%20%20%20a%20%3D%20-0.5%20*%20((z%5B...%2C%200%5D%20-%202.0)%20%2F%200.6)%20**%202%0A%20%20%20%20%20%20%20%20b%20%3D%20-0.5%20*%20((z%5B...%2C%200%5D%20%2B%202.0)%20%2F%200.6)%20**%202%0A%20%20%20%20%20%20%20%20bimodal%20%3D%20-jax.scipy.special.logsumexp(jnp.stack(%5Ba%2C%20b%5D%2C%20axis%3D-1)%2C%20axis%3D-1)%0A%20%20%20%20%20%20%20%20return%20-(ring%20%2B%20bimodal)%0A%0A%0A%20%20%20%20%23%20Visualize%20the%20target%0A%20%20%20%20_xx%2C%20_yy%20%3D%20jnp.meshgrid(jnp.linspace(-4%2C%204%2C%20200)%2C%20jnp.linspace(-4%2C%204%2C%20200))%0A%20%20%20%20_grid%20%3D%20jnp.stack(%5B_xx%2C%20_yy%5D%2C%20axis%3D-1)%0A%20%20%20%20_lp%20%3D%20np.asarray(crescent_log_prob(_grid))%0A%0A%20%20%20%20_fig%2C%20_ax%20%3D%20plt.subplots(figsize%3D(4%2C%204))%0A%20%20%20%20_ax.imshow(np.exp(_lp%20-%20_lp.max())%2C%20extent%3D%5B-4%2C%204%2C%20-4%2C%204%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20origin%3D%22lower%22%2C%20cmap%3D%22magma%22%2C%20aspect%3D%22equal%22)%0A%20%20%20%20_ax.set_title(%22R%26M%20double-crescent%20target%22)%0A%20%20%20%20_fig%0A%20%20%20%20return%20(crescent_log_prob%2C)%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20AutoBNAFNormal%2C%0A%20%20%20%20SVI%2C%0A%20%20%20%20Trace_ELBO%2C%0A%20%20%20%20crescent_log_prob%2C%0A%20%20%20%20dist%2C%0A%20%20%20%20jax%2C%0A%20%20%20%20jnp%2C%0A%20%20%20%20np%2C%0A%20%20%20%20numpyro%2C%0A%20%20%20%20optax%2C%0A)%3A%0A%20%20%20%20def%20crescent_model(*args%2C%20**kwargs)%3A%0A%20%20%20%20%20%20%20%20z%20%3D%20numpyro.sample(%22z%22%2C%20dist.Normal(jnp.zeros(2)%2C%205*jnp.ones(2)).to_event(1))%0A%20%20%20%20%20%20%20%20numpyro.factor(%22target%22%2C%20crescent_log_prob(z))%0A%0A%0A%20%20%20%20def%20fit_bnaf(num_flows%2C%20hidden_factors%2C%20n_steps%3D4000%2C%20lr%3D2e-3%2C%20num_particles%3D16%2C%20seed%3D0)%3A%0A%20%20%20%20%20%20%20%20g%20%3D%20AutoBNAFNormal(crescent_model%2C%20num_flows%3Dnum_flows%2C%20hidden_factors%3Dhidden_factors)%0A%20%20%20%20%20%20%20%20svi%20%3D%20SVI(crescent_model%2C%20g%2C%20optax.adam(optax.cosine_decay_schedule(lr%2C%20n_steps))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20Trace_ELBO(num_particles%3Dnum_particles))%0A%20%20%20%20%20%20%20%20res%20%3D%20svi.run(jax.random.PRNGKey(seed)%2C%20n_steps%2C%20progress_bar%3DFalse)%0A%20%20%20%20%20%20%20%20n_params%20%3D%20sum(p.size%20for%20p%20in%20jax.tree_util.tree_leaves(res.params))%0A%20%20%20%20%20%20%20%20samples%20%3D%20g.sample_posterior(jax.random.PRNGKey(seed%2B1)%2C%20res.params%2C%20sample_shape%3D(4000%2C))%0A%20%20%20%20%20%20%20%20return%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22name%22%3A%20f%22BNAF%20K%3D%7Bnum_flows%7D%20H%3D%7Bhidden_factors%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22n_params%22%3A%20n_params%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22final_loss%22%3A%20float(jnp.mean(res.losses%5B-100%3A%5D))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22losses%22%3A%20np.asarray(res.losses)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22samples%22%3A%20np.asarray(samples%5B%22z%22%5D)%2C%0A%20%20%20%20%20%20%20%20%7D%0A%0A%0A%20%20%20%20crescent_runs%20%3D%20%5B%0A%20%20%20%20%20%20%20%20fit_bnaf(num_flows%3D1%2C%20hidden_factors%3D%5B8%2C%208%5D)%2C%0A%20%20%20%20%20%20%20%20fit_bnaf(num_flows%3D2%2C%20hidden_factors%3D%5B8%2C%208%5D)%2C%0A%20%20%20%20%20%20%20%20fit_bnaf(num_flows%3D4%2C%20hidden_factors%3D%5B8%2C%208%5D)%2C%0A%20%20%20%20%5D%0A%0A%20%20%20%20for%20_r%20in%20crescent_runs%3A%0A%20%20%20%20%20%20%20%20print(f%22%7B_r%5B'name'%5D%3A30s%7D%20params%3D%7B_r%5B'n_params'%5D%3A5d%7D%20-ELBO%3D%7B_r%5B'final_loss'%5D%3A.3f%7D%22)%0A%0A%20%20%20%20return%20crescent_model%2C%20crescent_runs%0A%0A%0A%40app.cell%0Adef%20_(crescent_log_prob%2C%20crescent_runs%2C%20iaf_runs%2C%20jnp%2C%20np%2C%20plt)%3A%0A%20%20%20%20_xx2%2C%20_yy2%20%3D%20np.meshgrid(np.linspace(-4%2C%204%2C%20200)%2C%20np.linspace(-4%2C%204%2C%20200))%0A%20%20%20%20_grid2%20%3D%20np.stack(%5B_xx2%2C%20_yy2%5D%2C%20axis%3D-1)%0A%20%20%20%20_lp%20%3D%20np.asarray(crescent_log_prob(jnp.array(_grid2)))%0A%20%20%20%20_target_density%20%3D%20np.exp(_lp%20-%20_lp.max())%0A%0A%20%20%20%20_all_runs%20%3D%20crescent_runs%20%2B%20iaf_runs%0A%20%20%20%20_n%20%3D%20len(_all_runs)%0A%20%20%20%20_fig%2C%20_axes%20%3D%20plt.subplots(2%2C%204%2C%20figsize%3D(16%2C%208))%0A%20%20%20%20_axes%20%3D%20_axes.flatten()%0A%0A%20%20%20%20%23%20panel%200%3A%20target%0A%20%20%20%20_axes%5B0%5D.imshow(_target_density%2C%20extent%3D%5B-4%2C4%2C-4%2C4%5D%2C%20origin%3D%22lower%22%2C%20cmap%3D%22magma%22)%0A%20%20%20%20_axes%5B0%5D.set_title(%22target%20%24p%24%22)%0A%0A%20%20%20%20%23%20panels%201..6%3A%20runs%0A%20%20%20%20for%20_ax%2C%20_run%20in%20zip(_axes%5B1%3A1%2B_n%5D%2C%20_all_runs)%3A%0A%20%20%20%20%20%20%20%20_ax.imshow(_target_density%2C%20extent%3D%5B-4%2C4%2C-4%2C4%5D%2C%20origin%3D%22lower%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20cmap%3D%22magma%22%2C%20alpha%3D0.4)%0A%20%20%20%20%20%20%20%20_ax.scatter(_run%5B%22samples%22%5D%5B%3A%2C0%5D%2C%20_run%5B%22samples%22%5D%5B%3A%2C1%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20s%3D2%2C%20alpha%3D0.4%2C%20color%3D%22white%22)%0A%20%20%20%20%20%20%20%20_s%20%3D%20_run%5B%22samples%22%5D%0A%20%20%20%20%20%20%20%20_left%20%3D%20int((_s%5B%3A%2C0%5D%20%3C%20-0.5).sum())%0A%20%20%20%20%20%20%20%20_right%20%3D%20int((_s%5B%3A%2C0%5D%20%3E%200.5).sum())%0A%20%20%20%20%20%20%20%20_ax.set_title(f%22%7B_run%5B'name'%5D%7D%5Cnparams%3D%7B_run%5B'n_params'%5D%7D%2C%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22L%2FR%3D%7B_left%7D%2F%7B_right%7D%2C%20-ELBO%3D%7B_run%5B'final_loss'%5D%3A.2f%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20fontsize%3D9)%0A%0A%20%20%20%20%23%20any%20leftover%20panels%3A%20hide%0A%20%20%20%20for%20_ax%20in%20_axes%5B1%2B_n%3A%5D%3A%0A%20%20%20%20%20%20%20%20_ax.axis(%22off%22)%0A%0A%20%20%20%20for%20_ax%20in%20_axes%5B%3A1%2B_n%5D%3A%0A%20%20%20%20%20%20%20%20_ax.set_xlim(-4%2C%204)%3B%20_ax.set_ylim(-4%2C%204)%3B%20_ax.set_aspect(%22equal%22)%0A%20%20%20%20%20%20%20%20_ax.set_facecolor(%220.05%22)%0A%0A%20%20%20%20_fig.tight_layout()%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(AutoIAFNormal%2C%20SVI%2C%20Trace_ELBO%2C%20crescent_model%2C%20jax%2C%20jnp%2C%20np%2C%20optax)%3A%0A%20%20%20%20def%20fit_iaf(num_flows%2C%20hidden_dims%2C%20n_steps%3D4000%2C%20lr%3D2e-3%2C%20num_particles%3D16%2C%20seed%3D0)%3A%0A%20%20%20%20%20%20%20%20g%20%3D%20AutoIAFNormal(crescent_model%2C%20num_flows%3Dnum_flows%2C%20hidden_dims%3Dhidden_dims)%0A%20%20%20%20%20%20%20%20svi%20%3D%20SVI(crescent_model%2C%20g%2C%20optax.adam(optax.cosine_decay_schedule(lr%2C%20n_steps))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20Trace_ELBO(num_particles%3Dnum_particles))%0A%20%20%20%20%20%20%20%20res%20%3D%20svi.run(jax.random.PRNGKey(seed)%2C%20n_steps%2C%20progress_bar%3DFalse)%0A%20%20%20%20%20%20%20%20n_params%20%3D%20sum(p.size%20for%20p%20in%20jax.tree_util.tree_leaves(res.params))%0A%20%20%20%20%20%20%20%20samples%20%3D%20g.sample_posterior(jax.random.PRNGKey(seed%2B1)%2C%20res.params%2C%20sample_shape%3D(4000%2C))%0A%20%20%20%20%20%20%20%20return%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22name%22%3A%20f%22IAF%20K%3D%7Bnum_flows%7D%20H%3D%7Bhidden_dims%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22n_params%22%3A%20n_params%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22final_loss%22%3A%20float(jnp.mean(res.losses%5B-100%3A%5D))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22losses%22%3A%20np.asarray(res.losses)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22samples%22%3A%20np.asarray(samples%5B%22z%22%5D)%2C%0A%20%20%20%20%20%20%20%20%7D%0A%0A%0A%20%20%20%20iaf_runs%20%3D%20%5B%0A%20%20%20%20%20%20%20%20fit_iaf(num_flows%3D2%2C%20hidden_dims%3D%5B16%2C%2016%5D)%2C%0A%20%20%20%20%20%20%20%20fit_iaf(num_flows%3D4%2C%20hidden_dims%3D%5B16%2C%2016%5D)%2C%0A%20%20%20%20%20%20%20%20fit_iaf(num_flows%3D8%2C%20hidden_dims%3D%5B16%2C%2016%5D)%2C%0A%20%20%20%20%5D%0A%0A%20%20%20%20for%20_r%20in%20iaf_runs%3A%0A%20%20%20%20%20%20%20%20print(f%22%7B_r%5B%22name%22%5D%3A30s%7D%20params%3D%7B_r%5B%22n_params%22%5D%3A5d%7D%20-ELBO%3D%7B_r%5B%22final_loss%22%5D%3A.3f%7D%22)%0A%0A%20%20%20%20return%20(iaf_runs%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%20Findings%0A%0A%20%20%20%20Per-parameter%2C%20BNAF%20and%20IAF%20can%20both%20fit%20the%20double-crescent%20%E2%80%94%20but%20the%0A%20%20%20%20paths%20they%20take%20are%20very%20different.%0A%0A%20%20%20%20**BNAF.**%20%24K%3D1%2C%20H%3D%5B8%2C8%5D%24%20(388%20params%2C%20~2.4%C3%97%20planar's%20160)%20covers%20both%0A%20%20%20%20crescents%20and%20lays%20mass%20on%20the%20ring.%20Adding%20capacity%20*hurts*%3A%20%24K%3D2%24%20and%0A%20%20%20%20%24K%3D4%24%20both%20collapse%20onto%20the%20left%20crescent%20only.%20The%20richer%20per-layer%0A%20%20%20%20transform%20lets%20the%20flow%20carve%20a%20sharp%20single-mode%20density%20that's%20a%0A%20%20%20%20locally-optimal%20reverse-KL%20fit%2C%20and%20the%20optimizer%20can't%20escape%20it.%0A%0A%20%20%20%20**IAF.**%20All%20three%20configs%20(%24K%3D2%2C%204%2C%208%24)%20cover%20both%20modes.%20The%20affine%0A%20%20%20%20per-layer%20restriction%20prevents%20IAF%20from%20being%20sharp%20enough%20on%20a%20single%0A%20%20%20%20mode%20to%20%22win%22%20by%20collapsing%20%E2%80%94%20by%20the%20time%20enough%20layers%20compose%20to%20give%0A%20%20%20%20sharp%20marginals%2C%20the%20early-training%20trajectory%20has%20already%20locked%20in%0A%20%20%20%20bimodal%20coverage.%0A%0A%20%20%20%20**The%20general%20lesson%3A**%20per-layer%20expressivity%20is%20a%20double-edged%20sword.%0A%20%20%20%20It%20helps%20you%20fit%20a%20target%20you've%20already%20covered%2C%20and%20hurts%20you%20discover%0A%20%20%20%20coverage%20you%20don't%20yet%20have.%20This%20is%20the%20same%20trade-off%20as%20warm-up%0A%20%20%20%20schedules%20and%20learning-rate%20annealing%20%E2%80%94%20but%20at%20the%20architecture%20level%2C%0A%20%20%20%20where%20it's%20harder%20to%20control%20after%20the%20fact.%0A%0A%20%20%20%20The%20R%26M%20planar%20flows%20occupy%20a%20sweet%20spot%3A%20each%20layer%20is%20so%20weak%0A%20%20%20%20(rank-1%20perturbation)%20that%20the%20optimization%20landscape%20is%20forgiving%2C%20and%0A%20%20%20%20you%20compensate%20with%20depth.%20Modern%20flows%20pile%20capacity%20into%20layers%2C%20get%0A%20%20%20%20fewer-layer%20wins%20on%20smooth%20targets%2C%20and%20can%20quietly%20fail%20on%20multimodal%0A%20%20%20%20ones%20%E2%80%94%20the%20failure%20we%20just%20saw%2C%20which%20our%20checkerboard%20analysis%20would%0A%20%20%20%20have%20predicted.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Does%20IAF%20cope%20better%20with%20the%20checkerboard%3F%0A%0A%20%20%20%20The%20crescent%20experiment%20hinted%20that%20**IAF's%20per-layer%20affine%20restriction%0A%20%20%20%20keeps%20it%20broader%20during%20early%20training**%20%E2%80%94%20exactly%20the%20property%20we%20want%0A%20%20%20%20when%20the%20target%20is%20multimodal.%20Question%3A%20at%20fixed%20sharpness%20(no%20annealing)%2C%0A%20%20%20%20can%20IAF%20cover%20more%20checkerboard%20cells%20than%20BNAF%20did%3F%0A%0A%20%20%20%20We%20sweep%20over%20flow%20depth%20%24K%24%20and%20random%20seed.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20AutoIAFNormal%2C%0A%20%20%20%20SVI%2C%0A%20%20%20%20Trace_ELBO%2C%0A%20%20%20%20dist%2C%0A%20%20%20%20jax%2C%0A%20%20%20%20jnp%2C%0A%20%20%20%20np%2C%0A%20%20%20%20numpyro%2C%0A%20%20%20%20optax%2C%0A%20%20%20%20soft_checkerboard_log_prob%2C%0A)%3A%0A%20%20%20%20def%20fit_iaf_checkerboard(num_flows%2C%20hidden_dims%2C%20n_steps%3D8000%2C%20lr%3D2e-3%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20num_particles%3D32%2C%20seed%3D0%2C%20sharpness%3D40.0)%3A%0A%20%20%20%20%20%20%20%20def%20m(*a%2C%20**k)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20numpyro.sample(%22x%22%2C%20dist.Normal(jnp.zeros(2)%2C%205*jnp.ones(2)).to_event(1))%0A%20%20%20%20%20%20%20%20%20%20%20%20numpyro.factor(%22target%22%2C%20soft_checkerboard_log_prob(x%2C%20sharpness%3Dsharpness))%0A%20%20%20%20%20%20%20%20g%20%3D%20AutoIAFNormal(m%2C%20num_flows%3Dnum_flows%2C%20hidden_dims%3Dhidden_dims)%0A%20%20%20%20%20%20%20%20svi%20%3D%20SVI(m%2C%20g%2C%20optax.adam(optax.cosine_decay_schedule(lr%2C%20n_steps))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20Trace_ELBO(num_particles%3Dnum_particles))%0A%20%20%20%20%20%20%20%20state%20%3D%20svi.run(jax.random.PRNGKey(seed)%2C%20n_steps%2C%20progress_bar%3DFalse)%0A%20%20%20%20%20%20%20%20samples%20%3D%20g.sample_posterior(%0A%20%20%20%20%20%20%20%20%20%20%20%20jax.random.PRNGKey(seed%20%2B%20100)%2C%20state.params%2C%20sample_shape%3D(4000%2C)%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20return%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22model%22%3A%20m%2C%20%22guide%22%3A%20g%2C%20%22params%22%3A%20state.params%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22n_params%22%3A%20sum(p.size%20for%20p%20in%20jax.tree_util.tree_leaves(state.params))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22samples%22%3A%20np.asarray(samples%5B%22x%22%5D)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22final_loss%22%3A%20float(jnp.mean(state.losses%5B-100%3A%5D))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22K%22%3A%20num_flows%2C%20%22seed%22%3A%20seed%2C%0A%20%20%20%20%20%20%20%20%7D%0A%0A%0A%20%20%20%20def%20coverage(s)%3A%0A%20%20%20%20%20%20%20%20s%20%3D%20np.asarray(s)%0A%20%20%20%20%20%20%20%20H%2C%20_%2C%20_%20%3D%20np.histogram2d(s%5B%3A%2C0%5D%2C%20s%5B%3A%2C1%5D%2C%20bins%3D%5Bnp.linspace(-2%2C%202%2C%205)%5D*2)%0A%20%20%20%20%20%20%20%20H%20%3D%20H%20%2F%20max(H.sum()%2C%201)%0A%20%20%20%20%20%20%20%20mw%20%3D%20((np.add.outer(np.arange(4)%2C%20np.arange(4))%20%25%202)%20%3D%3D%200)%0A%20%20%20%20%20%20%20%20return%20float(H%5Bmw%5D.sum())%2C%20int(((H%20%3E%200.005)%20%26%20mw).sum())%0A%0A%0A%20%20%20%20%23%20Sweep%0A%20%20%20%20iaf_checker_runs%20%3D%20%5B%5D%0A%20%20%20%20for%20_K%20in%20%5B4%2C%208%2C%2016%5D%3A%0A%20%20%20%20%20%20%20%20for%20_seed%20in%20%5B0%2C%201%2C%202%5D%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20_r%20%3D%20fit_iaf_checkerboard(num_flows%3D_K%2C%20hidden_dims%3D%5B16%2C%2016%5D%2C%20seed%3D_seed)%0A%20%20%20%20%20%20%20%20%20%20%20%20_wm%2C%20_nv%20%3D%20coverage(_r%5B%22samples%22%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20_r%5B%22white_mass%22%5D%2C%20_r%5B%22visited%22%5D%20%3D%20_wm%2C%20_nv%0A%20%20%20%20%20%20%20%20%20%20%20%20iaf_checker_runs.append(_r)%0A%20%20%20%20%20%20%20%20%20%20%20%20print(f%22K%3D%7B_K%3A2d%7D%20seed%3D%7B_seed%7D%20-ELBO%3D%7B_r%5B'final_loss'%5D%3A.3f%7D%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22white%3D%7B_wm%3A.3f%7D%20visited%3D%7B_nv%7D%2F8%22)%0A%0A%20%20%20%20return%20coverage%2C%20iaf_checker_runs%0A%0A%0A%40app.cell%0Adef%20_(iaf_checker_runs%2C%20plt)%3A%0A%20%20%20%20def%20_grid_panel(ax%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20cw%20%3D%202%20*%20scale%20%2F%204%0A%20%20%20%20%20%20%20%20for%20i%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20j%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20white%20%3D%20(i%20%2B%20j)%20%25%202%20%3D%3D%201%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ax.add_patch(plt.Rectangle(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(-scale%20%2B%20i*cw%2C%20-scale%20%2B%20j*cw)%2C%20cw%2C%20cw%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20facecolor%3D%220.92%22%20if%20white%20else%20%221.0%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20edgecolor%3D%220.7%22%2C%20linewidth%3D0.5%2C%20zorder%3D0))%0A%0A%0A%20%20%20%20_fig%2C%20_axes%20%3D%20plt.subplots(3%2C%203%2C%20figsize%3D(11%2C%2011))%0A%20%20%20%20for%20_ax%2C%20_r%20in%20zip(_axes.flatten()%2C%20iaf_checker_runs)%3A%0A%20%20%20%20%20%20%20%20_grid_panel(_ax)%0A%20%20%20%20%20%20%20%20_ax.scatter(_r%5B%22samples%22%5D%5B%3A%2C0%5D%2C%20_r%5B%22samples%22%5D%5B%3A%2C1%5D%2C%20s%3D2%2C%20alpha%3D0.4)%0A%20%20%20%20%20%20%20%20_ax.set_xlim(-2.5%2C%202.5)%3B%20_ax.set_ylim(-2.5%2C%202.5)%3B%20_ax.set_aspect(%22equal%22)%0A%20%20%20%20%20%20%20%20_ax.set_title(f%22K%3D%7B_r%5B'K'%5D%7D%20seed%3D%7B_r%5B'seed'%5D%7D%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22visited%3D%7B_r%5B'visited'%5D%7D%2F8%20white%3D%7B_r%5B'white_mass'%5D%3A.2f%7D%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20fontsize%3D9)%0A%20%20%20%20_fig.tight_layout()%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20pattern%20matches%20the%20crescent%20experiment%20in%20spirit%20but%20is%20sobering%20in%0A%20%20%20%20extent.%20Increasing%20%24K%24%20from%204%20to%208%20helps%20coverage%20(4%2F8%20%E2%86%92%20up%20to%206%2F8).%20Going%0A%20%20%20%20further%20is%20unstable%20%E2%80%94%20at%20%24K%3D16%24%20one%20of%20three%20seeds%20diverged%20numerically.%0A%20%20%20%20No%20configuration%20covers%20all%208%20modes%3B%20the%20best%20is%20**6%2F8%20at%20%24K%3D8%24%2C%20seed%200**%2C%0A%20%20%20%20matching%20what%20BNAF%20%2B%20sharpness%20annealing%20gave%20us%20with%20target%20manipulation%0A%20%20%20%20we%20said%20we%20wouldn't%20allow.%0A%0A%20%20%20%20Two%20takeaways%3A%0A%0A%20%20%20%20-%20IAF's%20robustness%20on%20the%20crescent%20translates%20to%20a%20*constant-factor*%0A%20%20%20%20%20%20improvement%20on%20the%20checkerboard%20(BNAF's%204%2F8%20%E2%86%92%20IAF's%206%2F8).%20Real%0A%20%20%20%20%20%20improvement%2C%20not%20transformative.%0A%20%20%20%20-%20The%20remaining%20gap%20is%20the%20same%20fundamental%20obstruction%3A%0A%20%20%20%20%20%20%24%5Cmathrm%7Bsharpness%7D%3D40%24%20leaves%20%24%5Csim%2040%24-nat%20barriers%20between%20cells%0A%20%20%20%20%20%20in%20%24z%24-space%2C%20and%20no%20local%20optimization%20can%20drive%20mass%20across%20them.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%20Combining%20the%20best%20IAF%20guide%20with%20NeuTra-HMC%0A%0A%20%20%20%20Take%20the%20best%20IAF%20run%20(%24K%3D8%24%2C%20seed%200%2C%206%20cells%20covered)%20and%20use%20it%20as%20a%0A%20%20%20%20NeuTra%20reparametrization.%20NumPyro%20has%20built-in%20support%20for%20this%20via%0A%20%20%20%20%60NeuTraReparam%60%2C%20which%20gives%20us%20a%20properly%20typed%20%24z%24-space%20target%20without%0A%20%20%20%20hand-rolling%20%60Transform%60%20classes%20(cf.%20our%20FM%20section%2C%20where%20we%20needed%20an%0A%20%20%20%20augmented-ODE%20Jacobian).%0A%0A%20%20%20%20Prediction%20(from%20theory)%3A%20NeuTra%20**cannot**%20rescue%20the%20missing%202%20cells%20%E2%80%94%0A%20%20%20%20HMC's%20local%20dynamics%20can't%20cross%2040-nat%20barriers%20any%20more%20easily%20after%0A%20%20%20%20reparametrization%20than%20before.%20But%20it%20**should**%20sharpen%20mass%20on%20the%206%0A%20%20%20%20cells%20the%20flow%20already%20visits%2C%20the%20same%20way%20it%20did%20with%20the%20FM%20flow.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(MCMC%2C%20NUTS%2C%20NeuTraReparam%2C%20coverage%2C%20iaf_checker_runs%2C%20jax%2C%20np)%3A%0A%20%20%20%20best_iaf%20%3D%20max(iaf_checker_runs%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20key%3Dlambda%20r%3A%20(r%5B%22visited%22%5D%2C%20r%5B%22white_mass%22%5D%2C%20-r%5B%22final_loss%22%5D))%0A%20%20%20%20print(f%22best%20IAF%3A%20K%3D%7Bbest_iaf%5B'K'%5D%7D%20seed%3D%7Bbest_iaf%5B'seed'%5D%7D%20%22%0A%20%20%20%20%20%20%20%20%20%20f%22visited%3D%7Bbest_iaf%5B'visited'%5D%7D%2F8%20white%3D%7Bbest_iaf%5B'white_mass'%5D%3A.3f%7D%22)%0A%0A%20%20%20%20_neutra%20%3D%20NeuTraReparam(best_iaf%5B%22guide%22%5D%2C%20best_iaf%5B%22params%22%5D)%0A%20%20%20%20_neutra_model%20%3D%20_neutra.reparam(best_iaf%5B%22model%22%5D)%0A%20%20%20%20_kernel%20%3D%20NUTS(_neutra_model%2C%20max_tree_depth%3D6%2C%20target_accept_prob%3D0.8)%0A%20%20%20%20_mcmc%20%3D%20MCMC(_kernel%2C%20num_warmup%3D200%2C%20num_samples%3D500%2C%20num_chains%3D4%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20chain_method%3D%22sequential%22%2C%20progress_bar%3DTrue)%0A%20%20%20%20_mcmc.run(jax.random.PRNGKey(42))%0A%0A%20%20%20%20iaf_neutra_samples%20%3D%20np.asarray(_mcmc.get_samples()%5B%22x%22%5D)%0A%20%20%20%20_wm%2C%20_nv%20%3D%20coverage(iaf_neutra_samples)%0A%20%20%20%20print(f%22IAF%20%2B%20NeuTra%20(4%20chains%2C%20500%20each)%3A%20white%3D%7B_wm%3A.3f%7D%20visited%3D%7B_nv%7D%2F8%22)%0A%0A%20%20%20%20_per_chain%20%3D%20_mcmc.get_samples(group_by_chain%3DTrue)%5B%22x%22%5D%0A%20%20%20%20for%20_c%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20_wmc%2C%20_nvc%20%3D%20coverage(_per_chain%5B_c%5D)%0A%20%20%20%20%20%20%20%20print(f%22%20%20chain%20%7B_c%7D%3A%20visited%3D%7B_nvc%7D%2F8%20white%3D%7B_wmc%3A.3f%7D%22)%0A%0A%20%20%20%20return%20best_iaf%2C%20iaf_neutra_samples%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%20Final%20picture%0A%0A%20%20%20%20All%20five%20samplers%2C%20side%20by%20side%2C%20on%20the%20same%20target%20%E2%80%94%20and%20a%20clear%20story%3A%0A%0A%20%20%20%20*%20The%20**flow%20alone**%20results%20(FM%2C%20BNAF%2Banneal%2C%20IAF%20best)%20all%20leak%20across%0A%20%20%20%20%20%20cell%20boundaries.%20None%20of%20them%20is%20sharp.%0A%20%20%20%20*%20**NeuTra-HMC%20sharpens%20whatever%20it's%20given**%3A%20with%20the%20FM%20flow%20it%0A%20%20%20%20%20%20reaches%208%2F8%20modes%20at%2098%25%20white%3B%20with%20the%20IAF%20flow%20(best%206%2F8%20modes)%0A%20%20%20%20%20%20it%20reaches%207%2F8%20at%2097%25%20white.%20**It%20does%20not%20invent%20the%20missing%20mode**%20%E2%80%94%0A%20%20%20%20%20%20because%20the%20IAF%20flow%20simply%20doesn't%20visit%20the%20bottom-right%20region%20of%0A%20%20%20%20%20%20%24z%24-space%2C%20the%20pullback%20target%20there%20has%20very%20low%20density%20(and%20very%0A%20%20%20%20%20%20low%20gradients)%2C%20and%20HMC%20has%20no%20reason%20to%20go.%0A%20%20%20%20*%20The%20fact%20that%20NeuTra%2BIAF%20*does*%20find%20one%20extra%20cell%20beyond%20IAF's%20flow%0A%20%20%20%20%20%20coverage%20is%20a%20small%20but%20real%20demonstration%20that%20HMC%20can%20reach%20into%0A%20%20%20%20%20%20*low-density-but-non-zero*%20tails%20that%20the%20flow%20nominally%20covers%20%E2%80%94%20but%0A%20%20%20%20%20%20cannot%20cross%20the%20steep%20walls%20created%20by%20missing%20flow%20support.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(%0A%20%20%20%20best_iaf%2C%0A%20%20%20%20flow_samples%2C%0A%20%20%20%20iaf_neutra_samples%2C%0A%20%20%20%20mcmc_x%2C%0A%20%20%20%20plt%2C%0A%20%20%20%20svi_samples%2C%0A%20%20%20%20true_samples%2C%0A)%3A%0A%20%20%20%20def%20_grid(ax%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20cw%20%3D%202%20*%20scale%20%2F%204%0A%20%20%20%20%20%20%20%20for%20i%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20j%20in%20range(4)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20white%20%3D%20(i%20%2B%20j)%20%25%202%20%3D%3D%200%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ax.add_patch(plt.Rectangle(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(-scale%20%2B%20i*cw%2C%20-scale%20%2B%20j*cw)%2C%20cw%2C%20cw%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20facecolor%3D%220.92%22%20if%20white%20else%20%221.0%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20edgecolor%3D%220.7%22%2C%20linewidth%3D0.5%2C%20zorder%3D0))%0A%0A%0A%20%20%20%20_iaf_only%20%3D%20best_iaf%5B%22samples%22%5D%0A%20%20%20%20_panels%20%3D%20%5B%0A%20%20%20%20%20%20%20%20(%22true%20%24p_1%24%22%2C%20true_samples)%2C%0A%20%20%20%20%20%20%20%20(%22FM%20flow%22%2C%20flow_samples)%2C%0A%20%20%20%20%20%20%20%20(%22FM%20%2B%20NeuTra-HMC%22%2C%20mcmc_x)%2C%0A%20%20%20%20%20%20%20%20(%22BNAF%20%2B%20SVI%20(anneal)%22%2C%20svi_samples)%2C%0A%20%20%20%20%20%20%20%20(%22IAF%20(best%2C%20K%3D8)%22%2C%20_iaf_only)%2C%0A%20%20%20%20%20%20%20%20(%22IAF%20%2B%20NeuTra-HMC%22%2C%20iaf_neutra_samples)%2C%0A%20%20%20%20%5D%0A%0A%20%20%20%20_fig%2C%20_axes%20%3D%20plt.subplots(2%2C%203%2C%20figsize%3D(13%2C%208))%0A%20%20%20%20for%20_ax%2C%20(_t%2C%20_x)%20in%20zip(_axes.flatten()%2C%20_panels)%3A%0A%20%20%20%20%20%20%20%20_grid(_ax)%0A%20%20%20%20%20%20%20%20_ax.scatter(_x%5B%3A%2C0%5D%2C%20_x%5B%3A%2C1%5D%2C%20s%3D2%2C%20alpha%3D0.5)%0A%20%20%20%20%20%20%20%20_ax.set_xlim(-2.5%2C%202.5)%3B%20_ax.set_ylim(-2.5%2C%202.5)%3B%20_ax.set_aspect(%22equal%22)%0A%20%20%20%20%20%20%20%20_ax.set_title(_t)%0A%20%20%20%20_fig.tight_layout()%0A%20%20%20%20_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(%0A%20%20%20%20best_iaf%2C%0A%20%20%20%20flow_samples%2C%0A%20%20%20%20iaf_neutra_samples%2C%0A%20%20%20%20mcmc_x%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20np%2C%0A%20%20%20%20svi_samples%2C%0A%20%20%20%20true_samples%2C%0A)%3A%0A%20%20%20%20def%20_stats(x%2C%20scale%3D2.0)%3A%0A%20%20%20%20%20%20%20%20x%20%3D%20np.asarray(x)%0A%20%20%20%20%20%20%20%20edges%20%3D%20np.linspace(-scale%2C%20scale%2C%205)%0A%20%20%20%20%20%20%20%20H%2C%20_%2C%20_%20%3D%20np.histogram2d(x%5B%3A%2C%200%5D%2C%20x%5B%3A%2C%201%5D%2C%20bins%3D%5Bedges%2C%20edges%5D)%0A%20%20%20%20%20%20%20%20if%20H.sum()%20%3D%3D%200%3A%20return%200.0%2C%200%0A%20%20%20%20%20%20%20%20H%20%3D%20H%20%2F%20H.sum()%0A%20%20%20%20%20%20%20%20mw%20%3D%20((np.add.outer(np.arange(4)%2C%20np.arange(4))%20%25%202)%20%3D%3D%200)%0A%20%20%20%20%20%20%20%20return%20float(H%5Bmw%5D.sum())%2C%20int(((H%20%3E%200.005)%20%26%20mw).sum())%0A%0A%0A%20%20%20%20_rows%20%3D%20%5B%5D%0A%20%20%20%20for%20_name%2C%20_x%20in%20%5B%0A%20%20%20%20%20%20%20%20(%22true%20%24p_1%24%22%2C%20%20%20%20%20%20%20%20%20%20%20%20true_samples)%2C%0A%20%20%20%20%20%20%20%20(%22FM%20flow%22%2C%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20flow_samples)%2C%0A%20%20%20%20%20%20%20%20(%22FM%20%2B%20NeuTra-HMC%22%2C%20%20%20%20%20%20%20mcmc_x)%2C%0A%20%20%20%20%20%20%20%20(%22BNAF%20%2B%20SVI%20(anneal)%22%2C%20%20%20svi_samples)%2C%0A%20%20%20%20%20%20%20%20(%22IAF%20best%20(K%3D8%2C%20no%20anneal)%22%2C%20best_iaf%5B%22samples%22%5D)%2C%0A%20%20%20%20%20%20%20%20(%22IAF%20%2B%20NeuTra-HMC%22%2C%20%20%20%20%20%20iaf_neutra_samples)%2C%0A%20%20%20%20%5D%3A%0A%20%20%20%20%20%20%20%20_wm%2C%20_nv%20%3D%20_stats(_x)%0A%20%20%20%20%20%20%20%20_rows.append(f%22%7C%20%7B_name%7D%20%7C%20%7B_wm%3A.3f%7D%20%7C%20%7B_nv%7D%20%2F%208%20%7C%22)%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20%22%7C%20sampler%20%7C%20white-cell%20mass%20%7C%20white%20cells%20visited%20%7C%5Cn%22%0A%20%20%20%20%20%20%20%20%22%7C---%7C---%3A%7C---%3A%7C%5Cn%22%20%2B%20%22%5Cn%22.join(_rows)%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
8bfb6ca4575dbf51e27ec3d38be75014