External backends¶
Prerequisite: the Flexible inference tutorial,
which runs each of these backends end-to-end on a Ricker population
example. This notebook is a tighter reference for the dispatch
mechanism itself — how condition_on(model, data) decides which
backend to call and how to pin it explicitly.
condition_on(model, data) in ProbPipe is the single call site for
Bayesian inference. Under the hood it consults the inference method
registry, which inspects the model's protocol support and picks a
compatible sampler. This notebook walks through:
- What methods are currently registered, and what each needs.
- How
condition_onranks them for a given model. - How to pin a specific method.
- How optional backends (Stan, PyMC, nutpie, sbijax) plug into the same interface.
The upshot is that switching from the default TFP NUTS to
cmdstan's NUTS, or to simulation-based inference via sbijax, is a
one-line change: condition_on(model, data, method="cmdstan_nuts")
or method="sbijax_smcabc".
import jax
import jax.numpy as jnp
from probpipe import (Normal, SimpleModel, ProductDistribution,
inference_method_registry, condition_on)
from probpipe.modeling import Likelihood, GenerativeLikelihood
1. The registry¶
The registry holds every backend ProbPipe knows how to use. list_methods
returns them all — feasibility for a specific model is checked separately.
for name in inference_method_registry.list_methods():
print(f" {name}")
nutpie_nuts cmdstan_nuts pymc_nuts tfp_nuts tfp_hmc tfp_rwmh blackjax_sgld blackjax_sghmc pymc_advi sbijax_smcabc
Each entry is a method adapter that knows:
- What model shape it needs.
cmdstan_nutsneeds aStanModel;sbijax_smcabcneeds aSimpleGenerativeModel; the TFP samplers only needSupportsLogProb. - What JAX / external dependencies it requires. Optional backends (Stan, PyMC, nutpie, sbijax) register themselves at import time when their dependency is available.
2. Method dispatch for a simple model¶
Let's build a minimal SimpleModel with a Gaussian prior and a
Gaussian likelihood, and see which backends are available.
class GaussianLikelihood(Likelihood):
"""(mu, log_sigma) → N(data | mu, exp(log_sigma))"""
def log_likelihood(self, params, data):
mu, log_sigma = params[0], params[1]
return jnp.sum(-0.5 * ((data - mu) / jnp.exp(log_sigma)) ** 2 - log_sigma)
prior = ProductDistribution(
Normal(loc=0.0, scale=2.0, name="mu"),
Normal(loc=0.0, scale=0.5, name="log_sigma"),
)
model = SimpleModel(prior, GaussianLikelihood())
data = jnp.array([1.0, 0.8, 1.2, 0.9, 1.1])
print(model)
SimpleModel(prior=ProductDistribution, likelihood=GaussianLikelihood)
inference_method_registry.check(model, data, method=NAME) asks one
specific method whether it can run on this model. We iterate over every
registered method:
print(f"{'method':<20s} {'feasible':<10s} why")
print("-" * 72)
for name in inference_method_registry.list_methods():
info = inference_method_registry.check(model, data, method=name)
reason = info.description or ("(native JAX backend)" if info.feasible else "")
print(f"{name:<20s} {str(info.feasible):<10s} {reason}")
method feasible why ------------------------------------------------------------------------ nutpie_nuts False Requires StanModel or PyMCModel cmdstan_nuts False Requires StanModel pymc_nuts False Requires PyMCModel tfp_nuts True (native JAX backend) tfp_hmc True (native JAX backend) tfp_rwmh True (native JAX backend) blackjax_sgld False blackjax_sgld requires model.likelihood to satisfy ConditionallyIndependentLikelihood; got GaussianLikelihood. blackjax_sghmc False blackjax_sghmc requires model.likelihood to satisfy ConditionallyIndependentLikelihood; got GaussianLikelihood. pymc_advi False Requires PyMCModel sbijax_smcabc False Requires SimpleGenerativeModel
The three tfp_* methods (NUTS, HMC, RWMH) are feasible because they
only require SupportsLogProb — a property SimpleModel satisfies
through its prior + likelihood composition. The other five are gated
behind either a specific model subclass (StanModel, PyMCModel,
SimpleGenerativeModel) or an optional external dependency.
3. Running inference¶
Without method=, condition_on picks the highest-priority feasible
method. For a SupportsLogProb model that's the native JAX NUTS:
posterior = condition_on(model, data, num_results=500)
print(type(posterior).__name__)
print(f"posterior fields: {posterior.fields}")
print(f"posterior n_samples: {posterior.n}")
mu_mean = float(jnp.mean(jnp.asarray(posterior.samples["mu"])))
logsig_mean = float(jnp.mean(jnp.asarray(posterior.samples["log_sigma"])))
print(f"posterior mean: mu≈{mu_mean:.3f} log_sigma≈{logsig_mean:.3f}")
print(f"data mean for reference: {float(data.mean()):.3f}")
ApproximateDistribution
posterior fields: ('mu', 'log_sigma')
posterior n_samples: 500
posterior mean: mu≈0.982 log_sigma≈-0.796
data mean for reference: 1.000
4. Pinning a method¶
Pass method= to force a specific backend. The registry first calls
check(...) and raises a clear error if the method can't run on the
given model — cheaper than waiting for the backend to fail at startup.
# TFP RWMH instead of the default NUTS
rwmh_posterior = condition_on(model, data, method='tfp_rwmh', num_results=500)
rwmh_mu = float(jnp.mean(jnp.asarray(rwmh_posterior.samples['mu'])))
rwmh_logsig = float(jnp.mean(jnp.asarray(rwmh_posterior.samples['log_sigma'])))
print(f'RWMH posterior mean: mu≈{rwmh_mu:.3f} log_sigma≈{rwmh_logsig:.3f}')
RWMH posterior mean: mu≈0.991 log_sigma≈-0.617
Asking for a method the model doesn't satisfy surfaces the reason immediately:
try:
condition_on(model, data, method="cmdstan_nuts", num_results=100)
except (TypeError, ValueError, RuntimeError) as e:
print(f"refused: {type(e).__name__}: {str(e)[:120]}")
refused: TypeError: Method 'cmdstan_nuts' is not applicable: Requires StanModel
5. External backends — when do you reach for each?¶
The backends above (cmdstan_nuts, pymc_nuts, pymc_advi,
nutpie_nuts, sbijax_*) plug into condition_on through the same
interface but require different model shapes and runtime dependencies.
A quick decision table:
| Backend | Optional install | Model shape | Reach for it when |
|---|---|---|---|
tfp_nuts (default) |
— | SupportsLogProb |
General-purpose, differentiable likelihood |
tfp_hmc |
— | SupportsLogProb |
You want HMC without NUTS's adaptive step size |
tfp_rwmh |
— | SupportsLogProb |
Likelihood is differentiable but gradient is expensive; or discrete parameters |
cmdstan_nuts |
cmdstanpy |
StanModel |
You already have a Stan model; want Stan's battle-tested NUTS |
pymc_nuts / pymc_advi |
pymc |
PyMCModel |
Model is a PyMC model; you want NUTS or ADVI |
nutpie_nuts |
nutpie |
StanModel or PyMCModel |
Fast compiled NUTS on a Stan- or PyMC-defined model |
sbijax_smcabc |
sbijax |
SimpleGenerativeModel |
Likelihood is intractable; you can only simulate |
sbi_learn_conditional |
sbijax |
SimpleGenerativeModel |
Amortised NPE — simulate once, query many observations |
sbi_learn_likelihood |
sbijax |
SimpleGenerativeModel |
NLE — learn a surrogate likelihood then run MCMC |
When the likelihood is tractable, start with the default — tfp_nuts
is the most general-purpose. Move to cmdstan_nuts or pymc_nuts if
you're already working in that ecosystem; use nutpie_nuts if you
need the speed of a compiled backend.
When the likelihood is intractable (stochastic simulator, no
closed form), you give up SupportsLogProb and move to the SBI
backends, which only need to simulate. See the
Flexible inference tutorial
for an end-to-end walk-through of sbijax_smcabc,
sbi_learn_conditional (NPE), and sbi_learn_likelihood (NLE) on a
stochastic Ricker model.
6. Priorities and overrides¶
Registered methods have an effective priority. The registry picks the
highest-priority feasible method by default — tfp_nuts beats
tfp_hmc beats tfp_rwmh when all three are available. To globally
prefer a backend (e.g. always use nutpie_nuts when you have Stan
available), call inference_method_registry.set_priorities(...) at
startup.
# Example: inspect the effective priority ordering (higher → preferred)
ranked = sorted(
inference_method_registry.list_methods(),
key=lambda name: -inference_method_registry.get_method(name).priority,
)
for name in ranked:
m = inference_method_registry.get_method(name)
print(f" {name:<20s} priority={m.priority}")
nutpie_nuts priority=85 cmdstan_nuts priority=82 pymc_nuts priority=81 tfp_nuts priority=75 tfp_hmc priority=65 tfp_rwmh priority=55 blackjax_sgld priority=45 blackjax_sghmc priority=42 pymc_advi priority=25 sbijax_smcabc priority=5
Summary¶
condition_on(model, data)dispatches throughinference_method_registry; the same call site works for native JAX NUTS, Stan, PyMC, nutpie, and simulation-based inference.inference_method_registry.list_methods()lists every registered backend;check(model, data, method=NAME)tells you whether one can run on a given model.- Passing
method="..."pins the choice; omitting it selects the highest-priority feasible backend. - Optional backends (Stan, PyMC, nutpie, sbijax) register themselves
when their dependency is installed. Tractable-likelihood models
should default to the native JAX NUTS (
tfp_nuts); intractable likelihoods move to the SBI backends via aSimpleGenerativeModel.