Skip to content

Modeling and inference

Models, likelihoods, inference methods, iterative transformations, and predictive checks. Registering a new inference method is documented under Extending ProbPipe → Custom inference methods.

Models

ProbabilisticModel(*, name)

Bases: Distribution[T]

Abstract base for probabilistic programming models.

A ProbabilisticModel is a first-class Distribution that also supports named components. Subclasses declare their parameter and data components and optionally provide _log_prob, _sample, etc.

Conditioning is handled by the inference method registry — call condition_on(model, data) rather than model._condition_on().

Source code in probpipe/core/_distribution_base.py
def __init__(self, *, name: str):
    if not isinstance(name, str) or not name:
        raise TypeError(
            f"{type(self).__name__} requires a non-empty name= argument"
        )
    self._name = name

fields abstractmethod property

Names of all model components (parameters + data).

parameter_names abstractmethod property

Names of the model's parameters (latent variables).

__getitem__(key) abstractmethod

Access a component by name.

Source code in probpipe/modeling/_base.py
@abstractmethod
def __getitem__(self, key: str) -> Any:
    """Access a component by name."""
    ...

SimpleModel(prior, likelihood, *, name=None)

Bases: ProbabilisticModel[tuple[P, D]], SupportsLogProb

Probabilistic model as a joint distribution over (parameters, data).

A SimpleModel[P, D] is a Distribution[tuple[P, D]] — the joint distribution \(p(\theta, y) = p(\theta) \, p(y \mid \theta)\). The prior must support SupportsLogProb so that the joint log-density is always computable.

Named components: merged from the prior's record_template and the likelihood's data_template when both are available. For example, a GLM model might have fields == ("X", "intercept", "slope", "y"). Falls back to ("parameters", "data") when templates are absent.

Parameters:

Name Type Description Default
prior Distribution[P] that supports SupportsLogProb

Prior distribution over model parameters.

required
likelihood Likelihood[P, D]

Must have a log_likelihood(params, data) method.

required
name str or None

Model name for provenance.

None
Source code in probpipe/modeling/_simple.py
def __init__(
    self,
    prior: SupportsLogProb[P],
    likelihood: Likelihood[P, D],
    *,
    name: str | None = None,
):
    # Type-annotated as ``SupportsLogProb[P]`` so static type
    # checkers catch a wrong-type prior at the call site. The
    # isinstance check remains as a backstop for callers who
    # bypass the type system.
    if not isinstance(prior, SupportsLogProb):
        raise TypeError(
            f"SimpleModel requires a prior that supports SupportsLogProb, "
            f"got {type(prior).__name__}"
        )
    self._prior = prior
    self._likelihood = likelihood
    self._name_str = name

    # Build merged record_template: prior params + likelihood data fields.
    # This makes fields include both parameter and data names,
    # so condition_on can use component names as the sole signal for
    # splitting data kwargs from inference kwargs.
    prior_tpl = prior.record_template
    data_tpl = getattr(likelihood, 'data_template', None)
    # Convert legacy Record templates to RecordTemplate
    if isinstance(data_tpl, Record) and not isinstance(data_tpl, RecordTemplate):
        data_tpl = RecordTemplate.from_record(data_tpl)
    if prior_tpl is not None and data_tpl is not None:
        overlap = set(prior_tpl.fields) & set(data_tpl.fields)
        if overlap:
            raise ValueError(
                f"Parameter and data field names overlap: {overlap}"
            )
        merged = {}
        for f in prior_tpl.fields:
            merged[f] = prior_tpl[f]
        for f in data_tpl.fields:
            merged[f] = data_tpl[f]
        self._record_template = RecordTemplate(merged)
    elif prior_tpl is not None:
        self._record_template = prior_tpl

prior property

The prior distribution over parameters.

likelihood property

The likelihood function log p(D | params).

SimpleGenerativeModel(prior, likelihood, *, name=None)

Bases: ProbabilisticModel[tuple[P, D]], SupportsSampling

Generative probabilistic model as a joint over (parameters, data).

A SimpleGenerativeModel[P, D] pairs a prior that supports sampling with a GenerativeLikelihood that can generate synthetic data given parameters. Unlike SimpleModel, this does not require a log-density — making it suitable for simulation-based inference (SBI) and approximate Bayesian computation (ABC) methods.

Sampling: implements SupportsSampling via the obvious joint draw — sample parameters from the prior, then call likelihood.generate_data(params, ...) for the data.

Named components: "parameters" (the prior) and "data" (the generative likelihood). Only "data" is conditionable.

Conditioning: Use condition_on(model, data) — the inference method registry auto-selects an appropriate SBI or ABC method. SimpleGenerativeModel does not implement SupportsConditioning directly.

Parameters:

Name Type Description Default
prior SupportsSampling[P]

Prior distribution over model parameters. Must support sampling.

required
likelihood GenerativeLikelihood[P, D]

Must have a generate_data(params, n_samples, *, key) method.

required
name str or None

Model name for provenance.

None
Source code in probpipe/modeling/_simple_generative.py
def __init__(
    self,
    prior: SupportsSampling[P],
    likelihood: GenerativeLikelihood[P, D],
    *,
    name: str | None = None,
):
    # Type-annotated as ``SupportsSampling[P]`` so static type
    # checkers catch a wrong-type prior at the call site. The
    # isinstance check remains as a backstop for callers who
    # bypass the type system.
    if not isinstance(prior, SupportsSampling):
        raise TypeError(
            f"SimpleGenerativeModel requires a prior that supports SupportsSampling, "
            f"got {type(prior).__name__}"
        )
    if not isinstance(likelihood, GenerativeLikelihood):
        raise TypeError(
            f"SimpleGenerativeModel requires a GenerativeLikelihood, "
            f"got {type(likelihood).__name__}"
        )
    self._prior = prior
    self._likelihood = likelihood
    self._name_str = name

prior property

The prior distribution over parameters.

likelihood property

The generative likelihood (provides generate_data).

Likelihoods

Likelihood

Bases: Protocol

Protocol for computing log-likelihood of data given parameters.

Generic in P (parameter type) and D (data type). Any class that defines log_likelihood(params, data) -> float satisfies this protocol.

ConditionallyIndependentLikelihood

Bases: Likelihood[P, D], Protocol

Likelihood whose observations are conditionally independent given the parameters.

Formally, for observations y_1, ..., y_N the joint log-density factorises into a sum of per-observation log-densities:

.. math::

\log p(y_1, \ldots, y_N \mid \theta)
    = \sum_{i=1}^N \log p(y_i \mid \theta).

The "conditionally" refers to the conditioning on the parameters θ: the y_i are independent given θ, not marginally. For regression-style likelihoods each datum carries a covariate x_i that the per-observation density depends on; the factorisation then reads Σ_i log p(y_i | x_i, θ), with the covariates treated as fixed inputs rather than random variables. This is the "conditionally independent" case rather than the stricter "i.i.d." (where every p(y_i | θ) is the same density).

Required by MinibatchedDistribution for stochastic-gradient inference, and useful independently for held-out predictive log-likelihoods, leave-one-out cross-validation, and PSIS-LOO.

Implementations expose per_datum_log_likelihood; the helper _default_per_datum_log_likelihood provides a length-1-batch fallback for likelihoods that want a default rather than an efficient override.

per_datum_log_likelihood(params, datum)

Log-density of a single datum given parameters.

Parameters:

Name Type Description Default
params P

Model parameters.

required
datum Any

One observation. The exact shape depends on the data format the likelihood was constructed against — for a regression model that's a single row (x_i, y_i); for a scalar response, it's a single value. Subclasses define the shape.

required

Returns:

Type Description
Array

Scalar log-density of the datum under params.

Source code in probpipe/modeling/_likelihood.py
def per_datum_log_likelihood(self, params: P, datum: Any) -> Any:
    """Log-density of a single datum given parameters.

    Parameters
    ----------
    params : P
        Model parameters.
    datum : Any
        One observation. The exact shape depends on the data format
        the likelihood was constructed against — for a regression
        model that's a single row ``(x_i, y_i)``; for a scalar
        response, it's a single value. Subclasses define the shape.

    Returns
    -------
    Array
        Scalar log-density of the datum under ``params``.
    """
    ...

GenerativeLikelihood

Bases: Protocol

Protocol for generating synthetic data given parameters.

Generic in P (parameter type) and D (data type). Any class that defines generate_data(params, n_samples, *, key) -> D satisfies this protocol.

generate_data(params, n_samples, *, key=None)

Generate n_samples synthetic data points from params.

Parameters:

Name Type Description Default
params P

Model parameters.

required
n_samples int

Number of data points to generate.

required
key PRNGKey or None

JAX PRNG key for reproducible generation.

None
Source code in probpipe/modeling/_likelihood.py
def generate_data(self, params: P, n_samples: int, *, key: PRNGKey | None = None) -> D:
    """Generate ``n_samples`` synthetic data points from ``params``.

    Parameters
    ----------
    params : P
        Model parameters.
    n_samples : int
        Number of data points to generate.
    key : PRNGKey or None
        JAX PRNG key for reproducible generation.
    """
    ...

GLMLikelihood(family, x=None, *, fit_intercept=True, seed=0)

Wraps a TFP GLM family + design matrix into a Likelihood and GenerativeLikelihood.

Two accepted data forms:

  • data = Record(X=X_covariates, y=y_observed) — both fields explicit; the canonical form. X is the covariate matrix only; do not include a constant column for the intercept.
  • data = y_observed (a bare response array) when X was supplied at construction time — the construction-time X is used.

Joint bootstrapping of covariates and response uses the Record form::

Xy = Record(X=X_covariates, y=y_observed)
bootstrap = BootstrapReplicateDistribution(EmpiricalDistribution(Xy))
bagged = condition_on(model, bootstrap, n_broadcast_samples=16)

Parameters:

Name Type Description Default
family ExponentialFamily

TFP GLM family (e.g., Poisson(), Bernoulli(), NegativeBinomial()).

required
x array - like or None

Default covariate matrix of shape (n, p). If None, must be provided per-call via data=Record(X=..., y=...). Should contain only the covariates — the intercept is fit separately when fit_intercept=True.

None
fit_intercept bool

When True, the likelihood expects params to flatten to (intercept, *slopes) of length p + 1 and computes eta = intercept + X @ slopes. When False, params flattens to length p and the likelihood computes eta = X @ params directly — useful when the user wants to carry the intercept as a constant column in X themselves (the classical "model matrix" convention).

True
seed int

Random seed for data generation.

0
Source code in probpipe/modeling/_glm.py
def __init__(
    self,
    family: tfp_glm.ExponentialFamily,
    x: ArrayLike | None = None,
    *,
    fit_intercept: bool = True,
    seed: int = 0,
):
    self.family = family
    if x is not None:
        self._x = jnp.atleast_2d(jnp.asarray(x))
        if self._x.ndim == 2 and self._x.shape[0] == 1 and self._x.shape[1] > 1:
            self._x = self._x.T
    else:
        self._x = None
    self._fit_intercept = bool(fit_intercept)
    self._key = jax.random.PRNGKey(seed)

data_template property

Named structure of GLM data: X (design matrix) and y (response).

log_likelihood(params, data)

Log-likelihood: sum of per-observation log-probs.

params and data can be raw arrays or Record objects. When data is Record(X=..., y=...), both the covariate matrix and response are extracted. The linear predictor eta = intercept + X @ slopes is computed via _linear_predictor so the fit_intercept convention is respected uniformly across the public methods.

Source code in probpipe/modeling/_glm.py
def log_likelihood(self, params: ArrayLike | Record, data: ArrayLike | Record) -> float:
    """Log-likelihood: sum of per-observation log-probs.

    *params* and *data* can be raw arrays or ``Record`` objects.
    When *data* is ``Record(X=..., y=...)``, both the covariate
    matrix and response are extracted. The linear predictor
    ``eta = intercept + X @ slopes`` is computed via
    :meth:`_linear_predictor` so the ``fit_intercept`` convention is
    respected uniformly across the public methods.
    """
    X, y = self._extract_X_y(data)
    eta = self._linear_predictor(X, params)
    return jnp.sum(self.family.log_prob(y, eta))

per_datum_log_likelihood(params, datum)

Log-density of a single observation given parameters.

Satisfies ConditionallyIndependentLikelihood. Evaluates family.log_prob(y_i, x_i @ params) directly on a scalar response, bypassing the length-1-batch reshape that the default fallback (log_likelihood(params, datum[None, ...])) would add. The saved per-call overhead matters when this method is called inside a stochastic-gradient inner loop.

Parameters:

Name Type Description Default
params Array or Record

Coefficient vector of shape (p,).

required
datum Record

Record(X=x_i, y=y_i) with x_i of shape (p,) and y_i scalar.

required

Raises:

Type Description
TypeError

If datum is not a Record(X=..., y=...).

Source code in probpipe/modeling/_glm.py
def per_datum_log_likelihood(
    self, params: ArrayLike | Record, datum: Record,
) -> Array:
    """Log-density of a single observation given parameters.

    Satisfies :class:`~probpipe.ConditionallyIndependentLikelihood`.
    Evaluates ``family.log_prob(y_i, x_i @ params)`` directly on a
    scalar response, bypassing the length-1-batch reshape that the
    default fallback (``log_likelihood(params, datum[None, ...])``)
    would add. The saved per-call overhead matters when this method
    is called inside a stochastic-gradient inner loop.

    Parameters
    ----------
    params : Array or Record
        Coefficient vector of shape ``(p,)``.
    datum : Record
        ``Record(X=x_i, y=y_i)`` with ``x_i`` of shape ``(p,)``
        and ``y_i`` scalar.

    Raises
    ------
    TypeError
        If ``datum`` is not a ``Record(X=..., y=...)``.
    """
    if not (isinstance(datum, Record) and "X" in datum and "y" in datum):
        raise TypeError(
            "GLMLikelihood.per_datum_log_likelihood requires "
            "datum=Record(X=x_i, y=y_i)."
        )
    # `atleast_1d` accommodates the single-covariate case where the
    # per-observation X leaf is naturally scalar — the matmul below
    # still needs a 1-D vector.
    x_i = jnp.atleast_1d(jnp.asarray(datum["X"]))
    y_i = jnp.asarray(datum["y"])
    beta = _coerce_array(params)
    if self._fit_intercept:
        eta = beta[0] + x_i @ beta[1:]
    else:
        eta = x_i @ beta
    return self.family.log_prob(y_i, eta)

generate_data(params, n_samples, *, key=None)

Generate synthetic data from the GLM.

Uses the stored design matrix (first n_samples rows).

Parameters:

Name Type Description Default
params Array or Record

Parameter vector of shape (p,) or a batch (*batch, p).

required
n_samples int

Number of observations to generate (per batch element).

required
key PRNGKey

JAX PRNG key.

None
Source code in probpipe/modeling/_glm.py
def generate_data(
    self,
    params: ArrayLike | Record,
    n_samples: int,
    *,
    key: PRNGKey | None = None,
) -> Array:
    """Generate synthetic data from the GLM.

    Uses the stored design matrix (first ``n_samples`` rows).

    Parameters
    ----------
    params : Array or Record
        Parameter vector of shape ``(p,)`` or a batch ``(*batch, p)``.
    n_samples : int
        Number of observations to generate (per batch element).
    key : PRNGKey, optional
        JAX PRNG key.
    """
    if key is None:
        self._key, key = jax.random.split(self._key)
    X = self._x
    if X is None:
        raise ValueError(
            "generate_data requires a stored design matrix (pass x at construction)"
        )
    # Linear predictor — same convention as log_likelihood:
    #   fit_intercept=True  → eta = intercept + X @ slopes
    #   fit_intercept=False → eta = X @ params (classical model-matrix form)
    # Batched params come in as ``(*batch, p_total)``; broadcast against
    # ``X[:n_samples]`` of shape ``(n_samples, p)``.
    beta = _coerce_array(params)
    Xn = X[:n_samples]
    if self._fit_intercept:
        slopes = beta[..., 1:]
        intercept = beta[..., 0:1]  # keep last axis for broadcasting
        eta = intercept + slopes @ Xn.T
    else:
        eta = beta @ Xn.T
    dist = self.family.as_distribution(eta)
    return dist.sample(seed=key)

IncrementalConditioner(prior, likelihood, *, condition_fn=None, **condition_kwargs)

Bases: Module

Iteratively update a posterior by conditioning on data batches.

Maintains a current posterior (initially the prior) and provides update for single-batch conditioning and update_all for multi-batch iteration. Both methods update the internal state so that subsequent calls continue from the latest posterior.

The step property exposes the underlying step function for direct use with iterate and combinators.

Parameters:

Name Type Description Default
prior Distribution[P]

Initial prior distribution over model parameters.

required
likelihood Likelihood[P, D]

Likelihood object.

required
condition_fn callable or None

Conditioning callable; defaults to the global condition_on operation (which dispatches through the inference method registry).

None
**condition_kwargs Any

Extra keyword arguments forwarded to condition_fn on every call (e.g., method="tfp_nuts", num_results=2000).

{}

Examples:

::

conditioner = IncrementalConditioner(prior, likelihood)

# Single-batch update (stateful):
posterior1 = conditioner.update(data=batch1)
posterior2 = conditioner.update(data=batch2)
conditioner.curr_posterior  # is posterior2

# Multi-batch update (stateful, returns sequence):
dists = conditioner.update_all(data_batches=[batch3, batch4])
conditioner.curr_posterior  # is dists[-1]

# Functional escape hatch (for combinators):
dists = iterate(conditioner.step, prior, all_batches)
Source code in probpipe/modeling/_likelihood.py
def __init__(
    self,
    prior: Distribution[P],
    likelihood: Likelihood[P, D],
    *,
    condition_fn: Callable[..., Distribution[P]] | None = None,
    **condition_kwargs: Any,
):
    self._prior = prior
    self._likelihood = likelihood
    self._curr_posterior: Distribution[P] = prior
    self._step: _ConditioningStep[P, D] = _ConditioningStep(
        likelihood,
        condition_fn=condition_fn,
        **condition_kwargs,
    )

curr_posterior property

The current posterior (initially the prior).

step property

The underlying step function, for use with iterate.

update(data=None, **kwargs)

Condition on new data, updating the current posterior.

Data can be passed positionally, as data=, or as named keyword arguments that are bundled into a Record::

conditioner.update(y_obs)                # positional / data=
conditioner.update(X=new_X, y=new_y)     # named kwargs → Record

Parameters:

Name Type Description Default
data D

New observed data to condition on.

None
**kwargs

Named data fields, bundled into a Record automatically. Cannot be combined with data.

{}

Returns:

Type Description
Distribution[P]

The updated posterior distribution.

Source code in probpipe/modeling/_likelihood.py
def update(self, data: D | None = None, **kwargs) -> Distribution[P]:
    """Condition on new data, updating the current posterior.

    Data can be passed positionally, as ``data=``, or as named
    keyword arguments that are bundled into a ``Record``::

        conditioner.update(y_obs)                # positional / data=
        conditioner.update(X=new_X, y=new_y)     # named kwargs → Record

    Parameters
    ----------
    data : D, optional
        New observed data to condition on.
    **kwargs
        Named data fields, bundled into a ``Record`` automatically.
        Cannot be combined with *data*.

    Returns
    -------
    Distribution[P]
        The updated posterior distribution.
    """
    if kwargs:
        if data is not None:
            raise ValueError(
                "Cannot provide both `data` and named data kwargs"
            )
        from ..core.record import Record
        data = Record(kwargs)
    posterior = self._step(self._curr_posterior, data)
    self._curr_posterior = posterior
    return posterior

update_all(data_batches)

Condition on multiple data batches sequentially.

Calls iterate(self.step, self.curr_posterior, data_batches) and updates the internal state to the final posterior.

Parameters:

Name Type Description Default
data_batches Iterable[D]

Sequence of data batches to condition on.

required

Returns:

Type Description
list[Distribution[P]]

Sequence [starting_posterior, post_1, post_2, ...].

Source code in probpipe/modeling/_likelihood.py
def update_all(self, data_batches: Iterable[D]) -> list[Distribution[P]]:
    """Condition on multiple data batches sequentially.

    Calls ``iterate(self.step, self.curr_posterior, data_batches)``
    and updates the internal state to the final posterior.

    Parameters
    ----------
    data_batches : Iterable[D]
        Sequence of data batches to condition on.

    Returns
    -------
    list[Distribution[P]]
        Sequence ``[starting_posterior, post_1, post_2, ...]``.
    """
    dists = iterate(self._step, self._curr_posterior, data_batches)
    self._curr_posterior = dists[-1]
    return dists

Inference methods

condition_on dispatches through the inference-method registry: methods are tried in descending priority order and the first whose check() returns feasible=True runs. Pass method="<name>" to override the auto-selection; inference_method_registry.set_priorities(...) reorders the table at runtime.

Priorities follow a single-axis convention: values above 50 mark exact methods, values in (0, 50] mark inexact methods, and 0 means opt-in only (selectable by name but skipped during auto-dispatch). The five-axis selection criteria and the tier ranges contributors should use when choosing a number for a new method are documented under Extending ProbPipe → Setting priority for a new method.

Built-in methods:

Name Priority Requires Backend
nutpie_nuts 85 StanModel or PyMCModel + nutpie nutpie
cmdstan_nuts 82 StanModel + cmdstanpy CmdStan
pymc_nuts 81 PyMCModel + pymc PyMC
tfp_nuts 75 SupportsLogProb + JAX-traceable TFP
tfp_hmc 65 SupportsLogProb + JAX-traceable TFP
tfp_rwmh 55 SupportsLogProb TFP
blackjax_sgld 45 SimpleModel + ConditionallyIndependentLikelihood + batch_size= BlackJAX
blackjax_sghmc 42 SimpleModel + ConditionallyIndependentLikelihood + batch_size= BlackJAX
pymc_advi 25 PyMCModel + pymc PyMC
sbijax_smcabc 5 SimpleGenerativeModel + sbijax sbijax

ApproximateDistribution(chains, *, weights=None, name=None, record_template=None)

Bases: RecordEmpiricalDistribution

Empirical distribution with chain structure.

Stores per-chain sample arrays for chain-structured access via draws. Algorithm metadata, sample statistics, warmup samples, and the ArviZ InferenceData object live in dist.auxiliary (a DataTree on the Distribution base class), not as attributes of this class.

Parameters:

Name Type Description Default
chains list of Array

Per-chain sample arrays, each of shape (num_draws, *event_shape).

required
weights array-like, :class:`~probpipe.Weights`, or None

Optional per-sample importance weights (across all chains).

None
name str or None

Distribution name for provenance.

None
Notes

When record_template is multi-field, __init__ slices the concatenated chain into per-top-level-field arrays so fields, event_shapes, dtypes, _mean / _variance, and the public ops (mean(post) / variance(post)) all return Records whose keys match fields. Nested RecordTemplate fields are stored as a flat (n, nested_flat_size) array under the top-level field name; the nested structure is recoverable via record_template[field] and via draws, which walks the full template (including nesting) using the original per-chain samples.

Source code in probpipe/inference/_approximate_distribution.py
def __init__(
    self,
    chains: list[Array],
    *,
    weights: ArrayLike | Weights | None = None,
    name: str | None = None,
    record_template: RecordTemplate | None = None,
):
    if not chains:
        raise ValueError("Must provide at least one chain")

    self._chains = [jnp.asarray(c) for c in chains]
    self._concatenated: Array | None = None

    flat = self._concat_chains()
    # Track whether the user explicitly supplied a template; we use
    # this in ``draws()`` to decide whether to wrap the output.
    self._user_template = record_template is not None
    # Multi-field template → split the flat chain by top-level
    # field. Nested ``RecordTemplate`` fields are stored as a
    # 2-D ``(n, nested_flat_size)`` slice under the top-level
    # field name; the nested structure is recovered via
    # ``record_template[field]`` and ``draws()``. Slice sizes use
    # ``_spec_size``, which already handles both flat and nested
    # specs.
    if record_template is not None and len(record_template.fields) > 1:
        # Compute per-field sizes upfront so we can sanity-check the
        # chain's last dim against the template's total flat size
        # (catching template/data mismatch before silent slicing past
        # the end produces zero-sized chunks). ``_spec_size`` raises
        # on opaque (``spec=None``) leaves; pre-validate here so the
        # error names the offending field rather than the generic
        # ``_spec_size`` message.
        sizes: list[int] = []
        for field_name in record_template.fields:
            spec = record_template[field_name]
            if spec is None:
                raise ValueError(
                    f"ApproximateDistribution requires a numeric "
                    f"template; field {field_name!r} has spec=None "
                    f"(opaque). Opaque leaves don't have a flat size."
                )
            sizes.append(_spec_size(spec))
        total = sum(sizes)
        if flat.shape[-1] != total:
            raise ValueError(
                f"chain last dim ({flat.shape[-1]}) doesn't match "
                f"template total flat size ({total}); template "
                f"fields={record_template.fields}, sizes={sizes}."
            )
        offset = 0
        fields: dict[str, Array] = {}
        for field_name, size in zip(record_template.fields, sizes):
            spec = record_template[field_name]
            chunk = flat[..., offset : offset + size]
            if isinstance(spec, RecordTemplate):
                # Nested: keep flat-per-top-level-field. Shape is
                # ``(*sample_shape, nested_flat_size)``.
                fields[field_name] = chunk
            else:
                shape = spec if spec is not None else ()
                fields[field_name] = chunk.reshape(*flat.shape[:-1], *shape)
            offset += size
        super().__init__(Record(fields), weights=weights, name=name or "posterior")
        self._record_template = record_template
    else:
        # Single-field path: ``name`` (default ``"posterior"``)
        # becomes the auto-wrapped field name. If the user passed a
        # single-field template, rename to honor it.
        field_name = name or "posterior"
        if record_template is not None and len(record_template.fields) == 1:
            field_name = record_template.fields[0]
        super().__init__(flat, weights=weights, name=field_name)
        if record_template is not None:
            self._record_template = record_template

chains property

Per-chain sample arrays.

num_chains property

Number of chains.

num_draws property

Number of draws per chain (assumes equal-length chains).

algorithm property

Name of the inference algorithm (read from provenance).

inference_data property

The auxiliary DataTree, for ArviZ compatibility.

Alias for self.auxiliary. Use ArviZ functions for diagnostics::

import arviz as az
az.summary(posterior.inference_data)

warmup_samples property

Per-chain warmup samples extracted from auxiliary data.

draws(chain=None, *, include_warmup=False)

Access draws, optionally named via record_template.

Parameters:

Name Type Description Default
chain int or None

Chain index. If None, concatenates all chains.

None
include_warmup bool

If True and warmup samples are in the auxiliary DataTree, prepend them.

False

Returns:

Type Description
Array or Record

If record_template is set, returns a Record with named fields. Otherwise returns a raw array.

Source code in probpipe/inference/_approximate_distribution.py
def draws(
    self,
    chain: int | None = None,
    *,
    include_warmup: bool = False,
) -> Array | Record:
    """Access draws, optionally named via record_template.

    Parameters
    ----------
    chain : int or None
        Chain index.  If ``None``, concatenates all chains.
    include_warmup : bool
        If ``True`` and warmup samples are in the auxiliary DataTree,
        prepend them.

    Returns
    -------
    Array or Record
        If ``record_template`` is set, returns a :class:`~probpipe.Record`
        with named fields.  Otherwise returns a raw array.
    """
    if chain is not None:
        samples = self._chains[chain]
        if include_warmup:
            warmup = self.warmup_samples
            if warmup is not None:
                samples = jnp.concatenate([warmup[chain], samples], axis=0)
    else:
        parts = list(self._chains)
        if include_warmup:
            warmup = self.warmup_samples
            if warmup is not None:
                parts = [jnp.concatenate([w, c], axis=0)
                         for w, c in zip(warmup, parts)]
        samples = jnp.concatenate(parts, axis=0)

    # Honor any user-supplied template (single-field or multi-field).
    # Without one, return the raw concatenated array — matches the
    # historical behaviour of single-field auto-wrap empiricals
    # under the previous numeric-array hierarchy.
    if getattr(self, "_user_template", False):
        from ..core._record_array import NumericRecordArray
        return NumericRecordArray.unflatten(samples, template=self.record_template)
    return samples

rwmh(dist, data=None, *, log_prob_fn=None, num_results=1000, num_warmup=500, num_chains=1, step_size=0.1, init=None, random_seed=0)

Gradient-free random-walk Metropolis-Hastings.

Parameters:

Name Type Description Default
dist SupportsUnnormalizedLogProb

Distribution providing _unnormalized_log_prob. RWMH uses only the unnormalized density because the missing log normalizer cancels out of every accept/reject step.

required
data array - like or None

Observed data.

None
log_prob_fn callable or None

log_prob_fn(params, data) -> float. Combined with dist._unnormalized_log_prob(params) to form the target density.

None
num_results int

MCMC tuning parameters.

1000
num_warmup int

MCMC tuning parameters.

1000
num_chains int

MCMC tuning parameters.

1000
step_size int

MCMC tuning parameters.

1000
random_seed int

MCMC tuning parameters.

1000
init array - like or None

Initial chain state. Tries dist._mean(), then zeros.

None

Returns:

Type Description
ApproximateDistribution

Posterior samples with chain structure and auxiliary DataTree.

Source code in probpipe/inference/_rwmh.py
@workflow_function
def rwmh(
    dist: SupportsUnnormalizedLogProb,
    data: ArrayLike | None = None,
    *,
    log_prob_fn: Any | None = None,
    num_results: int = 1000,
    num_warmup: int = 500,
    num_chains: int = 1,
    step_size: float = 0.1,
    init: ArrayLike | None = None,
    random_seed: int = 0,
) -> ApproximateDistribution:
    """Gradient-free random-walk Metropolis-Hastings.

    Parameters
    ----------
    dist : SupportsUnnormalizedLogProb
        Distribution providing ``_unnormalized_log_prob``.  RWMH uses
        only the unnormalized density because the missing log
        normalizer cancels out of every accept/reject step.
    data : array-like or None
        Observed data.
    log_prob_fn : callable or None
        ``log_prob_fn(params, data) -> float``.  Combined with
        ``dist._unnormalized_log_prob(params)`` to form the target density.
    num_results, num_warmup, num_chains, step_size, random_seed
        MCMC tuning parameters.
    init : array-like or None
        Initial chain state.  Tries ``dist._mean()``, then zeros.

    Returns
    -------
    ApproximateDistribution
        Posterior samples with chain structure and auxiliary DataTree.
    """
    if not isinstance(dist, SupportsUnnormalizedLogProb):
        raise TypeError(
            f"{type(dist).__name__} does not support log_prob "
            f"(does not implement SupportsUnnormalizedLogProb)"
        )

    if log_prob_fn is not None and data is not None:
        data_jnp = jnp.asarray(data)
        def target_log_prob(params):
            return dist._unnormalized_log_prob(params) + log_prob_fn(params, data_jnp)
    else:
        def target_log_prob(params):
            return dist._unnormalized_log_prob(params)

    init_state = _get_init_state(dist, init, data)

    d = init_state.shape[0]
    key = jax.random.PRNGKey(random_seed)

    chains = []
    warmup_chains = []
    total_accepts = 0
    total_steps = 0

    for _ in range(num_chains):
        key, chain_key = jax.random.split(key)
        mu_curr = jnp.array(init_state)
        logp_curr = float(target_log_prob(mu_curr))

        warmup_samples: list[jnp.ndarray] = []
        kept: list[jnp.ndarray] = []
        chain_accepts = 0
        chain_total = num_warmup + num_results

        for t in range(chain_total):
            chain_key, subkey_prop, subkey_accept = jax.random.split(chain_key, 3)
            noise = jax.random.normal(subkey_prop, shape=(d,), dtype=mu_curr.dtype)
            mu_prop = mu_curr + step_size * noise
            logp_prop = float(target_log_prob(mu_prop))

            u = jax.random.uniform(subkey_accept, dtype=mu_curr.dtype)
            if jnp.log(u) < min(0.0, logp_prop - logp_curr):
                mu_curr = mu_prop
                logp_curr = logp_prop
                chain_accepts += 1

            if t < num_warmup:
                warmup_samples.append(mu_curr)
            else:
                kept.append(mu_curr)

        chains.append(jnp.stack(kept))
        warmup_chains.append(jnp.stack(warmup_samples) if warmup_samples else None)
        total_accepts += chain_accepts
        total_steps += chain_total

    accept_rate = total_accepts / total_steps
    warmup = warmup_chains if all(w is not None for w in warmup_chains) else None

    auxiliary = _build_mcmc_datatree(chains, warmup_chains=warmup)

    record_template = _extract_record_template(dist)
    return make_posterior(
        chains, parents=(dist,), algorithm="rwmh",
        auxiliary=auxiliary, record_template=record_template,
        num_results=num_results, num_warmup=num_warmup, num_chains=num_chains,
        step_size=step_size, accept_rate=accept_rate,
    )

condition_on_nutpie(model, data=None, *, num_results=1000, num_warmup=500, num_chains=4, random_seed=0, **kwargs)

MCMC sampling via nutpie (Rust-based NUTS).

Accepts a StanModel or PyMCModel.

Source code in probpipe/inference/_nutpie.py
@workflow_function
def condition_on_nutpie(
    model: Any,
    data: ArrayLike | None = None,
    *,
    num_results: int = 1000,
    num_warmup: int = 500,
    num_chains: int = 4,
    random_seed: int = 0,
    **kwargs: Any,
) -> ApproximateDistribution:
    """MCMC sampling via nutpie (Rust-based NUTS).

    Accepts a :class:`~probpipe.modeling.StanModel` or
    :class:`~probpipe.modeling.PyMCModel`.
    """
    try:
        import nutpie
    except ImportError as e:
        raise ImportError(
            "nutpie is required for condition_on_nutpie. "
            "Install it with: pip install nutpie"
        ) from e

    compiled = _compile_for_nutpie(model, data)
    trace = nutpie.sample(
        compiled, draws=num_results, tune=num_warmup,
        chains=num_chains, seed=random_seed, **kwargs,
    )

    chains, param_names = _extract_chains(trace, num_chains)

    return make_posterior(
        chains, parents=(model,), algorithm="nutpie_nuts",
        auxiliary=trace, record_template=getattr(model, "record_template", None),
        num_results=num_results, num_warmup=num_warmup, num_chains=num_chains,
    )

sbi_learn_conditional(*args, **kwargs)

Placeholder that raises when sbijax is not installed.

Source code in probpipe/inference/__init__.py
def sbi_learn_conditional(*args, **kwargs):  # type: ignore[misc]
    """Placeholder that raises when sbijax is not installed."""
    raise ImportError(_SBI_INSTALL_MSG)

sbi_learn_likelihood(*args, **kwargs)

Placeholder that raises when sbijax is not installed.

Source code in probpipe/inference/__init__.py
def sbi_learn_likelihood(*args, **kwargs):  # type: ignore[misc]
    """Placeholder that raises when sbijax is not installed."""
    raise ImportError(_SBI_INSTALL_MSG)

MinibatchedDistribution(prior, likelihood, data, batch_size, *, with_replacement=False, name=None)

Bases: RandomMeasure[Record], SupportsRandomUnnormalizedLogProb

Random measure realised by uniform minibatching.

A draw from this measure is a fixed-minibatch target — an unnormalized stochastic surrogate of the full-data unnormalized log-posterior, rescaled by N / b so the gradient is an unbiased estimator. Not a posterior in the strict (normalized) sense.

For a model with prior :math:p(\theta) and likelihood :math:p(\mathcal{D} \mid \theta) = \prod_i p(d_i \mid \theta), a draw's unnormalized log-density is

.. math::

\log \tilde{D}_B(\theta) = \log p(\theta)
                              + \frac{N}{b}
                                \sum_{d \in B}
                                  \log p(d \mid \theta).

Parameters:

Name Type Description Default
prior SupportsLogProb

Prior distribution over parameters; provides the log-prior term :math:\log p(\theta).

required
likelihood ConditionallyIndependentLikelihood

Likelihood that factorises as :math:\log p(\mathcal{D} \mid \theta) = \sum_i \log p(d_i \mid \theta); supplies the per-datum log-density used in the rescaled sum.

required
data array-like, Record, or RecordArray

Observed data. Indexed along its leading axis to draw minibatches; must have leading-axis length >= batch_size.

required
batch_size int

Minibatch size :math:b. Must be 1 <= b <= len(data).

required
with_replacement bool

Sample minibatch indices with replacement. Default is without-replacement (uniform permutation, take first b).

False
name str

Distribution name.

None

Raises:

Type Description
TypeError

If prior is not SupportsLogProb or likelihood is not ConditionallyIndependentLikelihood.

ValueError

If batch_size is not in [1, len(data)].

Source code in probpipe/inference/_minibatch.py
def __init__(
    self,
    prior: SupportsLogProb,
    likelihood: "ConditionallyIndependentLikelihood",
    data: ArrayLike | Record | RecordArray,
    batch_size: int,
    *,
    with_replacement: bool = False,
    name: str | None = None,
):
    # Lazy import: inference can't import modeling at module load.
    from ..modeling._likelihood import ConditionallyIndependentLikelihood

    if not isinstance(prior, SupportsLogProb):
        raise TypeError(
            f"MinibatchedDistribution requires prior to satisfy "
            f"SupportsLogProb; got {type(prior).__name__}."
        )
    if not isinstance(likelihood, ConditionallyIndependentLikelihood):
        raise TypeError(
            f"MinibatchedDistribution requires likelihood to satisfy "
            f"ConditionallyIndependentLikelihood; got "
            f"{type(likelihood).__name__}. Implement "
            f"per_datum_log_likelihood(params, datum) on the "
            f"likelihood class."
        )

    # Validate data + batch_size.
    n = _data_size(data)
    if batch_size < 1 or batch_size > n:
        raise ValueError(
            f"batch_size must be in [1, len(data)={n}]; got {batch_size}."
        )

    self._prior = prior
    self._likelihood = likelihood
    self._data = data
    self._n = int(n)
    self._batch_size = int(batch_size)
    self._with_replacement = bool(with_replacement)
    self._rescale_factor = float(self._n / batch_size)

    if name is None:
        name = f"MinibatchedDistribution(batch_size={batch_size})"
    super().__init__(name=name)

dataset_size property

Total number of observations in the dataset (len(data)).

Named dataset_size rather than .n to avoid colliding with STYLE_GUIDE §1.9's "how many items does this hold?" semantics — MinibatchedDistribution is not a finite-sample distribution; it doesn't hold a finite collection of realisations.

batch_size property

Minibatch size :math:b.

with_replacement property

Whether minibatch indices are drawn with replacement.

prior property

The prior distribution over parameters.

likelihood property

The conditionally-independent likelihood.

data property

The full dataset (not the minibatched view).

inference_method_registry = MethodRegistry() module-attribute

Iterative transformations

Step functions folded over inputs by iterate, with with_conversion and with_resampling as step-function wrappers.

iterate(step_fn, initial, inputs, *, callback=None)

Fold a step function over inputs, accumulating a distribution sequence.

Starting from initial, applies step_fn(dist, inp) for each element of inputs, collecting the resulting distributions into a list. The returned list includes the initial distribution at index 0.

Provenance is automatically attached to each output distribution (linking it to the previous distribution) unless the step function has already set provenance.

Parameters:

Name Type Description Default
step_fn callable

(Distribution[T], S) -> Distribution[T]. Any callable matching this signature — plain functions, WorkflowFunction instances, or bound methods.

required
initial Distribution[T]

The starting distribution.

required
inputs Iterable[S]

Sequence of inputs to pass to the step function.

required
callback callable or None

Called as callback(i, dist) after each step, where i is the step index and dist is the newly produced distribution. If it returns exactly False, iteration stops early.

None

Returns:

Type Description
list[Distribution[T]]

The full sequence: [initial, dist_1, dist_2, ...].

Source code in probpipe/core/transition.py
@workflow_function
def iterate[T, S](
    step_fn: Callable[[Distribution[T], S], Distribution[T]],
    initial: Distribution[T],
    inputs: Iterable[S],
    *,
    callback: Callable[[int, Distribution[T]], Any] | None = None,
) -> list[Distribution[T]]:
    """Fold a step function over inputs, accumulating a distribution sequence.

    Starting from *initial*, applies ``step_fn(dist, inp)`` for each
    element of *inputs*, collecting the resulting distributions into a
    list.  The returned list includes the initial distribution at
    index 0.

    Provenance is automatically attached to each output distribution
    (linking it to the previous distribution) unless the step function
    has already set provenance.

    Parameters
    ----------
    step_fn : callable
        ``(Distribution[T], S) -> Distribution[T]``.
        Any callable matching this signature — plain functions,
        :class:`WorkflowFunction` instances, or bound methods.
    initial : Distribution[T]
        The starting distribution.
    inputs : Iterable[S]
        Sequence of inputs to pass to the step function.
    callback : callable or None
        Called as ``callback(i, dist)`` after each step, where *i* is
        the step index and *dist* is the newly produced distribution.
        If it returns exactly ``False``, iteration stops early.

    Returns
    -------
    list[Distribution[T]]
        The full sequence: ``[initial, dist_1, dist_2, ...]``.
    """
    dists: list[Distribution[T]] = [initial]
    current = initial

    for i, inp in enumerate(inputs):
        result = step_fn(current, inp)
        if not isinstance(result, Distribution):
            raise TypeError(
                f"Step function at index {i} returned "
                f"{type(result).__name__}, expected Distribution."
            )

        # Auto-attach provenance if not already set
        if result.source is None:
            try:
                result.with_source(
                    Provenance(
                        "iterate",
                        parents=(current,),
                        metadata={"step": i},
                    )
                )
            except RuntimeError:
                pass  # write-once guard

        dists.append(result)
        current = result

        if callback is not None:
            cont = callback(i, result)
            if cont is False:
                break

    return dists

with_conversion(step_fn, target_type, **convert_kwargs)

Wrap a step function to convert its output after each step.

After calling step_fn, converts the resulting distribution to target_type using ProbPipe's standard from_distribution operation (which dispatches through the converter registry). The pre-conversion distribution is accessible via the converted distribution's provenance parents (set by the converter).

This is useful when the step function produces samples (e.g., MCMC output) but the next iteration needs a parametric distribution as input.

The returned wrapper is a WorkflowFunction, so it appears as a node in the ProbPipe workflow DAG.

Parameters:

Name Type Description Default
step_fn callable

The underlying step function.

required
target_type type

Distribution type to convert to (e.g., MultivariateNormal). Can also be a protocol (e.g., SupportsLogProb).

required
**convert_kwargs Any

Extra keyword arguments passed to from_distribution.

{}

Returns:

Type Description
WorkflowFunction

A new step function with the same call signature.

Source code in probpipe/core/transition.py
def with_conversion(
    step_fn: Callable,
    target_type: type,
    **convert_kwargs: Any,
) -> WorkflowFunction:
    """Wrap a step function to convert its output after each step.

    After calling *step_fn*, converts the resulting distribution to
    *target_type* using ProbPipe's standard ``from_distribution``
    operation (which dispatches through the converter registry).
    The pre-conversion distribution is accessible via the converted
    distribution's provenance parents (set by the converter).

    This is useful when the step function produces samples (e.g.,
    MCMC output) but the next iteration needs a parametric
    distribution as input.

    The returned wrapper is a :class:`WorkflowFunction`, so it appears
    as a node in the ProbPipe workflow DAG.

    Parameters
    ----------
    step_fn : callable
        The underlying step function.
    target_type : type
        Distribution type to convert to (e.g., ``MultivariateNormal``).
        Can also be a protocol (e.g., ``SupportsLogProb``).
    **convert_kwargs
        Extra keyword arguments passed to ``from_distribution``.

    Returns
    -------
    WorkflowFunction
        A new step function with the same call signature.
    """
    inner_name = _step_fn_name(step_fn)

    def _with_conversion_impl(dist: Distribution, inp: Any) -> Distribution:
        from .ops import from_distribution

        result = step_fn(dist, inp)
        return from_distribution(result, target_type, **convert_kwargs)

    return WorkflowFunction(
        func=_with_conversion_impl,
        name=f"with_conversion({inner_name}, {target_type.__name__})",
    )

with_resampling(step_fn, *, ess_threshold=0.5, seed=0)

Wrap a step function to resample when particle weights degenerate.

After calling step_fn, if the result is an EmpiricalDistribution with ESS / N < ess_threshold, performs multinomial resampling to produce equally-weighted particles.

The pre-resampling ESS is stored in provenance metadata of the resampled distribution (dist.source.metadata["ess"]) since this information would otherwise be lost after resampling to uniform weights.

The returned wrapper is a WorkflowFunction, so it appears as a node in the ProbPipe workflow DAG.

Parameters:

Name Type Description Default
step_fn callable

The underlying step function.

required
ess_threshold float

Resample when ESS / N drops below this value (default 0.5).

0.5
seed int

Base random seed; combined with a call counter for deterministic reproducibility.

0

Returns:

Type Description
WorkflowFunction

A new step function with the same call signature.

Notes

This API is likely to evolve as typical use cases become clearer. A future direction is a SupportsResampling protocol that would decouple this combinator from the concrete EmpiricalDistribution type.

Source code in probpipe/core/transition.py
def with_resampling(
    step_fn: Callable,
    *,
    ess_threshold: float = 0.5,
    seed: int = 0,
) -> WorkflowFunction:
    """Wrap a step function to resample when particle weights degenerate.

    After calling *step_fn*, if the result is an
    :class:`~probpipe.core.distribution.EmpiricalDistribution` with
    ``ESS / N < ess_threshold``, performs multinomial resampling to
    produce equally-weighted particles.

    The pre-resampling ESS is stored in provenance metadata of the
    resampled distribution (``dist.source.metadata["ess"]``) since
    this information would otherwise be lost after resampling to
    uniform weights.

    The returned wrapper is a :class:`WorkflowFunction`, so it appears
    as a node in the ProbPipe workflow DAG.

    Parameters
    ----------
    step_fn : callable
        The underlying step function.
    ess_threshold : float
        Resample when ``ESS / N`` drops below this value (default 0.5).
    seed : int
        Base random seed; combined with a call counter for
        deterministic reproducibility.

    Returns
    -------
    WorkflowFunction
        A new step function with the same call signature.

    Notes
    -----
    This API is likely to evolve as typical use cases become clearer.
    A future direction is a ``SupportsResampling`` protocol that would
    decouple this combinator from the concrete
    :class:`~probpipe.core.distribution.EmpiricalDistribution` type.
    """
    import jax
    import jax.numpy as jnp

    inner_name = _step_fn_name(step_fn)
    call_count = 0

    def _with_resampling_impl(dist: Distribution, inp: Any) -> Distribution:
        nonlocal call_count
        from .distribution import EmpiricalDistribution

        out_dist = step_fn(dist, inp)

        if isinstance(out_dist, EmpiricalDistribution):
            n = out_dist.n
            ess = float(out_dist.effective_sample_size)
            ess_ratio = ess / n

            if ess_ratio < ess_threshold:
                key = jax.random.PRNGKey(seed + call_count)
                call_count += 1
                indices = out_dist._w.choice(key, shape=(n,))
                # Resample per-field (samples is a NumericRecord; index
                # each field's stacked array along the sample axis).
                from .record import Record
                new_record = Record({
                    f: out_dist.samples[f][indices]
                    for f in out_dist.samples.fields
                })
                resampled = EmpiricalDistribution(
                    new_record, name=out_dist.name,
                )
                resampled.with_source(
                    Provenance(
                        "resample",
                        parents=(out_dist,),
                        metadata={"ess": ess, "ess_ratio": ess_ratio},
                    )
                )
                return resampled

        return out_dist

    return WorkflowFunction(
        func=_with_resampling_impl,
        name=f"with_resampling({inner_name})",
    )

Predictive checks

predictive_check(distribution, generative_likelihood, test_fn, observed_data=None, *, n_samples=None, n_replications=500, key=None)

Predictive check — works as both prior and posterior check.

Draws parameter samples from distribution, generates replicated data via generative_likelihood, and computes test_fn on each replicate.

When observed_data is provided, also computes test_fn on the observed data and returns a calibration p-value, making this a posterior predictive check. Without observed_data, this is a prior predictive check — useful for understanding the implications of the prior.

When generate_data accepts a key keyword argument, all replications are generated in a single vectorized call (by passing a batch of parameter vectors), giving a large speedup. The test function is then applied via jax.vmap when possible, with an automatic fallback to a Python loop.

Parameters:

Name Type Description Default
distribution Distribution[P]

Prior or posterior to sample parameters from.

required
generative_likelihood GenerativeLikelihood[P, D]

Must have generate_data(params: P, n_samples: int, *, key: PRNGKey | None = None) -> D. If generate_data also accepts a key keyword, the vectorized fast path is used.

required
test_fn Callable[[D], float]

Test statistic mapping data to a scalar.

required
observed_data D or None

If provided, compute the observed test statistic and p-value.

None
n_samples int

Number of observations per replicated dataset. Required if observed_data is not provided; otherwise defaults to len(observed_data).

None
n_replications int

Number of replicated datasets to generate.

500
key PRNGKey

JAX PRNG key. Auto-generated if None.

None

Returns:

Type Description
dict

Always contains:

  • "replicated_statistics"RecordEmpiricalDistribution over the test statistic values from replicated data.

When observed_data is provided, also contains:

  • "observed_statistic"test_fn(observed_data)
  • "p_value" — fraction of replicates where the test statistic is at least as extreme as the observed value.
Source code in probpipe/validation/_predictive_check.py
@workflow_function
def predictive_check[P, D](
    distribution: SupportsSampling,
    generative_likelihood: GenerativeLikelihood[P, D],
    test_fn: Callable[[D], float],
    observed_data: D | None = None,
    *,
    n_samples: int | None = None,
    n_replications: int = 500,
    key: PRNGKey | None = None,
) -> dict:
    """Predictive check — works as both prior and posterior check.

    Draws parameter samples from *distribution*, generates replicated
    data via *generative_likelihood*, and computes *test_fn* on each
    replicate.

    When *observed_data* is provided, also computes *test_fn* on the
    observed data and returns a calibration p-value, making this a
    posterior predictive check.  Without *observed_data*, this is a
    prior predictive check — useful for understanding the implications
    of the prior.

    When ``generate_data`` accepts a ``key`` keyword argument, all
    replications are generated in a single vectorized call (by passing
    a batch of parameter vectors), giving a large speedup.  The test
    function is then applied via ``jax.vmap`` when possible, with an
    automatic fallback to a Python loop.

    Parameters
    ----------
    distribution : Distribution[P]
        Prior or posterior to sample parameters from.
    generative_likelihood : GenerativeLikelihood[P, D]
        Must have ``generate_data(params: P, n_samples: int, *, key: PRNGKey | None = None) -> D``.
        If ``generate_data`` also accepts a ``key`` keyword, the
        vectorized fast path is used.
    test_fn : Callable[[D], float]
        Test statistic mapping data to a scalar.
    observed_data : D or None, optional
        If provided, compute the observed test statistic and p-value.
    n_samples : int, optional
        Number of observations per replicated dataset.  Required if
        *observed_data* is not provided; otherwise defaults to
        ``len(observed_data)``.
    n_replications : int
        Number of replicated datasets to generate.
    key : PRNGKey, optional
        JAX PRNG key.  Auto-generated if ``None``.

    Returns
    -------
    dict
        Always contains:

        - ``"replicated_statistics"`` — ``RecordEmpiricalDistribution``
          over the test statistic values from replicated data.

        When *observed_data* is provided, also contains:

        - ``"observed_statistic"`` — ``test_fn(observed_data)``
        - ``"p_value"`` — fraction of replicates where the test
          statistic is at least as extreme as the observed value.
    """
    if n_samples is None:
        if observed_data is None:
            raise ValueError(
                "n_samples is required when observed_data is not provided"
            )
        n_samples = len(observed_data)

    if key is None:
        key = _auto_key()

    # -- Fast path: batched generation + vmap test_fn -----------------------
    if _supports_key_arg(generative_likelihood):
        stats_array = _predictive_check_batched(
            distribution, generative_likelihood, test_fn,
            n_samples, n_replications, key,
        )
    else:
        stats_array = _predictive_check_loop(
            distribution, generative_likelihood, test_fn,
            n_samples, n_replications, key,
        )

    replicated_dist = RecordEmpiricalDistribution(
        stats_array, name="replicated_statistics",
    )

    test_fn_name = getattr(test_fn, "__name__", repr(test_fn))
    result = {
        "replicated_statistics": replicated_dist,
        "test_fn_name": test_fn_name,
    }

    if observed_data is not None:
        obs_stat = float(test_fn(observed_data))
        p_value = float(np.mean(stats_array >= obs_stat))
        result["observed_statistic"] = obs_stat
        result["p_value"] = p_value

    # Attach to the distribution for easy access
    if hasattr(distribution, "validation_results"):
        distribution.validation_results.append(result)

    return result