Skip to content

Extending ProbPipe

ProbPipe's extension surface is small and grouped by capability. The table below maps each kind of extension to the contract you implement against and the registry (if any) you register with. Each row links to the section on this page that covers it in detail.

To add a... Implement Register with
New distribution family Subclass of Distribution, RecordDistribution, NumericRecordDistribution, or TFPDistribution (none — capability is detected by isinstance against the matching protocol)
New op support on an existing distribution The matching underscore method (_sample, _log_prob, _mean, ...) on the class (none — see Protocols for which method backs which op)
New inference method (custom sampler, optimiser, ...) Subclass of Method declaring supported_types, priority, check(), and execute() inference_method_registry.register(...) — see Custom inference methods
New distribution-to-distribution converter Subclass of Converter with check() / convert() converter_registry.register(...) — see Custom converters
New canonical bijector for a Constraint A factory returning a TFP bijector register_bijector(constraint_or_class, factory) — see Custom bijectors
New auxiliary-metadata adapter (custom array-like) capture and restore callables register_aux(leaf_type, capture, restore) — see Custom auxiliary metadata

The two remaining sections — Broadcasting internals and the Internals page — document classes that an extension rarely constructs directly but may need to reference.

Distribution base classes

Distribution is the abstract root. RecordDistribution and NumericRecordDistribution specialise it for distributions whose _sample() returns a Record or NumericRecord respectively. TFPDistribution wraps an existing TFP Distribution.

Distribution(*, name)

Bases: ABC

Abstract base for all ProbPipe distributions, parameterized by value type T.

Every distribution has a name. Leaf distributions (Normal, Gamma, etc.) require an explicit name= argument; composite distributions (ProductDistribution, EmpiricalDistribution, etc.) auto-generate a name from their components when one is not provided.

Provides naming, provenance, conversion, and approximation tracking. Sampling and expectation capabilities are provided by the SupportsSampling protocol.

Source code in probpipe/core/_distribution_base.py
def __init__(self, *, name: str):
    if not isinstance(name, str) or not name:
        raise TypeError(
            f"{type(self).__name__} requires a non-empty name= argument"
        )
    self._name = name

is_approximate property

Whether this distribution is an approximation (e.g., from sampling or MCMC).

validation_results property

Results from predictive_check runs.

Each entry is a dict with at least "replicated_statistics" and "test_fn_name". Posterior checks also include "observed_statistic" and "p_value".

record_template property

Structural template for this distribution's samples, or None.

.. deprecated:: This property is migrating to RecordDistribution. Non-Record distributions should not rely on this.

auxiliary property

An xarray DataTree of auxiliary information (diagnostics, sample statistics, algorithm metadata), or None.

Populated by inference methods. Follows ArviZ group conventions (posterior, sample_stats, warmup, etc.) with metadata stored as DataTree attributes.

with_source(source)

Attach provenance to this distribution (write-once).

Source code in probpipe/core/_distribution_base.py
def with_source(self, source: Provenance) -> Distribution:
    """Attach provenance to this distribution (write-once)."""
    if getattr(self, "_source", None) is not None:
        raise RuntimeError(
            f"Source already set on {self!r}. "
            "Provenance is write-once; create a new distribution instead."
        )
    self._source = source
    return self

renamed(new_name)

Return a shallow copy with a different name.

The copy shares all internal state but has a new name. Provenance is tracked: the copy's source records the rename operation and points to the original as parent. Any cached record_template is cleared so it regenerates with the new name (relevant for TFPDistribution).

Source code in probpipe/core/_distribution_base.py
def renamed(self, new_name: str) -> Distribution:
    """Return a shallow copy with a different name.

    The copy shares all internal state but has a new ``name``.
    Provenance is tracked: the copy's ``source`` records the rename
    operation and points to the original as parent.  Any cached
    ``record_template`` is cleared so it regenerates with the new
    name (relevant for ``TFPDistribution``).
    """
    clone = _copy.copy(self)
    object.__setattr__(clone, "_name", new_name)
    object.__setattr__(clone, "_record_template", None)
    # Bypass write-once guard so rename provenance can be attached
    object.__setattr__(clone, "_source", None)
    clone.with_source(
        Provenance(
            "renamed",
            parents=(self,),
            metadata={"old_name": self.name, "new_name": new_name},
        )
    )
    return clone

from_batched_params(*, name, batch_shape=None, **batched_params) classmethod

Class-method alias for DistributionArray.from_batched_params.

Lets users write the ergonomic per-class form::

Normal.from_batched_params(loc=jnp.zeros(5), scale=1.0, name="x")

instead of the universal entry point::

DistributionArray.from_batched_params(
    Normal, loc=jnp.zeros(5), scale=1.0, name="x",
)

Both produce the same DistributionArray — the alias is a thin classmethod that calls the universal factory with cls bound. Subclasses inherit the alias automatically; no per-family override is needed.

See DistributionArray.from_batched_params for the full contract (dispatch on SupportsArrayBackend, batch_shape inference, per-cell name suffixing).

Source code in probpipe/core/_distribution_base.py
@classmethod
def from_batched_params(
    cls,
    *,
    name: str,
    batch_shape: tuple[int, ...] | None = None,
    **batched_params,
) -> "DistributionArray":
    """Class-method alias for :meth:`DistributionArray.from_batched_params`.

    Lets users write the ergonomic per-class form::

        Normal.from_batched_params(loc=jnp.zeros(5), scale=1.0, name="x")

    instead of the universal entry point::

        DistributionArray.from_batched_params(
            Normal, loc=jnp.zeros(5), scale=1.0, name="x",
        )

    Both produce the same ``DistributionArray`` — the alias is a
    thin classmethod that calls the universal factory with
    ``cls`` bound. Subclasses inherit the alias automatically;
    no per-family override is needed.

    See :meth:`DistributionArray.from_batched_params` for the full
    contract (dispatch on
    :class:`~probpipe.core.protocols.SupportsArrayBackend`,
    ``batch_shape`` inference, per-cell name suffixing).
    """
    # Local import: ``DistributionArray`` lives in the same
    # subpackage and importing at module top would create a cycle
    # (DistributionArray inherits from Distribution).
    from ._distribution_array import DistributionArray
    return DistributionArray.from_batched_params(
        cls, name=name, batch_shape=batch_shape, **batched_params,
    )

RecordDistribution(*, name)

Bases: Distribution[Record]

Generic Record-based distribution.

Provides named component access (fields, __getitem__, select()) and Record-aware flatten / unflatten. Does NOT impose numeric shape / dtype conventions (dtype, support, event_shape) — those belong on NumericRecordDistribution and its consumers.

Concrete subclasses must set _record_template (a RecordTemplate describing the named structure) and implement the relevant sampling / log-prob protocols.

Source code in probpipe/core/_distribution_base.py
def __init__(self, *, name: str):
    if not isinstance(name, str) or not name:
        raise TypeError(
            f"{type(self).__name__} requires a non-empty name= argument"
        )
    self._name = name

record_template property

Structural template describing this distribution's samples.

Returns a RecordTemplate with field names and per-field shapes, or None if no template is set.

fields property

Field names from the record_template, or empty tuple.

event_size property

Total number of scalar elements in one sample.

Sums the sizes of every numeric leaf described by the template; opaque leaves contribute zero. A NumericRecordTemplate has flat_size already cached — reuse it when available.

event_shapes property

Per-field event shapes.

For untemplated distributions (no record_template) returns {}; use event_shape for the scalar shape of a single unnamed field. Nested sub-templates collapse to () at the top level.

shape property

Shape of one draw (equals the sole field's event_shape).

ndim property

Number of axes in one draw.

select(*fields, **mapping)

Select named fields as views for workflow function broadcasting.

Positional args use the field name as the argument name. Keyword args remap: select(x="field_name").

Usage::

predict(**posterior.select("r", "K", "phi"), x=x_grid)
Source code in probpipe/core/_record_distribution.py
def select(self, *fields: str, **mapping: str) -> dict[str, _RecordDistributionView]:
    """Select named fields as views for workflow function broadcasting.

    Positional args use the field name as the argument name.
    Keyword args remap: ``select(x="field_name")``.

    Usage::

        predict(**posterior.select("r", "K", "phi"), x=x_grid)
    """
    result: dict[str, _RecordDistributionView] = {}
    for f in fields:
        result[f] = self[f]
    for arg_name, field_name in mapping.items():
        result[arg_name] = self[field_name]
    return result

select_all()

Return every component as a view, for splatting into function calls.

Sugar for select(*self.fields). Matches Record.select_all / RecordArray.select_all so the splat-all pattern works uniformly across the three field- bearing container types. Preserves cross-field correlation via the parent-identity machinery in the WorkflowFunction sweep layer.

Source code in probpipe/core/_record_distribution.py
def select_all(self) -> dict[str, _RecordDistributionView]:
    """Return every component as a view, for splatting into function calls.

    Sugar for ``select(*self.fields)``. Matches
    :meth:`Record.select_all` / :meth:`RecordArray.select_all` so
    the splat-all pattern works uniformly across the three field-
    bearing container types. Preserves cross-field correlation via
    the parent-identity machinery in the ``WorkflowFunction`` sweep
    layer.
    """
    return self.select(*self.fields)

keys()

Iterate over component names.

Source code in probpipe/core/_record_distribution.py
def keys(self) -> Iterator[str]:
    """Iterate over component names."""
    return iter(self.fields)

values()

Iterate over component views.

Source code in probpipe/core/_record_distribution.py
def values(self) -> Iterator[_RecordDistributionView]:
    """Iterate over component views."""
    for name in self.fields:
        yield self[name]

items()

Iterate over (name, view) pairs.

Source code in probpipe/core/_record_distribution.py
def items(self) -> Iterator[tuple[str, _RecordDistributionView]]:
    """Iterate over (name, view) pairs."""
    for name in self.fields:
        yield name, self[name]

flatten_value(value)

Flatten a NumericRecord or NumericRecordArray sample to a flat array.

The flatten operation is numeric-only, so Record inputs must be convertible to NumericRecord (all leaves numeric). Raw arrays are returned unchanged.

Source code in probpipe/core/_record_distribution.py
def flatten_value(self, value) -> Array:
    """Flatten a NumericRecord or NumericRecordArray sample to a flat array.

    The flatten operation is numeric-only, so ``Record`` inputs must
    be convertible to ``NumericRecord`` (all leaves numeric). Raw
    arrays are returned unchanged.
    """
    from ._numeric_record import NumericRecord
    from ._record_array import NumericRecordArray
    if isinstance(value, NumericRecordArray):
        return value.flatten()
    if isinstance(value, NumericRecord):
        return value.flatten()
    if isinstance(value, Record):
        return NumericRecord.from_record(value).flatten()
    return value

unflatten_value(flat)

Reconstruct a NumericRecord or NumericRecordArray from a flat array.

Source code in probpipe/core/_record_distribution.py
def unflatten_value(self, flat: Array):
    """Reconstruct a NumericRecord or NumericRecordArray from a flat array."""
    from ._numeric_record import NumericRecord
    from ._record_array import NumericRecordArray
    tpl = self.record_template
    if tpl is None:
        raise RuntimeError("Cannot unflatten without record_template")
    flat = jnp.asarray(flat)
    if flat.ndim < 2:
        return NumericRecord.unflatten(flat, template=tpl)
    return NumericRecordArray.unflatten(flat, template=tpl)

as_flat_distribution()

View this distribution as a flat FlatNumericRecordDistribution.

Returns a FlattenedDistributionView with event_shape = (event_size,) for algorithms expecting flat vectors (MCMC, optimizers, VI methods). Inverse: as_record_distribution.

Source code in probpipe/core/_record_distribution.py
def as_flat_distribution(self) -> FlattenedDistributionView:
    """View this distribution as a flat ``FlatNumericRecordDistribution``.

    Returns a :class:`~probpipe.core._numeric_record_distribution.FlattenedDistributionView`
    with ``event_shape = (event_size,)`` for algorithms expecting
    flat vectors (MCMC, optimizers, VI methods). Inverse:
    :meth:`~probpipe.core._numeric_record_distribution.FlatNumericRecordDistribution.as_record_distribution`.
    """
    from ._numeric_record_distribution import FlattenedDistributionView
    return FlattenedDistributionView(self)

NumericRecordDistribution(*, name)

Bases: RecordDistribution

Distribution over numeric arrays with Record support.

Extends RecordDistribution with numeric-specific metadata. The class is the most general numeric random variable in ProbPipe: samples are a pytree of jax.Array leaves named via RecordTemplate. Single-leaf distributions (Normal, Beta, MultivariateNormal, …) are the trivial case; the same machinery covers future multi-leaf joint distributions.

A Distribution represents one random variable. Collections of independent distributions live in DistributionArray.

Canonical / convenience accessor pairs

Per-field accessors (canonical) are the source of truth; scalar accessors (convenience) are derived shortcuts that raise on multi-leaf templates. Subclasses override the canonical side; convenience accessors are inherited and derived automatically.

+------------+---------------------------------+--------------------------------------+ | Concept | Canonical (per-leaf) | Convenience (single-leaf) | +============+=================================+======================================+ | Structure | record_template | — | +------------+---------------------------------+--------------------------------------+ | Pytree | treedef (from template) | — | +------------+---------------------------------+--------------------------------------+ | Shapes | event_shapes : dict | event_shape : tuple | | | | (raises on multi-leaf) | +------------+---------------------------------+--------------------------------------+ | Dtypes | dtypes : dict (raises if | dtype : dtype | None (unique or | | | not declared) | None) | +------------+---------------------------------+--------------------------------------+ | Supports | supports : dict (raises if | support : Constraint (raises on | | | not declared) | multi-leaf) | +------------+---------------------------------+--------------------------------------+ | Flat dim | event_size : int | — | +------------+---------------------------------+--------------------------------------+

Single-field auto-template

Any concrete subclass that declares an event_shape and is constructed with a name= gets an auto-built single-field RecordTemplate (RecordTemplate(**{name: event_shape})) on first read of record_template. Subclasses that need a multi-field template (joint distributions) override record_template directly to skip the auto-build.

_sample contract (Story A)
  • Single-leaf templates → _sample(key, sample_shape) returns a raw jax.Array of shape sample_shape + event_shape.
  • Multi-leaf templates → _sample(key, sample_shape) returns a NumericRecord (or NumericRecordArray for non-empty sample_shape) keyed by record_template.fields.

The treedef property locks this relationship by deriving from record_template.

Standard distributions (Normal, Gamma, Poisson, etc.) inherit from this class via TFPDistribution.

Source code in probpipe/core/_distribution_base.py
def __init__(self, *, name: str):
    if not isinstance(name, str) or not name:
        raise TypeError(
            f"{type(self).__name__} requires a non-empty name= argument"
        )
    self._name = name

record_template property

Auto-build a single-field RecordTemplate from name + event_shape when the subclass hasn't set one.

Cached via object.__setattr__ on first read. Multi-field subclasses (joint distributions) override this property to skip the auto-build.

dtypes property

Per-field dtypes — canonical, subclasses must override.

Returns a {field: dtype} dict aligned with record_template.fields. Default raises NotImplementedError rather than returning a silent default-float for every field (which lied for integer- valued distributions like Bernoulli, Poisson, Categorical).

supports property

Per-field support constraints — canonical, subclasses must override.

Subclasses should override to provide meaningful constraints. Default raises NotImplementedError.

dtype property

Convenience: scalar dtype if all fields share one, else None.

Derived from dtypes. dtypes is the canonical per-field accessor; subclasses override that, not this.

support property

Convenience: support for a single-field distribution.

Derived from supports. supports is the canonical per-field accessor; subclasses override that (or, for TFPDistribution-backed classes that follow the existing single-field override pattern, override support directly to short-circuit this derivation).

Raises TypeError (via _single_field_name) on multi-field distributions; reach for supports then.

treedef property

Treedef of one sample, derived from record_template.

Locks the relationship between the structural template and the sample's pytree shape:

  • Single-leaf template (len(fields) <= 1) → a leaf treedef (jax.tree.structure(None)). Matches the _sample contract that single-leaf distributions return a raw jax.Array.
  • Multi-leaf template → the treedef of a NumericRecord skeleton with the same field names. Matches the _sample contract that multi-leaf distributions return a NumericRecord.

Cached on first read; the underlying template is immutable post-construction so the cache is always valid.

flat_event_shapes property

List of per-field event shapes in template field order.

Tree-walk over event_shapes: list(event_shapes.values()). For a single-field distribution this is [event_shape]; for a multi-leaf distribution it's one entry per leaf.

flatten_value(value)

Flatten a sample (Record, NumericRecordArray, or array) to flat trailing axis.

Delegates to RecordDistribution.flatten_value for Record-like inputs. For raw arrays, flattens event dimensions preserving leading batch/sample dims.

Source code in probpipe/core/_numeric_record_distribution.py
def flatten_value(self, value) -> Array:
    """Flatten a sample (Record, NumericRecordArray, or array) to flat trailing axis.

    Delegates to ``RecordDistribution.flatten_value`` for Record-like
    inputs.  For raw arrays, flattens event dimensions preserving
    leading batch/sample dims.
    """
    from .record import Record
    from ._record_array import NumericRecordArray
    if isinstance(value, (NumericRecordArray, Record)):
        return super().flatten_value(value)
    value = jnp.asarray(value)
    es = self.event_shape
    n_event = prod(es)
    if not es:
        return value[..., None]
    n_batch = value.ndim - len(es)
    batch_dims = value.shape[:n_batch]
    return value.reshape(*batch_dims, n_event)

unflatten_value(flat)

Unflatten a flat trailing axis back to event dims, Record, or NumericRecordArray.

When record_template is set with multiple fields, delegates to RecordDistribution.unflatten_value (returns NumericRecord or NumericRecordArray). For single-field leaf distributions, reshapes to (*batch, *event_shape) for _log_prob compat.

Source code in probpipe/core/_numeric_record_distribution.py
def unflatten_value(self, flat: ArrayLike):
    """Unflatten a flat trailing axis back to event dims, Record, or NumericRecordArray.

    When ``record_template`` is set with multiple fields, delegates
    to ``RecordDistribution.unflatten_value`` (returns NumericRecord
    or NumericRecordArray).  For single-field leaf distributions,
    reshapes to ``(*batch, *event_shape)`` for ``_log_prob`` compat.
    """
    tpl = self.record_template
    if tpl is not None and len(tpl.fields) > 1:
        return super().unflatten_value(flat)
    flat = jnp.asarray(flat)
    es = self.event_shape
    if not es:
        return flat[..., 0]
    batch_dims = flat.shape[:-1]
    return flat.reshape(*batch_dims, *es)

as_flat_distribution()

View this distribution as a flat distribution.

Returns a FlattenedDistributionView wrapping this distribution. The view satisfies the FlatNumericRecordDistribution contract regardless of self's structure (multi-field, multi-dim event, …) — its event_shape is always (self.event_size,).

Inverse: FlatNumericRecordDistribution.as_record_distribution.

Source code in probpipe/core/_numeric_record_distribution.py
def as_flat_distribution(self) -> FlatNumericRecordDistribution:
    """View this distribution as a flat distribution.

    Returns a :class:`FlattenedDistributionView` wrapping this
    distribution. The view satisfies the
    :class:`FlatNumericRecordDistribution` contract regardless of
    ``self``'s structure (multi-field, multi-dim event, …) — its
    ``event_shape`` is always ``(self.event_size,)``.

    Inverse: :meth:`FlatNumericRecordDistribution.as_record_distribution`.
    """
    return FlattenedDistributionView(self)

as_record_distribution(*, template, name=None)

Lift this distribution to a Record-keyed view under template.

Only available on FlatNumericRecordDistribution subclasses. Calling this on a non-flat NumericRecordDistribution raises TypeError with a hint to call as_flat_distribution first.

See FlatNumericRecordDistribution.as_record_distribution for the actual implementation and parameters.

Source code in probpipe/core/_numeric_record_distribution.py
def as_record_distribution(
    self,
    *,
    template: NumericRecordTemplate,
    name: str | None = None,
) -> NumericRecordDistribution:
    """Lift this distribution to a Record-keyed view under *template*.

    **Only available on :class:`FlatNumericRecordDistribution` subclasses.**
    Calling this on a non-flat :class:`NumericRecordDistribution`
    raises :class:`TypeError` with a hint to call
    :meth:`as_flat_distribution` first.

    See :meth:`FlatNumericRecordDistribution.as_record_distribution`
    for the actual implementation and parameters.
    """
    raise TypeError(
        f"as_record_distribution is only available on "
        f"FlatNumericRecordDistribution subclasses. "
        f"{type(self).__name__} is not flat. Chain: "
        f"source.as_flat_distribution().as_record_distribution(template=...)."
    )

TFPDistribution(*, name)

Bases: NumericRecordDistribution, SupportsSampling, SupportsLogProb, SupportsMean, SupportsVariance, SupportsCovariance

Base class for distributions backed by a tfd.Distribution instance.

Subclasses set self._tfp_dist in __init__. The private protocol methods _sample, _expectation, _log_prob, _mean, and _variance all delegate to TFP (or use MC fallback for expectations).

Inherits from SupportsSampling, SupportsExpectation, SupportsLogProb (provides _prob, _unnormalized_log_prob, _unnormalized_prob defaults), SupportsMean, and SupportsVariance.

Rejects batched parameters

Per the framework hierarchy "one random variable per Distribution" rule (CONTRIBUTING.md), the constructor raises ValueError when the underlying tfd.Distribution has a non-empty batch_shape. Wrap multiple distributions in a DistributionArray instead — the migration factory is from_batched_params (or the per-class alias Distribution.from_batched_params).

The check fires in __init__ after super().__init__(name=name) completes, so concrete subclasses that set self._tfp_dist before calling super().__init__ (the standard pattern used by Normal, Beta, Gamma, …) are validated. Subclasses that set _tfp_dist after super().__init__ (e.g. KDEDistribution) are skipped via the hasattr guard — those classes are responsible for their own shape invariants and don't go through TFP's batched parameter convention.

Internal infrastructure that legitimately needs the batched form (the _TFPArrayBackend fused storage, converters, sequential joints, GRF predictions) opts into the bypass via _allow_batched_tfp_init.

Final-stage initializer for TFP-backed distributions.

Concrete subclasses (Normal, Beta, …) set self._tfp_dist in their own __init__ before calling super().__init__(name=name), so by the time we get here the TFP backend is fully constructed and we can validate its batch_shape.

Source code in probpipe/distributions/_tfp_base.py
def __init__(self, *, name: str) -> None:
    """Final-stage initializer for TFP-backed distributions.

    Concrete subclasses (``Normal``, ``Beta``, …) set
    ``self._tfp_dist`` in their own ``__init__`` *before* calling
    ``super().__init__(name=name)``, so by the time we get here
    the TFP backend is fully constructed and we can validate its
    ``batch_shape``.
    """
    super().__init__(name=name)
    if _BATCHED_INIT_BYPASS.get():
        return
    # KDE-style subclasses set ``_tfp_dist`` *after* this call;
    # skip the check rather than crash on a missing attribute.
    # Such classes are responsible for their own shape invariants.
    tfp_dist = getattr(self, "_tfp_dist", None)
    if tfp_dist is None:
        return
    actual = tuple(tfp_dist.batch_shape)
    if actual != ():
        cls_name = type(self).__name__
        raise ValueError(
            f"{cls_name} parameters imply batch_shape={actual}; "
            f"wrap multiple distributions in a DistributionArray "
            f"instead. See "
            f"DistributionArray.from_batched_params({cls_name}, ...) "
            f"(or the alias {cls_name}.from_batched_params(...)) "
            f"for the factory."
        )

dtypes property

Per-field dtypes — the canonical accessor for TFPDistribution. Reads self._tfp_dist.dtype and spreads it across every field of the auto-built single-field template. dtype (the convenience) is inherited from the base and derives from this dict.

support property

The support of this distribution. Override in subclasses.

supports property

Per-field support constraints — spreads the single-field support (overridden by each concrete TFP-backed subclass) across the auto-built template.

Protocols

Protocols define capabilities that distributions may support. Each protocol is @runtime_checkable; compliance is checked via isinstance at dispatch time. External types satisfy a protocol structurally by implementing the underscore method (_sample, _log_prob, ...) — no inheritance required.

SupportsSampling

Bases: Protocol

Distribution that can produce samples via _sample(key, sample_shape).

Only requires _sample(key, sample_shape); concrete classes choose their own implementation strategy (TFP batched sampling, index resampling, vmap over a local single-draw helper, etc.).

Does NOT extend SupportsExpectation — not all samplable distributions support array-valued expectations (e.g., random functions). Classes that support both should inherit both protocols.

Return-type convention

The shape of the return value depends on whether the distribution emits structured samples and whether the caller asks for a batch:

===================== ======================= ========================================= Distribution kind sample_shape == () sample_shape == (S1, S2, ...) ===================== ======================= ========================================= Numeric (raw array) Array[*event_shape] Array[*sample_shape, *event_shape] RecordDistribution Record / NumericRecord NumericRecordArray(batch_shape=sample_shape) ===================== ======================= =========================================

To draw a single sample, call _sample(key, ()). Implementations that find it clearer to factor out a single-draw helper should define it as a private method (e.g. _one_bootstrap) and have _sample dispatch on sample_shape internally.

SupportsExpectation

Bases: Protocol

Distribution that can compute E[f(X)].

SupportsLogProb

Bases: SupportsUnnormalizedLogProb, Protocol

Distribution with a (normalized) log-density.

Extends SupportsUnnormalizedLogProb because any distribution with a normalized density also has an unnormalized one (they coincide). The base Distribution class provides _unnormalized_log_prob defaulting to _log_prob.

SupportsUnnormalizedLogProb

Bases: Protocol

Distribution with an unnormalized log-density.

Provides _unnormalized_log_prob(value).

SupportsRandomLogProb

Bases: Protocol

Distribution over distributions with a random (normalized) log-density.

For a RandomMeasure[T] M, _random_log_prob returns the random function x ↦ log D(x) where D ~ M as a RandomFunction. The op layer (random_log_prob) optionally forwards an input value by calling the returned random function; that two-argument convenience is purely op-layer sugar — concrete subclasses implement only the zero-argument method here.

Mirrors SupportsLogProb for the random-measure setting.

SupportsRandomUnnormalizedLogProb

Bases: Protocol

Distribution over distributions with a random unnormalized log-density.

Mirrors SupportsUnnormalizedLogProb for the random-measure setting. _random_unnormalized_log_prob returns the random function x ↦ log D̃(x) where is the unnormalized density of a draw D ~ M, as a RandomFunction. The op layer (random_unnormalized_log_prob) optionally forwards an input value by calling the returned random function; that two-argument convenience is purely op-layer sugar — concrete subclasses implement only the zero-argument method here.

SupportsMean

Bases: Protocol

Distribution with an exact mean via _mean().

The return type is T-shaped where T is the distribution's sample type. For the common cases this is:

  • NumericRecordDistribution and friends (T = Array) — returns Array.
  • RecordDistribution and friends (T = Record) — returns Record.
  • RandomMeasure[T] (T = Distribution[T]) — returns the marginalised Distribution[T] with marginal D̄(A) = ∫ D(A) dM(D).

The protocol is sample-type-polymorphic by design: the array-valued and structured paths are unchanged; RandomMeasure opts in by implementing _mean to return its expected distribution.

Independent of SupportsExpectation. The ops layer falls back to MC estimation via SupportsExpectation when this protocol is absent. Concrete classes that want the MC default can apply @compute_expectation to their _mean implementation (only valid when T is array-like).

SupportsVariance

Bases: Protocol

Distribution with an exact variance via _variance().

Independent of SupportsExpectation. The ops layer falls back to MC estimation via SupportsExpectation when this protocol is absent. Concrete classes that want the MC default can apply @compute_expectation to their _variance implementation.

SupportsCovariance

Bases: Protocol

Distribution with an exact covariance via _cov().

Independent of SupportsExpectation. The ops layer falls back to MC estimation via SupportsExpectation when this protocol is absent. Concrete classes that want the MC default can apply @compute_expectation to their _cov implementation.

SupportsConditioning

Bases: Protocol

Distribution that has a fast, built-in condition_on path.

Implemented by distributions whose _condition_on produces a posterior without calling into the inference registry — either closed-form (conjugate updates, joint Gaussian marginalization) or amortized (e.g., a pre-trained SBI posterior that just runs a forward pass). When condition_on(dist, observed) is called and dist implements this protocol, the built-in path is used directly; otherwise the inference method registry selects an algorithm (NUTS, RWMH, variational, ...).

Probabilistic models whose conditioning requires on-the-fly MCMC or variational inference should not implement this protocol — let the registry handle algorithm selection instead.

SupportsArrayBackend is the only class-level protocol: its declared method (_make_array_backend) is a @classmethod, so the runtime check is isinstance(MyDistribution, SupportsArrayBackend) against the class itself, not an instance.

SupportsArrayBackend

Bases: Protocol

Distribution class that supports efficient batched construction.

Used by DistributionArray.from_batched_params to fuse storage when the caller's components are homogeneous instances of the same class. Implementations construct an internal _DistributionArrayBackend that owns the batched parameters and the vectorised ops; DistributionArray becomes a thin consumer.

Distribution classes that don't implement this protocol still work in a DistributionArray via the literal-array fallback (one Distribution instance per cell) — slower but correct.

The protocol attaches to the class, not to instances. The runtime check is isinstance(MyDistribution, SupportsArrayBackend) (i.e. the class itself implements _make_array_backend). isinstance(an_instance, SupportsArrayBackend) returns True too — instances inherit class attributes, and runtime_checkable just looks for the named attribute — but the result is misleading because the contract is at class scope.

The protocol is internal to the library; user code never calls _make_array_backend directly. DistributionArray is the sole consumer.

Examples:

A distribution class declares the capability by implementing the classmethod::

class MyDistribution(Distribution[T]):
    @classmethod
    def _make_array_backend(
        cls,
        *,
        name: str,
        batch_shape: tuple[int, ...],
        **batched_params,
    ) -> _DistributionArrayBackend:
        return _MyArrayBackend(
            cls=cls, name=name, batch_shape=batch_shape,
            **batched_params,
        )

Custom inference methods

Method subclasses register with inference_method_registry and declare supported_types, a priority, and check() / execute() methods. When condition_on runs, the registry tries methods in descending priority order and the first whose check() reports feasibility wins. The built-in methods table is on Modeling and inference → Inference methods.

Setting priority for a new method

The integer returned by priority carries semantics: it tells the registry whether your method should auto-dispatch, and if so, where it ranks against the alternatives.

  • priority > 50exact: auto-dispatched, higher = preferred among exact alternatives.
  • 0 < priority <= 50inexact: auto-dispatched, higher = preferred among inexact alternatives. The 50 break is documentary; the registry walks every positive priority uniformly.
  • priority == 0opt-in only: the registry skips the method during auto-dispatch. The method is reachable by name via method="...". This is the default; a Method subclass that doesn't override priority gets opt-in behaviour automatically.

Selection criteria

Choose a number with these axes in mind, roughly in order of weight:

  1. Robustness when applicable — how often the method gives a usable answer without per-model tuning, conditional on check() passing.
  2. Computational cost per effective sample (or per converged result). Two kinds of cost advantage deserve separate consideration: algorithmic specialisation that exploits model structure for an asymptotic speedup (Kalman, INLA, conjugate updates), and engineering specialisation — same algorithm, faster backend (nutpie's Rust-backed NUTS vs. TFP's; Stan's compiled gradients vs. JAX traces).
  3. Approximation quality — analytical exact > controlled-error approximations > asymptotically-exact MCMC > intrinsic approximations.
  4. Diagnostic richness — methods that fail silently rank below methods with built-in failure signals, all else equal.
  5. Model-class breadth as a tiebreaker only. A broader-applicability method does not need a higher priority than a narrow one; whichever applies wins via check().

Tier ranges — exact (51–100)

Five tiers, each 10 wide. Criteria are stated as positive properties of the method.

Range Criterion
91–100 Per-call cost in a strictly better complexity class than general-purpose alternatives; the speedup comes from exploiting model structure.
81–90 Optimised implementation of a more general algorithm; lower constant-factor cost than the reference implementation within its applicable model class.
71–80 Self-tuning; converges robustly without per-model hyperparameter selection.
61–70 Well-understood with strong convergence theory; performs well once hand-tuned.
51–60 Slow per effective sample or unreliable in typical use.

Tier ranges — inexact (1–50)

Four named tiers ordered by the strength of the asymptotic-to-exact story. The slot at 11–20 is intentionally reserved for methods with intermediate guarantees that don't fit cleanly into a named tier.

Range Criterion
41–50 Asymptotically exact under algorithmic refinement; bias is a knob the user can tighten (step size, mini-batch size).
31–40 Particle-based approximation refinable by particle count; quality improves with more particles, though convergence may be slow or unstable.
21–30 Parametric posterior approximation; error bounded by family expressiveness or by regularity conditions on the posterior shape.
11–20 (reserved for methods with intermediate guarantees not covered by neighbouring tiers)
1–10 No asymptotic-to-exact guarantee in practice; quality bounded by intrinsic information loss (summary statistics, fixed tolerance, learned representations).

Setting priority on a Method subclass

class MyNutsMethod(Method):
    @property
    def priority(self) -> int:
        # Tier 71-80 (self-tuning, broadly applicable).
        return 75

A method that should not auto-dispatch — perhaps it's experimental, has sharp failure modes, or exists only for method= testing — leaves priority at the inherited default of 0. The registry will exclude it from the auto-dispatch walk; users can still invoke it explicitly by name.

MethodRegistry()

Generic priority-based method registry.

Methods are tried in descending priority order. The first method whose check() returns feasible=True wins. Users can also select a specific method by name.

Methods whose effective priority equals OPT_IN_ONLY_PRIORITY (0) are excluded from the auto-dispatch walk and are reachable only by name via method="...".

Source code in probpipe/core/_registry.py
def __init__(self) -> None:
    self._methods: list[M] = []
    self._name_index: dict[str, M] = {}
    self._priority_overrides: dict[str, int] = {}
    self._type_cache: dict[type, list[M]] = {}

register(method)

Register a method (invalidates the lookup cache).

Source code in probpipe/core/_registry.py
def register(self, method: M) -> None:
    """Register a method (invalidates the lookup cache)."""
    if method.name in self._name_index:
        raise ValueError(
            f"Method name {method.name!r} is already registered"
        )
    self._methods.append(method)
    self._name_index[method.name] = method
    self._sort_methods()

set_priorities(**name_to_priority)

Override the priority of one or more methods.

Higher priority methods are tried first during auto-selection. Overrides are unrestricted: the new value can be any integer, including the opt-in-only sentinel 0. When an override moves a method into or out of 0, the registry emits a UserWarning because that crossing changes whether the method participates in auto-dispatch at all.

Parameters:

Name Type Description Default
**name_to_priority int

Keyword arguments mapping method names to new priorities. e.g., set_priorities(tfp_rwmh=200, tfp_nuts=50)

{}

Raises:

Type Description
KeyError

If a method name is not registered.

Source code in probpipe/core/_registry.py
def set_priorities(self, **name_to_priority: int) -> None:
    """Override the priority of one or more methods.

    Higher priority methods are tried first during auto-selection.
    Overrides are unrestricted: the new value can be any integer,
    including the opt-in-only sentinel ``0``. When an override
    moves a method *into* or *out of* ``0``, the registry emits a
    :class:`UserWarning` because that crossing changes whether the
    method participates in auto-dispatch at all.

    Parameters
    ----------
    **name_to_priority
        Keyword arguments mapping method names to new priorities.
        e.g., ``set_priorities(tfp_rwmh=200, tfp_nuts=50)``

    Raises
    ------
    KeyError
        If a method name is not registered.
    """
    for name in name_to_priority:
        if name not in self._name_index:
            available = ", ".join(sorted(self._name_index)) or "(none)"
            raise KeyError(
                f"No method named {name!r}. Available: {available}"
            )
    for name, new_priority in name_to_priority.items():
        old_priority = self._effective_priority(self._name_index[name])
        if (old_priority == OPT_IN_ONLY_PRIORITY) != (
            new_priority == OPT_IN_ONLY_PRIORITY
        ):
            direction = (
                "out of opt-in-only"
                if old_priority == OPT_IN_ONLY_PRIORITY
                else "into opt-in-only"
            )
            warnings.warn(
                f"Priority override for {name!r} moves it {direction} "
                f"({old_priority} -> {new_priority}); auto-dispatch "
                f"participation changes accordingly.",
                UserWarning,
                stacklevel=2,
            )
    self._priority_overrides.update(name_to_priority)
    self._sort_methods()

get_method(name)

Look up a method by name. Raises KeyError if not found.

Source code in probpipe/core/_registry.py
def get_method(self, name: str) -> M:
    """Look up a method by name.  Raises ``KeyError`` if not found."""
    try:
        return self._name_index[name]
    except KeyError:
        available = ", ".join(sorted(self._name_index)) or "(none)"
        raise KeyError(
            f"No method named {name!r}. Available: {available}"
        ) from None

list_methods()

Return method names in priority order (highest first).

Source code in probpipe/core/_registry.py
def list_methods(self) -> list[str]:
    """Return method names in priority order (highest first)."""
    return [m.name for m in self._methods]

check(*args, method=None, **kwargs)

Check feasibility. Auto-selects or uses the named method.

Source code in probpipe/core/_registry.py
def check(
    self, *args: Any, method: str | None = None, **kwargs: Any
) -> MethodInfo:
    """Check feasibility.  Auto-selects or uses the named method."""
    if method is not None:
        m = self.get_method(method)
        return m.check(*args, **kwargs)

    key_type = type(args[0]) if args else object
    for m in self._find_methods(key_type):
        info = m.check(*args, **kwargs)
        if info.feasible:
            return info

    return MethodInfo(
        feasible=False,
        description=f"No applicable method for {key_type.__name__}",
    )

execute(*args, method=None, **kwargs)

Execute using the best (or named) method.

Raises TypeError if no method is applicable. Raises KeyError if method is not registered.

Source code in probpipe/core/_registry.py
def execute(
    self, *args: Any, method: str | None = None, **kwargs: Any
) -> Any:
    """Execute using the best (or named) method.

    Raises ``TypeError`` if no method is applicable.
    Raises ``KeyError`` if *method* is not registered.
    """
    if method is not None:
        m = self.get_method(method)
        info = m.check(*args, **kwargs)
        if not info.feasible:
            raise TypeError(
                f"Method {method!r} is not applicable: {info.description}"
            )
        return m.execute(*args, **kwargs)

    key_type = type(args[0]) if args else object
    for m in self._find_methods(key_type):
        info = m.check(*args, **kwargs)
        if info.feasible:
            return m.execute(*args, **kwargs)

    raise TypeError(
        f"No method registered for {key_type.__name__}. "
        f"Available: {self.list_methods()}"
    )

Method

Bases: ABC

Abstract base for a pluggable method in a registry.

Subclasses declare a unique name, which supported_types they handle (for fast filtering), a priority (higher = tried first), and implement check/execute.

name abstractmethod property

Unique identifier for this method (e.g., 'tfp_nuts').

priority property

Auto-dispatch ordering.

Higher priority methods are tried first during auto-selection. The default of 0 is the opt-in-only sentinel — a method that doesn't override this property is reachable only by name via method="...". See the module docstring for the > 50 / <= 50 / == 0 convention contributors should follow when choosing a number.

supported_types() abstractmethod

Types this method can operate on (fast pre-filter).

Source code in probpipe/core/_registry.py
@abstractmethod
def supported_types(self) -> tuple[type, ...]:
    """Types this method can operate on (fast pre-filter)."""
    ...

check(*args, **kwargs) abstractmethod

Probe whether this method is applicable (must be cheap).

Source code in probpipe/core/_registry.py
@abstractmethod
def check(self, *args: Any, **kwargs: Any) -> MethodInfo:
    """Probe whether this method is applicable (must be cheap)."""
    ...

execute(*args, **kwargs) abstractmethod

Run the method and return the result.

Source code in probpipe/core/_registry.py
@abstractmethod
def execute(self, *args: Any, **kwargs: Any) -> Any:
    """Run the method and return the result."""
    ...

MethodInfo(feasible, method_name='', description='') dataclass

Metadata describing whether a method is applicable.

Returned by a method's check() to describe feasibility without performing the actual computation.

Custom converters

Subclass Converter, implement check() / convert(), and register with converter_registry.register(...). The built-in priorities and the registry handle itself are documented under Conversion and interop.

Custom bijectors

register_bijector overrides the canonical bijector returned by bijector_for(c) for a given Constraint instance or class. See Constraints → Bijectors.

Custom auxiliary metadata

register_aux extends the RecordNumericRecord round-trip to a new array-like type. See Records → Auxiliary-metadata registry.

Broadcasting internals (exposed for extension)

DistributionArray is the shape-indexed container produced by parameter-sweep workflow functions whose inner call returns a Distribution. BroadcastDistribution is the joint container produced by @workflow_function when include_inputs=True.

DistributionArray(components, *, batch_shape=None, name=None)

Bases: Distribution[T]

Ordered collection of independent scalar distributions addressed by a (multi-d) batch_shape.

Exposes only the container surface (indexing, iteration, components, batch_shape, event_shape). Vectorized ops (sample, mean, variance, log_prob, …) are delivered by the WorkflowFunction sweep layer, which treats the array as Array[Distribution] and dispatches cell-by-cell.

Parameters:

Name Type Description Default
components sequence of Distribution

The n component distributions. Must be non-empty and share event_shape.

required
batch_shape tuple of int

Leading batch shape. Defaults to (len(components),) for the 1-D form; prod(batch_shape) must equal len(components).

None
name str

Name for provenance / introspection. Defaults to "distribution_array".

None
Notes

How ops work on a DistributionArray. The class deliberately does not implement _sample / _mean / _log_prob / etc. — those would couple the array to specific component capabilities. Vectorization is handled at a different layer:

  1. sample(da, ...) calls the sample WorkflowFunction, whose dispatch sees a DistributionArray argument where the op's annotation expects a scalar SupportsSampling.
  2. WF dispatches cell-by-cell: each da[i] is sampled, results are stacked along batch_shape and returned as a NumericRecordArray (or RecordArray for non-numeric components). For ops whose inner return is itself a Distribution (e.g. posterior-predictive sweeps), the result is a nested DistributionArray.
  3. Multiple swept arguments combine by the product rule: passing two DistributionArray args of shapes (m,) and (n,) produces an output of shape (m, n).

Consequences of this design:

  • Calling da._sample(key) directly raises AttributeErrorDistributionArray doesn't have _sample. Always use the public op (sample(da, key=...)).
  • isinstance(da, SupportsSampling) is False even when every component supports sampling. The protocol attaches to individual cells, not to the array.
  • Component capabilities don't have to be uniform: an array where some cells are SupportsLogProb and some are not will fail at op-dispatch on the first non-supporting cell, rather than rejecting at construction.
Source code in probpipe/core/_distribution_array.py
def __init__(
    self,
    components,
    *,
    batch_shape: tuple[int, ...] | None = None,
    name: str | None = None,
):
    components = tuple(components)
    if not components:
        raise ValueError("DistributionArray requires at least one component")
    # Components must share event_shape. Batching lives on the
    # DistributionArray itself; per the "one random variable per
    # Distribution" rule, components have no batch_shape.
    es0 = getattr(components[0], "event_shape", ())
    for i, c in enumerate(components):
        es = getattr(c, "event_shape", ())
        if es != es0:
            raise ValueError(
                f"DistributionArray requires matching event_shape "
                f"across components; components[0].event_shape={es0} "
                f"but components[{i}].event_shape={es}."
            )
    # ``batch_shape`` defaults to (n,) for backward compatibility
    # with the 1-D-only form used until now. Multi-d broadcasting
    # passes the full sweep shape explicitly.
    if batch_shape is None:
        batch_shape = (len(components),)
    else:
        batch_shape = tuple(batch_shape)
        if prod(batch_shape) != len(components):
            raise ValueError(
                f"DistributionArray batch_shape={batch_shape} implies "
                f"{prod(batch_shape)} components but got "
                f"{len(components)}."
            )
    self._components = components
    self._batch_shape = batch_shape
    # Set only by :meth:`_from_backend`. The literal-array path
    # leaves it ``None`` and uses ``_components`` as the
    # storage-of-truth.
    self._backend = None
    if name is None:
        name = "distribution_array"
    super().__init__(name=name)
    # A DistributionArray holding MC-marginal components inherits
    # their approximation status; if any component is approximate
    # (a _MixtureMarginal or RecordEmpiricalDistribution), so is
    # the stack.
    self._approximate = any(
        getattr(c, "is_approximate", False) for c in components
    )

components property

Flat tuple of component distributions, in row-major order across the leading batch_shape.

For backend-delegated arrays the tuple is materialised lazily on first access via backend.cell(i) for each flat index and cached. Cells are still freshly constructed inside cell() (no de-duplication), so successive components accesses return the same cached tuple but indexing via __getitem__ / _flat_component always returns a fresh scalar.

batch_shape property

Batch axes of this DistributionArray.

Components are scalar (one random variable per Distribution), so this is simply self._batch_shape — there is no inherit-from-component composition. Multi-d broadcasting outputs pass the full sweep shape; the default 1-D form is (n,).

event_shape property

Shared event_shape across components.

dtype property

Per-cell dtype.

Cells share an event shape and (in practice) a dtype because homogeneous backends produce uniformly-typed cells and literal-array constructions inherit from the source. Backend-delegated arrays read it from the backend; literal arrays read it from the first component.

size property

Total number of cells (prod(batch_shape)).

Mirrors np.ndarray.size / jax.Array.size: len(da) is the leading-axis dim, da.size is the total cell count.

from_batched_params(dist_cls, *, name, batch_shape=None, **batched_params) classmethod

Construct a DistributionArray of homogeneous components.

The recommended way to build a DistributionArray whose cells are all instances of the same class — most often a TFP-backed family like Normal — without manually constructing each cell::

DistributionArray.from_batched_params(
    Normal, loc=jnp.zeros(5), scale=1.0, name="x",
)

When dist_cls implements SupportsArrayBackend (every TFP-backed concrete class does — Normal, Beta, Gamma, MultivariateNormal, …), the factory dispatches onto the backend's fused-storage path: a single batched TFP backend owns the parameters; cells are materialised lazily on demand. Otherwise the factory falls back to the literal-array path: one dist_cls instance per cell with per-cell parameters auto-sliced and names auto-suffixed f"{name}_{flat_index}".

Parameters:

Name Type Description Default
dist_cls type

A Distribution subclass. The factory does not instantiate dist_cls directly when the protocol path is taken; per-cell scalars are produced by the backend.

required
name str

Base name; per-cell scalars are named f"{name}_{flat_index}" (row-major over batch_shape).

required
batch_shape tuple of int

Leading shape of the batched parameters. Inferred from batched_params (broadcast shape of array-valued entries) when omitted.

None
**batched_params

Constructor kwargs for dist_cls with leading batch_shape already applied. Scalars are broadcast across every cell.

{}

Returns:

Type Description
DistributionArray

Backend-delegated when dist_cls implements SupportsArrayBackend; literal-array fallback otherwise.

Raises:

Type Description
ValueError

If batch_shape cannot be inferred (no array-valued params) and the caller did not pass it explicitly.

Examples:

Backend-delegated TFP path::

da = DistributionArray.from_batched_params(
    Normal, loc=jnp.zeros(5), scale=1.0, name="x",
)
da.batch_shape       # (5,)
da[0].name           # "x_0"
da[0].loc            # 0.0
da._backend          # _TFPArrayBackend(...)

Literal-array fallback (any class without the protocol)::

da = DistributionArray.from_batched_params(
    MyCustomDist, param=jnp.arange(4), name="z",
)
da._backend          # None — fallback path
da[0].name           # "z_0"
Source code in probpipe/core/_distribution_array.py
@classmethod
def from_batched_params(
    cls,
    dist_cls: type,
    *,
    name: str,
    batch_shape: tuple[int, ...] | None = None,
    **batched_params,
) -> "DistributionArray":
    """Construct a ``DistributionArray`` of homogeneous components.

    The recommended way to build a ``DistributionArray`` whose
    cells are all instances of the same class — most often a
    TFP-backed family like ``Normal`` — without manually
    constructing each cell::

        DistributionArray.from_batched_params(
            Normal, loc=jnp.zeros(5), scale=1.0, name="x",
        )

    When ``dist_cls`` implements
    :class:`~probpipe.core.protocols.SupportsArrayBackend` (every
    TFP-backed concrete class does — ``Normal``, ``Beta``,
    ``Gamma``, ``MultivariateNormal``, …), the factory dispatches
    onto the backend's fused-storage path: a single batched
    TFP backend owns the parameters; cells are materialised lazily
    on demand. Otherwise the factory falls back to the
    literal-array path: one ``dist_cls`` instance per cell with
    per-cell parameters auto-sliced and names auto-suffixed
    ``f"{name}_{flat_index}"``.

    Parameters
    ----------
    dist_cls : type
        A ``Distribution`` subclass. The factory does not
        instantiate ``dist_cls`` directly when the protocol path
        is taken; per-cell scalars are produced by the backend.
    name : str
        Base name; per-cell scalars are named
        ``f"{name}_{flat_index}"`` (row-major over
        ``batch_shape``).
    batch_shape : tuple of int, optional
        Leading shape of the batched parameters. Inferred from
        ``batched_params`` (broadcast shape of array-valued
        entries) when omitted.
    **batched_params
        Constructor kwargs for ``dist_cls`` with leading
        ``batch_shape`` already applied. Scalars are broadcast
        across every cell.

    Returns
    -------
    DistributionArray
        Backend-delegated when ``dist_cls`` implements
        ``SupportsArrayBackend``; literal-array fallback otherwise.

    Raises
    ------
    ValueError
        If ``batch_shape`` cannot be inferred (no array-valued
        params) and the caller did not pass it explicitly.

    Examples
    --------
    Backend-delegated TFP path::

        da = DistributionArray.from_batched_params(
            Normal, loc=jnp.zeros(5), scale=1.0, name="x",
        )
        da.batch_shape       # (5,)
        da[0].name           # "x_0"
        da[0].loc            # 0.0
        da._backend          # _TFPArrayBackend(...)

    Literal-array fallback (any class without the protocol)::

        da = DistributionArray.from_batched_params(
            MyCustomDist, param=jnp.arange(4), name="z",
        )
        da._backend          # None — fallback path
        da[0].name           # "z_0"
    """
    inferred_shape = _infer_batch_shape(batched_params, batch_shape)
    if isinstance(dist_cls, SupportsArrayBackend):
        backend = dist_cls._make_array_backend(
            name=name,
            batch_shape=inferred_shape,
            **batched_params,
        )
        return cls._from_backend(backend, name=name)
    return cls._from_literal_components(
        dist_cls,
        name=name,
        batch_shape=inferred_shape,
        batched_params=batched_params,
    )

__getitem__(key)

Index a single component, slice, or multi-d tuple.

Supported key forms:

  • int → index along the leading batch axis.
  • slice → slice along the leading batch axis; returns a new DistributionArray containing the sliced subset.
  • tuple → multi-axis index (int or slice per axis); collapses along int axes and slices along slice axes.

Indexing uses row-major order across batch_shape. The pure-int path translates the key directly to a flat index via np.ravel_multi_index without materialising a shape=batch_shape object array.

Source code in probpipe/core/_distribution_array.py
def __getitem__(self, key):
    """Index a single component, slice, or multi-d tuple.

    Supported key forms:

    * ``int`` → index along the leading batch axis.
    * ``slice`` → slice along the leading batch axis; returns a
      new ``DistributionArray`` containing the sliced subset.
    * ``tuple`` → multi-axis index (int or slice per axis);
      collapses along int axes and slices along slice axes.

    Indexing uses row-major order across ``batch_shape``.
    The pure-int path translates the key directly to a flat index
    via ``np.ravel_multi_index`` without materialising a
    ``shape=batch_shape`` object array.
    """
    bshape = self._batch_shape
    # Normalise the key to a tuple, one entry per leading axis.
    if isinstance(key, tuple):
        if len(key) > len(bshape):
            raise IndexError(
                f"DistributionArray has {len(bshape)} leading batch "
                f"axes; got {len(key)}-tuple key."
            )
        key_tuple = key + (slice(None),) * (len(bshape) - len(key))
    else:
        key_tuple = (key,) + (slice(None),) * (len(bshape) - 1)

    # Fast path: all axes addressed by int (or int-like). Compute
    # the flat index directly; no object-array materialisation.
    # ``np.ravel_multi_index`` rejects negative indices, so wrap
    # them into the positive range first (``dists[-1]`` is a
    # common pattern — e.g. "last posterior in an iterate output").
    if all(
        isinstance(k, (int, np.integer)) or hasattr(k, "__index__")
        for k in key_tuple
    ):
        indices = tuple(
            int(k) % dim for k, dim in zip(key_tuple, bshape)
        )
        flat = int(np.ravel_multi_index(indices, bshape))
        if self._backend is not None:
            return self._backend.cell(flat)
        return self._components[flat]

    # General path: object-array view for slice / mixed-key
    # support. Only materialised when slices are actually used,
    # and only for the axes that involve them. Backend-delegated
    # arrays materialise their components on this path too — slice
    # indexing is rare and a one-time materialisation cost is
    # acceptable.
    components_nd = np.asarray(self.components, dtype=object).reshape(bshape)
    sliced = components_nd[key_tuple]
    if isinstance(sliced, Distribution):
        return sliced
    if sliced.ndim == 0:
        return sliced.item()
    new_components = list(sliced.ravel())
    if not new_components:
        raise ValueError(
            "DistributionArray index produced an empty sequence; "
            "at least one component is required."
        )
    return _make_distribution_array(
        new_components,
        batch_shape=sliced.shape,
        name=self._name,
    )

__iter__()

Iterate the leading axis (numpy / jax convention).

len(self) items are yielded:

  • ndim == 1 (the common case): each item is a scalar Distribution cell.
  • ndim >= 2: each item is a DistributionArray of shape batch_shape[1:] — a leading-axis slice, mirroring iter(np.zeros((2, 3))) yielding two (3,)-shaped views.
  • ndim == 0 (batch_shape == ()): raises TypeError to match iter(np.zeros(())). Reach for _flat_component (or components) to access the single cell — those work uniformly across every batch_shape including ().

For flat row-major access over every cell (the pre-#178 behaviour), use components or range(self.size) with _flat_component.

Source code in probpipe/core/_distribution_array.py
def __iter__(self):
    """Iterate the leading axis (numpy / jax convention).

    ``len(self)`` items are yielded:

    * ``ndim == 1`` (the common case): each item is a scalar
      :class:`~probpipe.Distribution` cell.
    * ``ndim >= 2``: each item is a ``DistributionArray`` of
      shape ``batch_shape[1:]`` — a leading-axis slice, mirroring
      ``iter(np.zeros((2, 3)))`` yielding two ``(3,)``-shaped
      views.
    * ``ndim == 0`` (``batch_shape == ()``): raises
      ``TypeError`` to match ``iter(np.zeros(()))``. Reach for
      :meth:`_flat_component` (or :attr:`components`) to access
      the single cell — those work uniformly across every
      ``batch_shape`` including ``()``.

    For flat row-major access over every cell (the pre-#178
    behaviour), use :attr:`components` or
    ``range(self.size)`` with :meth:`_flat_component`.
    """
    bshape = self._batch_shape
    if not bshape:
        raise TypeError(
            "iteration over a 0-d DistributionArray "
            f"(batch_shape={bshape}). Reach for "
            "da.components or da._flat_component(0) for the "
            "single cell."
        )
    n_lead = bshape[0]
    if len(bshape) == 1:
        if self._backend is not None:
            return (self._backend.cell(i) for i in range(n_lead))
        return iter(self._components)
    # Multi-d: __getitem__(int) on a multi-d DA returns a
    # sub-DistributionArray of shape batch_shape[1:].
    return (self[i] for i in range(n_lead))

BroadcastDistribution(input_samples, output_samples, weights=None, *, log_weights=None, output_distributions=None, broadcast_args, name=None)

Bases: Distribution[dict], SupportsSampling

Joint distribution over broadcast inputs and function output.

Stores the paired input–output samples from a WorkflowFunction broadcast. Supports joint sampling (resampling paired input–output tuples) and named component access.

Call marginalize to obtain the output-only marginal, which supports moment protocols (mean, variance, etc.) when the output data permits.

.. note::

BroadcastDistribution does not inherit from JointDistribution. JointDistribution requires all leaves to be NumericRecordDistribution instances with TFP shape semantics (batch_shape, event_shape), but a broadcast output can be any type — arrays, distributions, strings, etc. — and input samples are plain arrays without distribution metadata. The two hierarchies serve different roles: JointDistribution models structured probabilistic variables; BroadcastDistribution records the empirical input–output mapping of a function evaluation.

Parameters:

Name Type Description Default
input_samples dict[str, Array]

{arg_name: (n, *event_shape)} for each broadcast argument.

required
output_samples Array or list

(n, *event_shape) for array outputs, or a list of length n.

required
weights array-like, :class:`~probpipe.Weights`, or None

Non-negative weights (normalized internally). A pre-built Weights object is also accepted. Mutually exclusive with log_weights. None for uniform.

None
log_weights array-like, :class:`~probpipe.Weights`, or None

Log-unnormalized weights. A pre-built Weights object is also accepted. Mutually exclusive with weights.

None
output_distributions list of Distribution or None

When each function evaluation returns a Distribution, these are the n component distributions for the mixture marginal.

None
broadcast_args list of str

Ordered names of the broadcast arguments.

required
name str or None

Distribution name for provenance.

None
Source code in probpipe/core/_broadcast_distributions.py
def __init__(
    self,
    input_samples: dict[str, Any],
    output_samples: Any,
    weights: Array | Weights | None = None,
    *,
    log_weights: Array | Weights | None = None,
    output_distributions: list | None = None,
    broadcast_args: list[str],
    name: str | None = None,
):
    self._input_samples = input_samples
    self._output_samples = output_samples
    self._output_distributions = output_distributions

    # Determine n from first broadcast arg
    first_key = list(broadcast_args)[0]
    first_arr = input_samples[first_key]
    n = first_arr.shape[0] if hasattr(first_arr, 'shape') else len(first_arr)
    self._w = Weights(n=n, weights=weights, log_weights=log_weights)
    self._broadcast_args = list(broadcast_args)
    if name is None:
        name = "broadcast"
    super().__init__(name=name)
    self._approximate = True
    self._marginal_cache: MarginalizedBroadcastDistribution | None = None

n property

Number of input–output pairs.

weights property

Normalised weights, shape (n,).

input_samples property

Broadcast input samples: {arg_name: (n, *event_shape)}.

samples property

Output samples (forwarded to output marginal for backward compat).

output property

Alias for marginalize.

marginalize()

Return the output marginal distribution.

Lazy — the marginal is constructed on first call and cached. The marginal inherits this distribution's provenance (if any) so the lineage is preserved without a direct reference to the BroadcastDistribution.

Source code in probpipe/core/_broadcast_distributions.py
def marginalize(self) -> MarginalizedBroadcastDistribution:
    """Return the output marginal distribution.

    Lazy — the marginal is constructed on first call and cached.
    The marginal inherits this distribution's provenance (if any)
    so the lineage is preserved without a direct reference to the
    ``BroadcastDistribution``.
    """
    if self._marginal_cache is None:
        self._marginal_cache = _make_marginal(
            self._output_samples,
            self._w,
            output_distributions=self._output_distributions,
        )
        if self.source is not None and isinstance(self._marginal_cache, Distribution):
            self._marginal_cache.with_source(self.source)
    return self._marginal_cache

The truly private machinery (_RecordDistributionView, _vmap_sample, _mc_expectation) lives on Internals, alongside the public-but-rarely-constructed FlattenedDistributionView and NumericRecordDistributionView.