Modeling and inference¶
Models, likelihoods, inference methods, iterative transformations, and predictive checks. Registering a new inference method is documented under Extending ProbPipe → Custom inference methods.
Models¶
ProbabilisticModel(*, name)
¶
Bases: Distribution[T]
Abstract base for probabilistic programming models.
A ProbabilisticModel is a first-class Distribution
that also supports named components. Subclasses declare their
parameter and data components and optionally provide _log_prob,
_sample, etc.
Conditioning is handled by the inference method registry — call
condition_on(model, data) rather than model._condition_on().
Source code in probpipe/core/_distribution_base.py
SimpleModel(prior, likelihood, *, name=None)
¶
Bases: ProbabilisticModel[tuple[P, D]], SupportsLogProb
Probabilistic model as a joint distribution over (parameters, data).
A SimpleModel[P, D] is a Distribution[tuple[P, D]] — the joint
distribution \(p(\theta, y) = p(\theta) \, p(y \mid \theta)\).
The prior must support SupportsLogProb so that the joint
log-density is always computable.
Named components: merged from the prior's record_template
and the likelihood's data_template when both are available.
For example, a GLM model might have
fields == ("X", "intercept", "slope", "y").
Falls back to ("parameters", "data") when templates are absent.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior
|
Distribution[P] that supports SupportsLogProb
|
Prior distribution over model parameters. |
required |
likelihood
|
Likelihood[P, D]
|
Must have a |
required |
name
|
str or None
|
Model name for provenance. |
None
|
Source code in probpipe/modeling/_simple.py
SimpleGenerativeModel(prior, likelihood, *, name=None)
¶
Bases: ProbabilisticModel[tuple[P, D]], SupportsSampling
Generative probabilistic model as a joint over (parameters, data).
A SimpleGenerativeModel[P, D] pairs a prior that supports
sampling with a GenerativeLikelihood that can generate
synthetic data given parameters. Unlike SimpleModel, this
does not require a log-density — making it suitable for
simulation-based inference (SBI) and approximate Bayesian
computation (ABC) methods.
Sampling: implements SupportsSampling
via the obvious joint draw — sample parameters from the prior, then
call likelihood.generate_data(params, ...) for the data.
Named components: "parameters" (the prior) and "data"
(the generative likelihood). Only "data" is conditionable.
Conditioning: Use condition_on(model, data) — the inference
method registry auto-selects an appropriate SBI or ABC method.
SimpleGenerativeModel does not implement SupportsConditioning
directly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior
|
SupportsSampling[P]
|
Prior distribution over model parameters. Must support sampling. |
required |
likelihood
|
GenerativeLikelihood[P, D]
|
Must have a |
required |
name
|
str or None
|
Model name for provenance. |
None
|
Source code in probpipe/modeling/_simple_generative.py
Likelihoods¶
Likelihood
¶
Bases: Protocol
Protocol for computing log-likelihood of data given parameters.
Generic in P (parameter type) and D (data type).
Any class that defines log_likelihood(params, data) -> float
satisfies this protocol.
ConditionallyIndependentLikelihood
¶
Bases: Likelihood[P, D], Protocol
Likelihood whose observations are conditionally independent given the parameters.
Formally, for observations y_1, ..., y_N the joint log-density
factorises into a sum of per-observation log-densities:
.. math::
\log p(y_1, \ldots, y_N \mid \theta)
= \sum_{i=1}^N \log p(y_i \mid \theta).
The "conditionally" refers to the conditioning on the parameters
θ: the y_i are independent given θ, not marginally.
For regression-style likelihoods each datum carries a covariate
x_i that the per-observation density depends on; the
factorisation then reads Σ_i log p(y_i | x_i, θ), with the
covariates treated as fixed inputs rather than random variables.
This is the "conditionally independent" case rather than the
stricter "i.i.d." (where every p(y_i | θ) is the same density).
Required by MinibatchedDistribution for
stochastic-gradient inference, and useful independently for
held-out predictive log-likelihoods, leave-one-out
cross-validation, and PSIS-LOO.
Implementations expose per_datum_log_likelihood; the helper
_default_per_datum_log_likelihood provides a length-1-batch
fallback for likelihoods that want a default rather than an
efficient override.
per_datum_log_likelihood(params, datum)
¶
Log-density of a single datum given parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
P
|
Model parameters. |
required |
datum
|
Any
|
One observation. The exact shape depends on the data format
the likelihood was constructed against — for a regression
model that's a single row |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar log-density of the datum under |
Source code in probpipe/modeling/_likelihood.py
GenerativeLikelihood
¶
Bases: Protocol
Protocol for generating synthetic data given parameters.
Generic in P (parameter type) and D (data type).
Any class that defines generate_data(params, n_samples, *, key) -> D
satisfies this protocol.
generate_data(params, n_samples, *, key=None)
¶
Generate n_samples synthetic data points from params.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
P
|
Model parameters. |
required |
n_samples
|
int
|
Number of data points to generate. |
required |
key
|
PRNGKey or None
|
JAX PRNG key for reproducible generation. |
None
|
Source code in probpipe/modeling/_likelihood.py
GLMLikelihood(family, x=None, *, fit_intercept=True, seed=0)
¶
Wraps a TFP GLM family + design matrix into a Likelihood and GenerativeLikelihood.
Two accepted data forms:
data = Record(X=X_covariates, y=y_observed)— both fields explicit; the canonical form.Xis the covariate matrix only; do not include a constant column for the intercept.data = y_observed(a bare response array) whenXwas supplied at construction time — the construction-timeXis used.
Joint bootstrapping of covariates and response uses the Record form::
Xy = Record(X=X_covariates, y=y_observed)
bootstrap = BootstrapReplicateDistribution(EmpiricalDistribution(Xy))
bagged = condition_on(model, bootstrap, n_broadcast_samples=16)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
family
|
ExponentialFamily
|
TFP GLM family (e.g., |
required |
x
|
array - like or None
|
Default covariate matrix of shape |
None
|
fit_intercept
|
bool
|
When True, the likelihood expects |
True
|
seed
|
int
|
Random seed for data generation. |
0
|
Source code in probpipe/modeling/_glm.py
data_template
property
¶
Named structure of GLM data: X (design matrix) and y (response).
log_likelihood(params, data)
¶
Log-likelihood: sum of per-observation log-probs.
params and data can be raw arrays or Record objects.
When data is Record(X=..., y=...), both the covariate
matrix and response are extracted. The linear predictor
eta = intercept + X @ slopes is computed via
_linear_predictor so the fit_intercept convention is
respected uniformly across the public methods.
Source code in probpipe/modeling/_glm.py
per_datum_log_likelihood(params, datum)
¶
Log-density of a single observation given parameters.
Satisfies ConditionallyIndependentLikelihood.
Evaluates family.log_prob(y_i, x_i @ params) directly on a
scalar response, bypassing the length-1-batch reshape that the
default fallback (log_likelihood(params, datum[None, ...]))
would add. The saved per-call overhead matters when this method
is called inside a stochastic-gradient inner loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Array or Record
|
Coefficient vector of shape |
required |
datum
|
Record
|
|
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
Source code in probpipe/modeling/_glm.py
generate_data(params, n_samples, *, key=None)
¶
Generate synthetic data from the GLM.
Uses the stored design matrix (first n_samples rows).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Array or Record
|
Parameter vector of shape |
required |
n_samples
|
int
|
Number of observations to generate (per batch element). |
required |
key
|
PRNGKey
|
JAX PRNG key. |
None
|
Source code in probpipe/modeling/_glm.py
IncrementalConditioner(prior, likelihood, *, condition_fn=None, **condition_kwargs)
¶
Bases: Module
Iteratively update a posterior by conditioning on data batches.
Maintains a current posterior (initially the prior) and provides
update for single-batch conditioning and
update_all for multi-batch iteration. Both methods update
the internal state so that subsequent calls continue from the
latest posterior.
The step property exposes the underlying step function
for direct use with iterate
and combinators.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior
|
Distribution[P]
|
Initial prior distribution over model parameters. |
required |
likelihood
|
Likelihood[P, D]
|
Likelihood object. |
required |
condition_fn
|
callable or None
|
Conditioning callable; defaults to the global |
None
|
**condition_kwargs
|
Any
|
Extra keyword arguments forwarded to condition_fn on every
call (e.g., |
{}
|
Examples:
::
conditioner = IncrementalConditioner(prior, likelihood)
# Single-batch update (stateful):
posterior1 = conditioner.update(data=batch1)
posterior2 = conditioner.update(data=batch2)
conditioner.curr_posterior # is posterior2
# Multi-batch update (stateful, returns sequence):
dists = conditioner.update_all(data_batches=[batch3, batch4])
conditioner.curr_posterior # is dists[-1]
# Functional escape hatch (for combinators):
dists = iterate(conditioner.step, prior, all_batches)
Source code in probpipe/modeling/_likelihood.py
curr_posterior
property
¶
The current posterior (initially the prior).
step
property
¶
The underlying step function, for use with iterate.
update(data=None, **kwargs)
¶
Condition on new data, updating the current posterior.
Data can be passed positionally, as data=, or as named
keyword arguments that are bundled into a Record::
conditioner.update(y_obs) # positional / data=
conditioner.update(X=new_X, y=new_y) # named kwargs → Record
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
D
|
New observed data to condition on. |
None
|
**kwargs
|
Named data fields, bundled into a |
{}
|
Returns:
| Type | Description |
|---|---|
Distribution[P]
|
The updated posterior distribution. |
Source code in probpipe/modeling/_likelihood.py
update_all(data_batches)
¶
Condition on multiple data batches sequentially.
Calls iterate(self.step, self.curr_posterior, data_batches)
and updates the internal state to the final posterior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_batches
|
Iterable[D]
|
Sequence of data batches to condition on. |
required |
Returns:
| Type | Description |
|---|---|
list[Distribution[P]]
|
Sequence |
Source code in probpipe/modeling/_likelihood.py
Inference methods¶
condition_on dispatches through the
inference-method registry: methods are tried in descending priority order
and the first whose check() returns feasible=True runs. Pass
method="<name>" to override the auto-selection;
inference_method_registry.set_priorities(...) reorders the table at
runtime.
Priorities follow a single-axis convention: values above 50 mark exact
methods, values in (0, 50] mark inexact methods, and 0 means
opt-in only (selectable by name but skipped during auto-dispatch). The
five-axis selection criteria and the tier ranges contributors should use
when choosing a number for a new method are documented under
Extending ProbPipe → Setting priority for a new method.
Built-in methods:
| Name | Priority | Requires | Backend |
|---|---|---|---|
nutpie_nuts |
85 | StanModel or PyMCModel + nutpie |
nutpie |
cmdstan_nuts |
82 | StanModel + cmdstanpy |
CmdStan |
pymc_nuts |
81 | PyMCModel + pymc |
PyMC |
tfp_nuts |
75 | SupportsLogProb + JAX-traceable |
TFP |
tfp_hmc |
65 | SupportsLogProb + JAX-traceable |
TFP |
tfp_rwmh |
55 | SupportsLogProb |
TFP |
blackjax_sgld |
45 | SimpleModel + ConditionallyIndependentLikelihood + batch_size= |
BlackJAX |
blackjax_sghmc |
42 | SimpleModel + ConditionallyIndependentLikelihood + batch_size= |
BlackJAX |
pymc_advi |
25 | PyMCModel + pymc |
PyMC |
sbijax_smcabc |
5 | SimpleGenerativeModel + sbijax |
sbijax |
ApproximateDistribution(chains, *, weights=None, name=None, record_template=None)
¶
Bases: RecordEmpiricalDistribution
Empirical distribution with chain structure.
Stores per-chain sample arrays for chain-structured access via
draws. Algorithm metadata, sample statistics, warmup
samples, and the ArviZ InferenceData object live in
dist.auxiliary (a DataTree on the Distribution base class),
not as attributes of this class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
chains
|
list of Array
|
Per-chain sample arrays, each of shape |
required |
weights
|
array-like, :class:`~probpipe.Weights`, or None
|
Optional per-sample importance weights (across all chains). |
None
|
name
|
str or None
|
Distribution name for provenance. |
None
|
Notes
When record_template is multi-field, __init__ slices the
concatenated chain into per-top-level-field arrays so
fields, event_shapes, dtypes,
_mean / _variance, and the public ops
(mean(post) / variance(post)) all return Records whose
keys match fields. Nested RecordTemplate fields are
stored as a flat (n, nested_flat_size) array under the
top-level field name; the nested structure is recoverable via
record_template[field] and via draws, which walks
the full template (including nesting) using the original
per-chain samples.
Source code in probpipe/inference/_approximate_distribution.py
chains
property
¶
Per-chain sample arrays.
num_chains
property
¶
Number of chains.
num_draws
property
¶
Number of draws per chain (assumes equal-length chains).
algorithm
property
¶
Name of the inference algorithm (read from provenance).
inference_data
property
¶
The auxiliary DataTree, for ArviZ compatibility.
Alias for self.auxiliary. Use ArviZ functions for diagnostics::
import arviz as az
az.summary(posterior.inference_data)
warmup_samples
property
¶
Per-chain warmup samples extracted from auxiliary data.
draws(chain=None, *, include_warmup=False)
¶
Access draws, optionally named via record_template.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
chain
|
int or None
|
Chain index. If |
None
|
include_warmup
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
Array or Record
|
If |
Source code in probpipe/inference/_approximate_distribution.py
rwmh(dist, data=None, *, log_prob_fn=None, num_results=1000, num_warmup=500, num_chains=1, step_size=0.1, init=None, random_seed=0)
¶
Gradient-free random-walk Metropolis-Hastings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist
|
SupportsUnnormalizedLogProb
|
Distribution providing |
required |
data
|
array - like or None
|
Observed data. |
None
|
log_prob_fn
|
callable or None
|
|
None
|
num_results
|
int
|
MCMC tuning parameters. |
1000
|
num_warmup
|
int
|
MCMC tuning parameters. |
1000
|
num_chains
|
int
|
MCMC tuning parameters. |
1000
|
step_size
|
int
|
MCMC tuning parameters. |
1000
|
random_seed
|
int
|
MCMC tuning parameters. |
1000
|
init
|
array - like or None
|
Initial chain state. Tries |
None
|
Returns:
| Type | Description |
|---|---|
ApproximateDistribution
|
Posterior samples with chain structure and auxiliary DataTree. |
Source code in probpipe/inference/_rwmh.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | |
condition_on_nutpie(model, data=None, *, num_results=1000, num_warmup=500, num_chains=4, random_seed=0, **kwargs)
¶
MCMC sampling via nutpie (Rust-based NUTS).
Accepts a StanModel or
PyMCModel.
Source code in probpipe/inference/_nutpie.py
sbi_learn_conditional(*args, **kwargs)
¶
sbi_learn_likelihood(*args, **kwargs)
¶
MinibatchedDistribution(prior, likelihood, data, batch_size, *, with_replacement=False, name=None)
¶
Bases: RandomMeasure[Record], SupportsRandomUnnormalizedLogProb
Random measure realised by uniform minibatching.
A draw from this measure is a fixed-minibatch target — an
unnormalized stochastic surrogate of the full-data unnormalized
log-posterior, rescaled by N / b so the gradient is an
unbiased estimator. Not a posterior in the strict (normalized)
sense.
For a model with prior :math:p(\theta) and likelihood
:math:p(\mathcal{D} \mid \theta) = \prod_i p(d_i \mid \theta),
a draw's unnormalized log-density is
.. math::
\log \tilde{D}_B(\theta) = \log p(\theta)
+ \frac{N}{b}
\sum_{d \in B}
\log p(d \mid \theta).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior
|
SupportsLogProb
|
Prior distribution over parameters; provides the log-prior
term :math: |
required |
likelihood
|
ConditionallyIndependentLikelihood
|
Likelihood that factorises as
:math: |
required |
data
|
array-like, Record, or RecordArray
|
Observed data. Indexed along its leading axis to draw
minibatches; must have leading-axis length |
required |
batch_size
|
int
|
Minibatch size :math: |
required |
with_replacement
|
bool
|
Sample minibatch indices with replacement. Default is
without-replacement (uniform permutation, take first |
False
|
name
|
str
|
Distribution name. |
None
|
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
ValueError
|
If |
Source code in probpipe/inference/_minibatch.py
dataset_size
property
¶
Total number of observations in the dataset (len(data)).
Named dataset_size rather than .n to avoid colliding
with STYLE_GUIDE §1.9's "how many items does this hold?"
semantics — MinibatchedDistribution is not a
finite-sample distribution; it doesn't hold a finite
collection of realisations.
batch_size
property
¶
Minibatch size :math:b.
with_replacement
property
¶
Whether minibatch indices are drawn with replacement.
prior
property
¶
The prior distribution over parameters.
likelihood
property
¶
The conditionally-independent likelihood.
data
property
¶
The full dataset (not the minibatched view).
inference_method_registry = MethodRegistry()
module-attribute
¶
Iterative transformations¶
Step functions folded over inputs by iterate, with with_conversion and
with_resampling as step-function wrappers.
iterate(step_fn, initial, inputs, *, callback=None)
¶
Fold a step function over inputs, accumulating a distribution sequence.
Starting from initial, applies step_fn(dist, inp) for each
element of inputs, collecting the resulting distributions into a
list. The returned list includes the initial distribution at
index 0.
Provenance is automatically attached to each output distribution (linking it to the previous distribution) unless the step function has already set provenance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_fn
|
callable
|
|
required |
initial
|
Distribution[T]
|
The starting distribution. |
required |
inputs
|
Iterable[S]
|
Sequence of inputs to pass to the step function. |
required |
callback
|
callable or None
|
Called as |
None
|
Returns:
| Type | Description |
|---|---|
list[Distribution[T]]
|
The full sequence: |
Source code in probpipe/core/transition.py
with_conversion(step_fn, target_type, **convert_kwargs)
¶
Wrap a step function to convert its output after each step.
After calling step_fn, converts the resulting distribution to
target_type using ProbPipe's standard from_distribution
operation (which dispatches through the converter registry).
The pre-conversion distribution is accessible via the converted
distribution's provenance parents (set by the converter).
This is useful when the step function produces samples (e.g., MCMC output) but the next iteration needs a parametric distribution as input.
The returned wrapper is a WorkflowFunction, so it appears
as a node in the ProbPipe workflow DAG.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_fn
|
callable
|
The underlying step function. |
required |
target_type
|
type
|
Distribution type to convert to (e.g., |
required |
**convert_kwargs
|
Any
|
Extra keyword arguments passed to |
{}
|
Returns:
| Type | Description |
|---|---|
WorkflowFunction
|
A new step function with the same call signature. |
Source code in probpipe/core/transition.py
with_resampling(step_fn, *, ess_threshold=0.5, seed=0)
¶
Wrap a step function to resample when particle weights degenerate.
After calling step_fn, if the result is an
EmpiricalDistribution with
ESS / N < ess_threshold, performs multinomial resampling to
produce equally-weighted particles.
The pre-resampling ESS is stored in provenance metadata of the
resampled distribution (dist.source.metadata["ess"]) since
this information would otherwise be lost after resampling to
uniform weights.
The returned wrapper is a WorkflowFunction, so it appears
as a node in the ProbPipe workflow DAG.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_fn
|
callable
|
The underlying step function. |
required |
ess_threshold
|
float
|
Resample when |
0.5
|
seed
|
int
|
Base random seed; combined with a call counter for deterministic reproducibility. |
0
|
Returns:
| Type | Description |
|---|---|
WorkflowFunction
|
A new step function with the same call signature. |
Notes
This API is likely to evolve as typical use cases become clearer.
A future direction is a SupportsResampling protocol that would
decouple this combinator from the concrete
EmpiricalDistribution type.
Source code in probpipe/core/transition.py
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | |
Predictive checks¶
predictive_check(distribution, generative_likelihood, test_fn, observed_data=None, *, n_samples=None, n_replications=500, key=None)
¶
Predictive check — works as both prior and posterior check.
Draws parameter samples from distribution, generates replicated data via generative_likelihood, and computes test_fn on each replicate.
When observed_data is provided, also computes test_fn on the observed data and returns a calibration p-value, making this a posterior predictive check. Without observed_data, this is a prior predictive check — useful for understanding the implications of the prior.
When generate_data accepts a key keyword argument, all
replications are generated in a single vectorized call (by passing
a batch of parameter vectors), giving a large speedup. The test
function is then applied via jax.vmap when possible, with an
automatic fallback to a Python loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
distribution
|
Distribution[P]
|
Prior or posterior to sample parameters from. |
required |
generative_likelihood
|
GenerativeLikelihood[P, D]
|
Must have |
required |
test_fn
|
Callable[[D], float]
|
Test statistic mapping data to a scalar. |
required |
observed_data
|
D or None
|
If provided, compute the observed test statistic and p-value. |
None
|
n_samples
|
int
|
Number of observations per replicated dataset. Required if
observed_data is not provided; otherwise defaults to
|
None
|
n_replications
|
int
|
Number of replicated datasets to generate. |
500
|
key
|
PRNGKey
|
JAX PRNG key. Auto-generated if |
None
|
Returns:
| Type | Description |
|---|---|
dict
|
Always contains:
When observed_data is provided, also contains:
|
Source code in probpipe/validation/_predictive_check.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | |