Skip to content

Internals

These classes and helpers are not part of the stable public API — they're documented here for contributors, advanced users who need to understand the broadcasting / protocol-dispatch machinery, and anyone debugging a failing isinstance(obj, SupportsX) check. Names start with _ to reflect this. Signatures may change without deprecation warnings between PRs.

If you find yourself reaching for something on this page in user code, there's probably a public replacement — check Operations, Distributions, Records, or Extending ProbPipe first, and open an issue if there isn't.

Record-distribution views

dist["field"] on a RecordDistribution returns a _RecordDistributionView — a lightweight reference that preserves correlation when multiple views from the same parent are used in @workflow_function broadcasting. The view's protocol membership (SupportsSampling, SupportsMean, SupportsVariance, SupportsLogProb, SupportsCovariance) is computed dynamically from the parent's capabilities, so isinstance(view, SupportsMean) is true iff the parent is.

_RecordDistributionView(parent, key)

Bases: Distribution

Lightweight reference to a single named field of a Record-based distribution.

The Record-world analog of DistributionView. Preserves correlation when multiple views from the same parent are used in WorkflowFunction broadcasting.

Dynamic protocol support: this base class intentionally does not inherit any SupportsFoo protocols. Each concrete instance is a cached subclass built by _view_class_for_parent, which mixes in only the protocols the parent actually implements. Calling _RecordDistributionView(parent, key) routes through __new__ and picks the right subclass automatically, so isinstance(view, SupportsSampling) is True iff the parent is.

Parameters:

Name Type Description Default
parent Distribution

A distribution with record_template set.

required
key str

Field name in the parent's record_template.

required
Source code in probpipe/core/_record_distribution.py
def __init__(self, parent: Distribution, key: str):
    template = parent.record_template
    if template is None or key not in template:
        raise KeyError(
            f"No field {key!r} in record_template "
            f"(available: {template.fields if template is not None else ()})"
        )
    # Bypass Distribution.__init__ validation (view name comes from
    # the field key, not a user-supplied argument).
    self._name = key
    self._parent = parent
    self._key = key
    self._key_path = (key,)
    self._template_field = template[key]

parent property

The RecordDistribution this view points at.

Shared-identity signal for the WorkflowFunction sweep layer: views with the same parent co-sample (preserve correlation) when passed as sibling broadcast args to a workflow function. Matches the _RecordArrayView.parent surface.

field property

Name of the viewed field (the top-level key into the parent).

shape property

Shape of one draw from this view — equals event_shape.

dtype property

Dtype of a single draw, if the parent exposes dtypes.

ndim property

Number of axes in a single draw (len(event_shape)).

Flat / Record view helpers

FlattenedDistributionView and NumericRecordDistributionView are the public view classes for the flat ↔ Record-keyed bridge. Both follow the same dynamic-protocol pattern as _RecordDistributionView — only the protocols the base distribution supports are attached to the view.

FlattenedDistributionView is a FlatNumericRecordDistribution: it exposes any distribution as flat (single field, event_shape=(N,)), for interop with algorithms that expect a flat parameter vector. Construct via as_flat_distribution.

NumericRecordDistributionView is the inverse: it takes a FlatNumericRecordDistribution and a NumericRecordTemplate, and presents the distribution under the template's named-field structure. Construct via as_record_distribution.

FlatNumericRecordDistribution(*, name)

Bases: NumericRecordDistribution

A NumericRecordDistribution whose samples are flat 1-D vectors.

The flat contract:

  • exactly one field (len(fields) == 1)
  • event_shape == (N,) for some N
  • samples shaped sample_shape + (N,)

Algorithms that operate on a flat parameter vector — MCMC kernels, optimisers, Hessian / curvature builders, variational families, Pathfinder / Laplace surrogates — should declare their input as FlatNumericRecordDistribution. The natively-multivariate parametrics (MultivariateNormal, Dirichlet, Multinomial, VonMisesFisher) and FlattenedDistributionView all satisfy this contract.

Scalar parametrics (Normal, Beta, …) have event_shape == () and do not satisfy the contract directly; call as_flat_distribution to get a FlattenedDistributionView (whose event_shape is (1,)).

This class is also the home of as_record_distribution — the inverse of as_flat_distribution. Receiver typing means non-flat callers fail at the type level rather than at a runtime shape check.

Source code in probpipe/core/_distribution_base.py
def __init__(self, *, name: str):
    if not isinstance(name, str) or not name:
        raise TypeError(
            f"{type(self).__name__} requires a non-empty name= argument"
        )
    self._name = name

flat_size property

Number of scalar elements — equal to event_shape[0].

Validates the flat contract on access: subclasses with non-1-D event_shape raise TypeError here rather than silently truncating to the first dimension.

as_record_distribution(*, template, name=None)

Lift this flat distribution to a Record-keyed view under template.

Inverse of as_flat_distribution. Samples come back as NumericRecord / NumericRecordArray keyed by template.fields.

Parameters:

Name Type Description Default
template NumericRecordTemplate

Target structural skeleton. Must be a NumericRecordTemplate — opaque (None) leaves cannot be reconstructed from a flat numeric array.

required
name str

Name for the lifted distribution. Defaults to self.name.

None

Returns:

Type Description
NumericRecordDistribution

A thin view over self. Sampling, log-prob, moments, and expectation delegate to the source and reshape via the template. Capability protocols match the source.

Raises:

Type Description
TypeError

If template is not a NumericRecordTemplate.

ValueError

If self.flat_size does not match template.flat_size.

Source code in probpipe/core/_numeric_record_distribution.py
def as_record_distribution(
    self,
    *,
    template: NumericRecordTemplate,
    name: str | None = None,
) -> NumericRecordDistribution:
    """Lift this flat distribution to a Record-keyed view under *template*.

    Inverse of :meth:`~NumericRecordDistribution.as_flat_distribution`.
    Samples come back as :class:`NumericRecord` /
    :class:`NumericRecordArray` keyed by ``template.fields``.

    Parameters
    ----------
    template : NumericRecordTemplate
        Target structural skeleton. Must be a
        :class:`NumericRecordTemplate` — opaque (``None``) leaves
        cannot be reconstructed from a flat numeric array.
    name : str, optional
        Name for the lifted distribution. Defaults to ``self.name``.

    Returns
    -------
    NumericRecordDistribution
        A thin view over ``self``. Sampling, log-prob, moments, and
        ``expectation`` delegate to the source and reshape via the
        template. Capability protocols match the source.

    Raises
    ------
    TypeError
        If ``template`` is not a ``NumericRecordTemplate``.
    ValueError
        If ``self.flat_size`` does not match ``template.flat_size``.
    """
    from .record import NumericRecordTemplate
    if not isinstance(template, NumericRecordTemplate):
        raise TypeError(
            f"as_record_distribution requires a NumericRecordTemplate, "
            f"got {type(template).__name__}. Opaque (None) leaves "
            f"cannot be reconstructed from a flat numeric array."
        )
    if self.flat_size != template.flat_size:
        raise ValueError(
            f"flat_size mismatch: source flat_size={self.flat_size}, "
            f"template.flat_size={template.flat_size}."
        )
    cls = _numeric_record_distribution_view_class_for_base(self)
    return cls(self, template, name=name)

FlattenedDistributionView(base)

Bases: FlatNumericRecordDistribution

Wraps a distribution as a flat FlatNumericRecordDistribution.

Sampling produces flat vectors of shape (event_size,), and _log_prob accepts flat vectors and delegates to the wrapped distribution after unflattening.

This is the primary interoperability mechanism: any algorithm written against FlatNumericRecordDistribution works with an arbitrary RecordDistribution / NumericRecordDistribution via dist.as_flat_distribution().

Dynamic protocol support: the view's isinstance compliance matches the base's capabilities — a log-prob-only base produces a view that is not SupportsSampling, and a sampling-only base produces one that is not SupportsLogProb.

Source code in probpipe/core/_numeric_record_distribution.py
def __init__(self, base: Distribution):
    self._base = base

supports property

Per-field support — the flattened view is real-valued.

base_distribution property

The underlying distribution.

unflatten_sample(flat_sample)

Convenience: unflatten a flat sample back to the pytree structure.

Source code in probpipe/core/_numeric_record_distribution.py
def unflatten_sample(self, flat_sample: ArrayLike):
    """Convenience: unflatten a flat sample back to the pytree structure."""
    return self._base.unflatten_value(jnp.asarray(flat_sample))

NumericRecordDistributionView(base, template, *, name=None)

Bases: NumericRecordDistribution

View that lifts a flat distribution to a Record-keyed structure.

Inverse of FlattenedDistributionView. self._base is a FlatNumericRecordDistribution (single-field, event_shape == (N,)); self.record_template is the user-supplied NumericRecordTemplate (not the source's auto-template).

Sampling, log-prob, and moments delegate to self._base and reshape via the template's flatten / unflatten machinery. Capability protocols match the source via _numeric_record_distribution_view_class_for_base.

Constructed via FlatNumericRecordDistribution.as_record_distribution.

Source code in probpipe/core/_numeric_record_distribution.py
def __init__(
    self,
    base: Distribution,
    template: NumericRecordTemplate,
    *,
    name: str | None = None,
):
    # Match ``FlattenedDistributionView``'s convention of skipping
    # ``Distribution.__init__``: the base may itself be a view with
    # no ``_name`` set, so going through ``super().__init__(name=...)``
    # would raise. Set the attribute directly and accept ``None`` —
    # ``self.name`` will raise on access in that case, consistent
    # with ``FlattenedDistributionView``.
    self._base = base
    self._name = name if name is not None else getattr(base, "_name", None)
    # Pre-set the user-supplied template so the auto-build path in
    # ``NumericRecordDistribution.record_template`` is skipped.
    object.__setattr__(self, "_record_template", template)

event_shape property

Single-field shortcut: the lone field's shape.

Raises TypeError via _single_field_name for multi-field templates; reach for event_shapes (dict) in that case.

event_shapes property

Per-field event shapes from the user-supplied template.

dtypes property

Per-field dtypes — all fields inherit the source's single dtype.

supports property

Per-field supports — all fields inherit the source's single support.

base_distribution property

The underlying single-field flat distribution.

Sampling primitive

_vmap_sample(dist, key, sample_shape=())

Draw samples via jax.vmap over dist._sample(key, ()).

Convenience for distributions whose _sample implementation is naturally a single-draw function: call this helper from _sample and it will handle the sample_shape prefix by splitting keys and vmap-ing over the single-draw path.

Parameters:

Name Type Description Default
dist

Distribution whose _sample(key, ()) draws one unbatched sample (array or pytree of arrays).

required
key PRNGKey

JAX PRNG key.

required
sample_shape tuple of int

Shape prefix for independent draws.

()
Source code in probpipe/core/_numeric_record_distribution.py
def _vmap_sample(
    dist,
    key: PRNGKey,
    sample_shape: tuple[int, ...] = (),
) -> Any:
    """Draw samples via ``jax.vmap`` over ``dist._sample(key, ())``.

    Convenience for distributions whose ``_sample`` implementation is
    naturally a single-draw function: call this helper from ``_sample``
    and it will handle the ``sample_shape`` prefix by splitting keys
    and vmap-ing over the single-draw path.

    Parameters
    ----------
    dist
        Distribution whose ``_sample(key, ())`` draws one unbatched
        sample (array or pytree of arrays).
    key : PRNGKey
        JAX PRNG key.
    sample_shape : tuple of int
        Shape prefix for independent draws.
    """
    def _one(k: PRNGKey) -> Any:
        return dist._sample(k, ())

    if sample_shape == ():
        return _one(key)
    n = prod(sample_shape)
    keys = jax.random.split(key, n)
    flat_samples = jax.vmap(_one)(keys)
    return jax.tree.map(
        lambda x: x.reshape(*sample_shape, *x.shape[1:]),
        flat_samples,
    )

_mc_expectation(dist, f, *, key=None, num_evaluations=None, return_dist=None)

Estimate E[f(X)] where X ~ dist via Monte Carlo.

Parameters:

Name Type Description Default
dist

Distribution with a _sample(key, sample_shape) method.

required
f callable

Function mapping a single sample to an array (or pytree of arrays).

required
key PRNGKey

JAX PRNG key for sampling. Auto-generated if None.

None
num_evaluations int

Number of samples to draw. If None, uses DEFAULT_NUM_EVALUATIONS.

None
return_dist bool

If True, return a BootstrapDistribution capturing estimation uncertainty. If False, return a plain array. If None, use the global RETURN_APPROX_DIST setting.

None
Source code in probpipe/core/_numeric_record_distribution.py
def _mc_expectation(
    dist,
    f: Callable,
    *,
    key: PRNGKey | None = None,
    num_evaluations: int | None = None,
    return_dist: bool | None = None,
) -> Any:
    """Estimate ``E[f(X)]`` where ``X ~ dist`` via Monte Carlo.

    Parameters
    ----------
    dist
        Distribution with a ``_sample(key, sample_shape)`` method.
    f : callable
        Function mapping a single sample to an array (or pytree of arrays).
    key : PRNGKey, optional
        JAX PRNG key for sampling.  Auto-generated if ``None``.
    num_evaluations : int, optional
        Number of samples to draw.  If ``None``, uses
        ``DEFAULT_NUM_EVALUATIONS``.
    return_dist : bool, optional
        If ``True``, return a ``BootstrapDistribution`` capturing
        estimation uncertainty.  If ``False``, return a plain array.
        If ``None``, use the global ``RETURN_APPROX_DIST`` setting.
    """
    n = num_evaluations if num_evaluations is not None else _base.DEFAULT_NUM_EVALUATIONS
    if key is None:
        key = _auto_key()
    samples = dist._sample(key, sample_shape=(n,))
    evals = jax.vmap(f)(samples)

    rd = return_dist if return_dist is not None else _base.RETURN_APPROX_DIST
    if rd:
        return BootstrapDistribution(evals, name="E[f(X)]")
    return jax.tree.map(lambda v: jnp.mean(v, axis=0), evals)