Skip to content

Random functions

Distributions over function-valued random variables (RandomFunction, GaussianRandomFunction) and over measure-valued random variables (RandomMeasure).

RandomFunction(*, name=None)

Bases: Distribution[Callable[[X], Y]]

A distribution over functions f: X → Y.

The primary interface is __call__. Calling the random function on a set of inputs returns a distribution representing the (joint) distribution over the corresponding function outputs. In other words, calling returns the finite-dimensional distributions of the stochastic process. Log-densities are typically not available as random functions do not have densities in the standard sense.

Sampling a random function means sampling an entire functional trajectory. For infinite-dimensional models (e.g., Gaussian processes), drawing an entire function realization may be impossible or require approximation. Finite-dimensional subclasses that support sampling should inherit SupportsSampling and implement _sample(key, sample_shape).

This class is generic in X (input type) and Y (output type).

Source code in probpipe/core/_random_functions.py
def __init__(self, *, name: str | None = None):
    super().__init__(name=name or type(self).__name__)

input_shape property

Shape of a single input point (array-valued case).

output_shape property

Shape of a single output point (array-valued case).

__call__(x) abstractmethod

Return the distribution over outputs at input x.

This is the fundamental interface of a random function.

Source code in probpipe/core/_random_functions.py
@abstractmethod
def __call__(self, x: X) -> Distribution[Y]:
    """Return the distribution over outputs at input *x*.

    This is the fundamental interface of a random function.
    """
    ...

ArrayRandomFunction(input_shape, output_shape=(), *, name=None)

Bases: RandomFunction[Array, Array]

A random function mapping arrays to arrays.

Given prediction input X with shape (*extra_batch, n, *input_shape) (n = number of input points), predict returns a DistributionArray whose outer batch_shape covers axes that are independent across cells and whose per-cell event_shape covers axes that are jointly modeled:

+----------------+-----------------+----------------+------------------------------+----------------------+ | joint_inputs | joint_outputs | Cell type | Outer batch_shape | Cell event_shape | +================+=================+================+==============================+======================+ | False | False | Normal | (*extra_batch, n, *out) | () | +----------------+-----------------+----------------+------------------------------+----------------------+ | True | False | MVN | (*extra_batch, *out) | (n,) | +----------------+-----------------+----------------+------------------------------+----------------------+ | False | True | MVN | (*extra_batch, n) | output_shape | +----------------+-----------------+----------------+------------------------------+----------------------+ | True | True | MVN | (*extra_batch,) | (n, *out) | +----------------+-----------------+----------------+------------------------------+----------------------+

out is shorthand for output_shape. In all modes a sample has the same total shape: (*sample_shape, *extra_batch, n, *output_shape). The flags only change which axes are jointly modeled (event) vs independent (batch).

Parameters:

Name Type Description Default
input_shape tuple of int

Shape of a single input point, e.g. (3,) for 3-D inputs.

required
output_shape tuple of int

Shape of a single output, e.g. (2,) for two outputs, () for a scalar output.

()
Source code in probpipe/core/_random_functions.py
def __init__(
    self,
    input_shape: tuple[int, ...],
    output_shape: tuple[int, ...] = (),
    *,
    name: str | None = None,
) -> None:
    super().__init__(name=name)
    self._input_shape = tuple(input_shape)
    self._output_shape = tuple(output_shape)

input_shape property

Shape of a single input point.

output_shape property

Shape of a single output point.

__call__(X, *, joint_inputs=False, joint_outputs=False)

Return a predictive distribution over outputs at input points X.

Validates inputs, parses shapes, and delegates to predict.

Parameters:

Name Type Description Default
X array - like

Input points with shape (*extra_batch, n, *input_shape).

required
joint_inputs bool

If True, the n axis is part of the event (predictions are correlated across input points). Default: False.

False
joint_outputs bool

If True, output_shape is part of the event (predictions are correlated across outputs). Default: False.

False

Returns:

Type Description
DistributionArray

A DistributionArray of Normal (fully marginal) or MultivariateNormal (any joint axis) cells, with outer batch_shape and per-cell event_shape following the shape table in the class docstring.

Source code in probpipe/core/_random_functions.py
def __call__(
    self,
    X: ArrayLike,
    *,
    joint_inputs: bool = False,
    joint_outputs: bool = False,
) -> Distribution:
    """Return a predictive distribution over outputs at input points *X*.

    Validates inputs, parses shapes, and delegates to :meth:`predict`.

    Parameters
    ----------
    X : array-like
        Input points with shape ``(*extra_batch, n, *input_shape)``.
    joint_inputs : bool, optional
        If True, the ``n`` axis is part of the event (predictions are
        correlated across input points).  Default: False.
    joint_outputs : bool, optional
        If True, ``output_shape`` is part of the event (predictions are
        correlated across outputs).  Default: False.

    Returns
    -------
    DistributionArray
        A ``DistributionArray`` of ``Normal`` (fully marginal) or
        ``MultivariateNormal`` (any joint axis) cells, with
        outer ``batch_shape`` and per-cell ``event_shape``
        following the shape table in the class docstring.
    """
    X = jnp.asarray(X)
    self._validate_joint_request(joint_inputs, joint_outputs)
    self._validate_X(X)
    return self.predict(X, joint_inputs=joint_inputs, joint_outputs=joint_outputs)

predict(X, *, joint_inputs=False, joint_outputs=False) abstractmethod

Subclass implementation of prediction.

When this method is called, X has already been validated and converted to a JAX array. Subclasses should return a DistributionArray whose outer batch_shape and per-cell event_shape conform to the shape table in the ArrayRandomFunction class docstring.

Parameters:

Name Type Description Default
X Array

Validated input points, shape (*extra_batch, n, *input_shape).

required
joint_inputs bool

Whether predictions should be joint across input points.

False
joint_outputs bool

Whether predictions should be joint across outputs.

False
Source code in probpipe/core/_random_functions.py
@abstractmethod
def predict(
    self,
    X: Array,
    *,
    joint_inputs: bool = False,
    joint_outputs: bool = False,
) -> Distribution:
    """Subclass implementation of prediction.

    When this method is called, ``X`` has already been validated and
    converted to a JAX array.  Subclasses should return a
    :class:`~probpipe.DistributionArray` whose outer
    ``batch_shape`` and per-cell ``event_shape`` conform to the
    shape table in the :class:`ArrayRandomFunction` class
    docstring.

    Parameters
    ----------
    X : Array
        Validated input points, shape ``(*extra_batch, n, *input_shape)``.
    joint_inputs : bool
        Whether predictions should be joint across input points.
    joint_outputs : bool
        Whether predictions should be joint across outputs.
    """
    ...

GaussianRandomFunction(input_shape, output_shape=(), *, name=None)

Bases: ArrayRandomFunction

Abstract random function with Gaussian predictive distributions.

Subclasses implement predict_mean and predict_variance at minimum. If the model supports joint modes it must also implement predict_covariance.

The base predict method assembles these into the appropriate Normal or MultivariateNormal distribution with the correct batch/event shape partition.

This class is not restricted to GPs — any model that produces Gaussian (or Gaussian-approximated) predictions can inherit from it.

Source code in probpipe/core/_random_functions.py
def __init__(
    self,
    input_shape: tuple[int, ...],
    output_shape: tuple[int, ...] = (),
    *,
    name: str | None = None,
) -> None:
    super().__init__(name=name)
    self._input_shape = tuple(input_shape)
    self._output_shape = tuple(output_shape)

predict_mean(X) abstractmethod

Predictive mean at input points X.

Parameters:

Name Type Description Default
X Array

Shape (*extra_batch, n, *input_shape).

required

Returns:

Type Description
Array

Shape (*extra_batch, n, *output_shape).

Source code in probpipe/distributions/gaussian_random_function.py
@abstractmethod
def predict_mean(self, X: Array) -> Array:
    """Predictive mean at input points X.

    Parameters
    ----------
    X : Array
        Shape ``(*extra_batch, n, *input_shape)``.

    Returns
    -------
    Array
        Shape ``(*extra_batch, n, *output_shape)``.
    """
    ...

predict_variance(X) abstractmethod

Marginal predictive variance at input points X.

Parameters:

Name Type Description Default
X Array

Shape (*extra_batch, n, *input_shape).

required

Returns:

Type Description
Array

Shape (*extra_batch, n, *output_shape). Each element is the marginal variance of the corresponding scalar prediction.

Source code in probpipe/distributions/gaussian_random_function.py
@abstractmethod
def predict_variance(self, X: Array) -> Array:
    """Marginal predictive variance at input points X.

    Parameters
    ----------
    X : Array
        Shape ``(*extra_batch, n, *input_shape)``.

    Returns
    -------
    Array
        Shape ``(*extra_batch, n, *output_shape)``.
        Each element is the marginal variance of the corresponding
        scalar prediction.
    """
    ...

predict_covariance(X, *, joint_inputs=False, joint_outputs=False)

Predictive covariance matrix.

Required only if the model supports joint modes. Returns the covariance over whichever axes are flagged as joint.

Parameters:

Name Type Description Default
X Array

Shape (*extra_batch, n, *input_shape).

required
joint_inputs bool

Include cross-input covariance.

False
joint_outputs bool

Include cross-output covariance.

False

Returns:

Type Description
Array

The shape depends on the joint flags:

  • joint_inputs=True, joint_outputs=True: (*extra_batch, n * prod(output_shape), n * prod(output_shape))

  • joint_inputs=True, joint_outputs=False: (*extra_batch, *output_shape, n, n)

  • joint_inputs=False, joint_outputs=True: (*extra_batch, n, prod(output_shape), prod(output_shape))

Raises:

Type Description
NotImplementedError

If the subclass does not support the requested joint mode.

Source code in probpipe/distributions/gaussian_random_function.py
def predict_covariance(
    self,
    X: Array,
    *,
    joint_inputs: bool = False,
    joint_outputs: bool = False,
) -> Array:
    """Predictive covariance matrix.

    Required only if the model supports joint modes.  Returns the
    covariance over whichever axes are flagged as joint.

    Parameters
    ----------
    X : Array
        Shape ``(*extra_batch, n, *input_shape)``.
    joint_inputs : bool
        Include cross-input covariance.
    joint_outputs : bool
        Include cross-output covariance.

    Returns
    -------
    Array
        The shape depends on the joint flags:

        - ``joint_inputs=True, joint_outputs=True``:
          ``(*extra_batch, n * prod(output_shape), n * prod(output_shape))``

        - ``joint_inputs=True, joint_outputs=False``:
          ``(*extra_batch, *output_shape, n, n)``

        - ``joint_inputs=False, joint_outputs=True``:
          ``(*extra_batch, n, prod(output_shape), prod(output_shape))``

    Raises
    ------
    NotImplementedError
        If the subclass does not support the requested joint mode.
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not implement predict_covariance "
        f"for joint_inputs={joint_inputs}, joint_outputs={joint_outputs}."
    )

__rmatmul__(other)

A @ grf — linear map of outputs.

other must be a 2-D array of shape (d_out, d_w) and self.output_shape must be 1-D (d_w,).

Source code in probpipe/distributions/gaussian_random_function.py
def __rmatmul__(self, other):
    """``A @ grf`` — linear map of outputs.

    *other* must be a 2-D array of shape ``(d_out, d_w)`` and
    ``self.output_shape`` must be 1-D ``(d_w,)``.
    """
    return _LinearMapGRF(self, jnp.asarray(other))

predict(X, *, joint_inputs=False, joint_outputs=False)

Assemble a Gaussian-valued DistributionArray from mean / variance / covariance.

Returns a DistributionArray whose cells are Normal (fully marginal) or MultivariateNormal (any joint axis); the outer batch_shape covers axes that are independent across cells, the per-cell event_shape covers axes that are jointly modeled.

============== ============== ================ ============================== =================== joint_inputs joint_outputs Cell type Outer batch_shape Cell event_shape ============== ============== ================ ============================== =================== False False Normal (*extra_batch, n, *out) () True False MVN (*extra_batch, *out) (n,) False True MVN (*extra_batch, n) output_shape True True MVN (*extra_batch,) (n, *out) ============== ============== ================ ============================== ===================

out is shorthand for output_shape. In all modes a sample has total shape (*sample_shape, *extra_batch, n, *output_shape); the flags only change which axes are jointly modeled (event) vs independent (batch).

Subclasses may override this if they need non-standard assembly (e.g. structured covariance representations).

Source code in probpipe/distributions/gaussian_random_function.py
def predict(
    self,
    X: Array,
    *,
    joint_inputs: bool = False,
    joint_outputs: bool = False,
):
    """Assemble a Gaussian-valued :class:`DistributionArray` from
    mean / variance / covariance.

    Returns a :class:`~probpipe.DistributionArray` whose cells
    are :class:`~probpipe.Normal` (fully marginal) or
    :class:`~probpipe.MultivariateNormal` (any joint axis); the
    outer ``batch_shape`` covers axes that are independent
    across cells, the per-cell ``event_shape`` covers axes that
    are jointly modeled.

    ============== ============== ================ ============================== ===================
    joint_inputs   joint_outputs  Cell type        Outer ``batch_shape``          Cell ``event_shape``
    ============== ============== ================ ============================== ===================
    ``False``      ``False``      ``Normal``       ``(*extra_batch, n, *out)``    ``()``
    ``True``       ``False``      ``MVN``          ``(*extra_batch, *out)``       ``(n,)``
    ``False``      ``True``       ``MVN``          ``(*extra_batch, n)``          ``output_shape``
    ``True``       ``True``       ``MVN``          ``(*extra_batch,)``            ``(n, *out)``
    ============== ============== ================ ============================== ===================

    ``out`` is shorthand for ``output_shape``. In all modes a
    sample has total shape
    ``(*sample_shape, *extra_batch, n, *output_shape)``; the
    flags only change which axes are jointly modeled (event)
    vs independent (batch).

    Subclasses may override this if they need non-standard
    assembly (e.g. structured covariance representations).
    """
    from ..core._distribution_array import DistributionArray
    from . import MultivariateNormal, Normal

    mean = self.predict_mean(X)  # (*eb, n, *out)
    extra_batch, n = self._parse_X(X)

    # -- Fully marginal ---------------------------------------------------
    if not joint_inputs and not joint_outputs:
        variance = self.predict_variance(X)
        return DistributionArray.from_batched_params(
            Normal,
            batch_shape=(*extra_batch, n, *self._output_shape),
            loc=mean,
            scale=jnp.sqrt(variance),
            name=_PREDICTION_NAME,
        )

    # -- At least one joint axis — need covariance ------------------------
    cov = self.predict_covariance(
        X, joint_inputs=joint_inputs, joint_outputs=joint_outputs
    )
    d_out = prod(self._output_shape)
    scale_tril = jnp.linalg.cholesky(cov)

    if joint_inputs and joint_outputs:
        flat_dim = n * d_out if self._output_shape else n
        flat_mean = mean.reshape(*extra_batch, flat_dim)
        return DistributionArray.from_batched_params(
            MultivariateNormal,
            batch_shape=tuple(extra_batch),
            loc=flat_mean,
            scale_tril=scale_tril,
            name=_PREDICTION_NAME,
        )

    if joint_inputs and not joint_outputs:
        # Joint over n, independent over outputs.
        # mean: (*eb, n, *out) → need (*eb, *out, n)
        if self._output_shape:
            ndim_eb = len(extra_batch)
            ndim_out = len(self._output_shape)
            source_axes = list(range(ndim_eb + 1, ndim_eb + 1 + ndim_out))
            dest_axes = list(range(ndim_eb, ndim_eb + ndim_out))
            mean_t = jnp.moveaxis(mean, source_axes, dest_axes)
        else:
            mean_t = mean  # (*eb, n) — nothing to rearrange
        return DistributionArray.from_batched_params(
            MultivariateNormal,
            batch_shape=(*extra_batch, *self._output_shape),
            loc=mean_t,
            scale_tril=scale_tril,
            name=_PREDICTION_NAME,
        )

    # joint_outputs only (not joint_inputs)
    # mean: (*eb, n, *out) → flatten output dims: (*eb, n, d_out)
    flat_mean = mean.reshape(*extra_batch, n, d_out)
    return DistributionArray.from_batched_params(
        MultivariateNormal,
        batch_shape=(*extra_batch, n),
        loc=flat_mean,
        scale_tril=scale_tril,
        name=_PREDICTION_NAME,
    )

LinearBasisFunction(feature_map, weights, input_shape, output_shape=(), bias=None)

Bases: GaussianRandomFunction, SupportsSampling

Linear model with fixed Gaussian weights.

Implements the model:

.. math::

f(x) = a + \Phi(x)\, w

where :math:w \sim \mathcal{N}(m, C) is a fixed Gaussian distribution over weights, :math:\Phi(x) is a user-supplied feature map, and :math:a is an optional bias.

The feature map maps each input to a vector (scalar output) or matrix (multi-output) of basis-function evaluations. The weight distribution is supplied as a MultivariateNormal.

This model always supports joint_inputs=True since the cross-input covariance :math:\Phi(x_i)\, C\, \Phi(x_j)^T is available analytically.

Parameters:

Name Type Description Default
feature_map callable

Maps input X of shape (*extra_batch, n, *input_shape) to features:

  • Scalar output: shape (*extra_batch, n, d_w)
  • Multi-output: shape (*extra_batch, n, d_out, d_w)

where d_w matches the dimensionality of weights.

required
weights MultivariateNormal

Fixed Gaussian distribution over weight vector, with event_shape = (d_w,).

required
input_shape tuple of int

Shape of a single input point.

required
output_shape tuple of int

Shape of a single output. Default () (scalar).

()
bias array - like

Additive bias of shape (*output_shape,). Defaults to zero.

None
Source code in probpipe/distributions/gaussian_random_function.py
def __init__(
    self,
    feature_map: Callable[[Array], Array],
    weights,  # MultivariateNormal — avoid top-level import
    input_shape: tuple[int, ...],
    output_shape: tuple[int, ...] = (),
    bias: ArrayLike | None = None,
) -> None:
    from . import MultivariateNormal

    if not isinstance(weights, MultivariateNormal):
        raise TypeError(
            f"weights must be a MultivariateNormal, "
            f"got {type(weights).__name__}"
        )

    self._feature_map = feature_map
    self._weights = weights
    self._w_mean = weights.loc          # (d_w,)
    self._w_cov = weights.cov           # (d_w, d_w)

    # Bias inherits dtype from the weight distribution.
    bias_dtype = self._w_mean.dtype
    if bias is not None:
        self._bias = jnp.asarray(bias, dtype=bias_dtype)
    elif output_shape:
        self._bias = jnp.zeros(output_shape, dtype=bias_dtype)
    else:
        self._bias = jnp.zeros((), dtype=bias_dtype)

    super().__init__(input_shape=input_shape, output_shape=output_shape)

    # Multi-output with shared weights implies coupled outputs.
    if self._output_shape:
        self.supports_joint_outputs = True

predict_mean(X)

Predictive mean: a + Phi(X) @ m.

Returns shape (*extra_batch, n, *output_shape).

Source code in probpipe/distributions/gaussian_random_function.py
def predict_mean(self, X: Array) -> Array:
    r"""Predictive mean: ``a + Phi(X) @ m``.

    Returns shape ``(*extra_batch, n, *output_shape)``.
    """
    phi = self._feature_map(X)  # (*eb, n, [*out,] d_w)
    return self._bias + jnp.einsum("...w,w->...", phi, self._w_mean)

predict_variance(X)

Marginal predictive variance.

For scalar output: diag(Phi(X) C Phi(X)^T) element-wise. Returns shape (*extra_batch, n, *output_shape).

Source code in probpipe/distributions/gaussian_random_function.py
def predict_variance(self, X: Array) -> Array:
    """Marginal predictive variance.

    For scalar output: ``diag(Phi(X) C Phi(X)^T)`` element-wise.
    Returns shape ``(*extra_batch, n, *output_shape)``.
    """
    phi = self._feature_map(X)
    return jnp.einsum("...w,wv,...v->...", phi, self._w_cov, phi)

RandomMeasure(*, name=None)

Bases: Distribution[Distribution[T]]

A distribution over probability distributions on T.

A draw D ~ M is itself a Distribution[T]. Capabilities (sampling, expected distribution, random log-density) are declared via the protocols in probpipe.core.protocols — subclasses opt in by implementing the corresponding _method and inheriting the matching Supports* protocol.

The base class does not expose outer support / event_shape / batch_shape: those concepts apply to tensor-valued distributions and have no useful content for a distribution-valued one. Inner metadata (inner_support, inner_event_shape) lives on NumericRandomMeasure when T is array-like.

Batches of random measures should use DistributionArray, which treats RandomMeasure instances as scalar Distribution components like any other.

Source code in probpipe/core/_random_measures.py
def __init__(self, *, name: str | None = None):
    super().__init__(name=name or type(self).__name__)

NumericRandomMeasure(*, name=None)

Bases: RandomMeasure[Array]

A random measure whose inner draws are array-valued distributions.

Adds two pieces of metadata that are only meaningful when the inner sample type is array-like:

  • inner_support — the Constraint that every inner Distribution[Array]'s samples satisfy.
  • inner_event_shape — the event_shape shared by the inner distributions; the shape of one sample drawn from any D ~ M.

These mirror NumericRecordDistribution's support / event_shape for the random-measure layer. Subclasses must override both.

Source code in probpipe/core/_random_measures.py
def __init__(self, *, name: str | None = None):
    super().__init__(name=name or type(self).__name__)

inner_support abstractmethod property

Support shared by every inner Distribution[Array]'s samples.

inner_event_shape abstractmethod property

event_shape shared by the inner distributions.