Skip to content

Operations

Standalone workflow functions for sampling, density evaluation, moments, conditioning, and conversion. Each op dispatches via the matching protocol, participates in broadcasting, and is subject to Prefect orchestration when configured.

Sampling

sample(dist, *, key=None, sample_shape=())

Draw samples from a distribution.

Parameters:

Name Type Description Default
dist SupportsSampling

Distribution to sample from.

required
key PRNGKey

JAX PRNG key. Auto-generated if None.

None
sample_shape tuple of int

Shape prefix for independent draws.

()
Source code in probpipe/core/ops.py
@workflow_function
def sample(
    dist: SupportsSampling,
    *,
    key: PRNGKey | None = None,
    sample_shape: tuple[int, ...] = (),
) -> Any:
    """Draw samples from a distribution.

    Parameters
    ----------
    dist : SupportsSampling
        Distribution to sample from.
    key : PRNGKey, optional
        JAX PRNG key.  Auto-generated if ``None``.
    sample_shape : tuple of int
        Shape prefix for independent draws.
    """
    if not isinstance(dist, SupportsSampling):
        raise TypeError(
            f"{type(dist).__name__} does not support sampling "
            f"(does not implement SupportsSampling)"
        )
    if key is None:
        key = _auto_key()
    return dist._sample(key, sample_shape)

Density evaluation

log_prob(dist, value)

Evaluate the normalized log-density at value.

Source code in probpipe/core/ops.py
@workflow_function
def log_prob(dist: SupportsLogProb, value: Any) -> Array:
    """Evaluate the normalized log-density at *value*."""
    if not isinstance(dist, SupportsLogProb):
        raise TypeError(
            f"{type(dist).__name__} does not support log_prob"
        )
    return dist._log_prob(value)

prob(dist, value)

Evaluate the density at value (exp(log_prob)).

Source code in probpipe/core/ops.py
@workflow_function
def prob(dist: SupportsLogProb, value: Any) -> Array:
    """Evaluate the density at *value* (``exp(log_prob)``)."""
    if not isinstance(dist, SupportsLogProb):
        raise TypeError(
            f"{type(dist).__name__} does not support prob "
            f"(missing _log_prob method)"
        )
    return jnp.exp(dist._log_prob(value))

unnormalized_log_prob(dist, value)

Evaluate the unnormalized log-density at value.

Source code in probpipe/core/ops.py
@workflow_function
def unnormalized_log_prob(
    dist: SupportsUnnormalizedLogProb, value: Any,
) -> Array:
    """Evaluate the unnormalized log-density at *value*."""
    if not isinstance(dist, SupportsUnnormalizedLogProb):
        raise TypeError(
            f"{type(dist).__name__} does not support unnormalized_log_prob "
            f"(missing _unnormalized_log_prob method)"
        )
    return dist._unnormalized_log_prob(value)

unnormalized_prob(dist, value)

Evaluate the unnormalized density at value (exp(unnormalized_log_prob)).

Source code in probpipe/core/ops.py
@workflow_function
def unnormalized_prob(
    dist: SupportsUnnormalizedLogProb, value: Any,
) -> Array:
    """Evaluate the unnormalized density at *value* (``exp(unnormalized_log_prob)``)."""
    if not isinstance(dist, SupportsUnnormalizedLogProb):
        raise TypeError(
            f"{type(dist).__name__} does not support unnormalized_prob "
            f"(missing _unnormalized_log_prob method)"
        )
    return jnp.exp(dist._unnormalized_log_prob(value))

random_log_prob(dist, value=None)

Return the random (normalized) log-density of a random measure.

For a RandomMeasure[T] M with draws D ~ M, the random function x ↦ log D(x) is itself a callable returning a distribution over scalars at every input.

When value is omitted, returns that callable as a RandomFunction. When value is provided, returns the Distribution[Array] over log D(value) directly — equivalent to random_log_prob(dist)(value). The two-argument form mirrors log_prob for non-random distributions.

Concrete subclasses implement a single method _random_log_prob() returning a RandomFunction; the optional value dispatch lives entirely in this op, not on the protocol.

Source code in probpipe/core/ops.py
@workflow_function
def random_log_prob(
    dist: SupportsRandomLogProb,
    value: Any = None,
) -> RandomFunction | Distribution:
    """Return the random (normalized) log-density of a random measure.

    For a ``RandomMeasure[T]`` ``M`` with draws ``D ~ M``, the random
    function ``x ↦ log D(x)`` is itself a callable returning a
    distribution over scalars at every input.

    When *value* is omitted, returns that callable as a
    :class:`~probpipe.core._random_functions.RandomFunction`. When
    *value* is provided, returns the ``Distribution[Array]`` over
    ``log D(value)`` directly — equivalent to
    ``random_log_prob(dist)(value)``. The two-argument form mirrors
    :func:`log_prob` for non-random distributions.

    Concrete subclasses implement a single method
    ``_random_log_prob()`` returning a ``RandomFunction``; the optional
    *value* dispatch lives entirely in this op, not on the protocol.
    """
    if not isinstance(dist, SupportsRandomLogProb):
        raise TypeError(
            f"{type(dist).__name__} does not support random_log_prob "
            f"(does not implement SupportsRandomLogProb)"
        )
    rf = dist._random_log_prob()
    return rf if value is None else rf(value)

random_unnormalized_log_prob(dist, value=None)

Return the random unnormalized log-density of a random measure.

For a RandomMeasure[T] M with draws D ~ M, the random function x ↦ log D̃(x) (where is the unnormalized density of D) is itself a callable returning a distribution over scalars at every input.

When value is omitted, returns that callable as a RandomFunction. When value is provided, returns the Distribution[Array] over log D̃(value) directly — equivalent to random_unnormalized_log_prob(dist)(value). The two-argument form mirrors unnormalized_log_prob for non-random distributions.

Concrete subclasses implement a single method _random_unnormalized_log_prob() returning a RandomFunction; the optional value dispatch lives entirely in this op, not on the protocol.

Source code in probpipe/core/ops.py
@workflow_function
def random_unnormalized_log_prob(
    dist: SupportsRandomUnnormalizedLogProb,
    value: Any = None,
) -> RandomFunction | Distribution:
    """Return the random unnormalized log-density of a random measure.

    For a ``RandomMeasure[T]`` ``M`` with draws ``D ~ M``, the random
    function ``x ↦ log D̃(x)`` (where ``D̃`` is the unnormalized density
    of ``D``) is itself a callable returning a distribution over
    scalars at every input.

    When *value* is omitted, returns that callable as a
    :class:`~probpipe.core._random_functions.RandomFunction`. When
    *value* is provided, returns the ``Distribution[Array]`` over
    ``log D̃(value)`` directly — equivalent to
    ``random_unnormalized_log_prob(dist)(value)``. The two-argument
    form mirrors :func:`unnormalized_log_prob` for non-random
    distributions.

    Concrete subclasses implement a single method
    ``_random_unnormalized_log_prob()`` returning a ``RandomFunction``;
    the optional *value* dispatch lives entirely in this op, not on
    the protocol.
    """
    if not isinstance(dist, SupportsRandomUnnormalizedLogProb):
        raise TypeError(
            f"{type(dist).__name__} does not support random_unnormalized_log_prob "
            f"(does not implement SupportsRandomUnnormalizedLogProb)"
        )
    rf = dist._random_unnormalized_log_prob()
    return rf if value is None else rf(value)

Moments and expectations

mean(dist)

Compute E[X] where X ~ dist.

The return type is T-shaped where T is dist's sample type:

  • Numeric distributions (T = Array) — returns Array.
  • Structured distributions (T = Record) — returns Record.
  • RandomMeasure[T] (T itself a Distribution[T]) — returns the marginalised Distribution[T] with marginal D̄(A) = ∫ D(A) dM(D).

Requires the distribution to implement SupportsMean.

Source code in probpipe/core/ops.py
@workflow_function
def mean(dist: SupportsMean) -> Any:
    """Compute ``E[X]`` where ``X ~ dist``.

    The return type is ``T``-shaped where ``T`` is *dist*'s sample type:

    * Numeric distributions (``T = Array``) — returns
      :class:`~probpipe.custom_types.Array`.
    * Structured distributions (``T = Record``) — returns
      :class:`~probpipe.record.Record`.
    * :class:`~probpipe.core._random_measures.RandomMeasure[T]` (``T``
      itself a :class:`~probpipe.core._distribution_base.Distribution[T]`)
      — returns the marginalised ``Distribution[T]`` with marginal
      ``D̄(A) = ∫ D(A) dM(D)``.

    Requires the distribution to implement :class:`SupportsMean`.
    """
    if not isinstance(dist, SupportsMean):
        raise TypeError(
            f"{type(dist).__name__} does not support mean "
            f"(does not implement SupportsMean)"
        )
    return dist._mean()

variance(dist)

Compute Var[X].

Requires the distribution to implement SupportsVariance.

Source code in probpipe/core/ops.py
@workflow_function
def variance(dist: SupportsVariance) -> Any:
    """Compute Var[X].

    Requires the distribution to implement :class:`SupportsVariance`.
    """
    if not isinstance(dist, SupportsVariance):
        raise TypeError(
            f"{type(dist).__name__} does not support variance "
            f"(does not implement SupportsVariance)"
        )
    return dist._variance()

cov(dist)

Compute the covariance matrix.

Requires the distribution to implement SupportsCovariance.

Source code in probpipe/core/ops.py
@workflow_function
def cov(dist: SupportsCovariance) -> Array:
    """Compute the covariance matrix.

    Requires the distribution to implement :class:`SupportsCovariance`.
    """
    if not isinstance(dist, SupportsCovariance):
        raise TypeError(
            f"{type(dist).__name__} does not support covariance "
            f"(does not implement SupportsCovariance)"
        )
    return dist._cov()

expectation(dist, f, *, key=None, num_evaluations=None, return_dist=None)

Compute E[f(X)] where X ~ dist.

Source code in probpipe/core/ops.py
@workflow_function
def expectation(
    dist: SupportsExpectation,
    f: Any,
    *,
    key: PRNGKey | None = None,
    num_evaluations: int | None = None,
    return_dist: bool | None = None,
) -> Any:
    """Compute E[f(X)] where X ~ dist."""
    if not isinstance(dist, SupportsExpectation):
        raise TypeError(
            f"{type(dist).__name__} does not support expectation"
        )
    return dist._expectation(
        f, key=key, num_evaluations=num_evaluations, return_dist=return_dist,
    )

Conditioning

condition_on(dist, observed=None, *, method=None, **kwargs)

Condition a distribution on observed values.

Observed data can be passed positionally or as named keyword arguments::

# Positional (backward compatible):
condition_on(model, y_obs)

# Named data kwargs — bundled into Record(X=..., y=...):
condition_on(model, X=bootstrap["X"], y=bootstrap["y"],
             n_broadcast_samples=16)

When named data kwargs are distribution views from the same parent, the workflow function broadcasting machinery samples the parent once and distributes the fields, preserving joint correlation.

Dispatch priority:

  1. Explicit overridemethod="tfp_nuts" (or any registered name) routes directly to the named inference method.
  2. Exact conditioning — if dist implements SupportsConditioning, its _condition_on is called for a closed-form result (e.g., conjugate updates, joint marginalization).
  3. Registry auto-select — the inference method registry picks the highest-priority feasible algorithm (NUTS, HMC, RWMH, etc.).

Parameters:

Name Type Description Default
dist Distribution

Distribution or model to condition. Need not implement SupportsConditioning — the registry provides inference methods for common model types.

required
observed Any

Observed values to condition on.

None
method str or None

If provided, use the named inference method from the registry instead of the default dispatch.

None
**kwargs Any

Inference parameters (e.g., num_results, num_warmup, random_seed) and/or named data kwargs. Any kwarg whose name matches a distribution component name is treated as observed data; everything else is an inference parameter.

{}
Source code in probpipe/core/ops.py
@workflow_function
def condition_on(
    dist: Distribution,
    observed: Any = None,
    *,
    method: str | None = None,
    **kwargs: Any,
) -> Distribution:
    """Condition a distribution on observed values.

    Observed data can be passed positionally or as named keyword
    arguments::

        # Positional (backward compatible):
        condition_on(model, y_obs)

        # Named data kwargs — bundled into Record(X=..., y=...):
        condition_on(model, X=bootstrap["X"], y=bootstrap["y"],
                     n_broadcast_samples=16)

    When named data kwargs are distribution views from the same parent,
    the workflow function broadcasting machinery samples the parent once
    and distributes the fields, preserving joint correlation.

    Dispatch priority:

    1. **Explicit override** — ``method="tfp_nuts"`` (or any registered
       name) routes directly to the named inference method.
    2. **Exact conditioning** — if *dist* implements
       ``SupportsConditioning``, its ``_condition_on`` is called for a
       closed-form result (e.g., conjugate updates, joint marginalization).
    3. **Registry auto-select** — the inference method registry picks
       the highest-priority feasible algorithm (NUTS, HMC, RWMH, etc.).

    Parameters
    ----------
    dist : Distribution
        Distribution or model to condition.  Need not implement
        ``SupportsConditioning`` — the registry provides inference
        methods for common model types.
    observed : Any
        Observed values to condition on.
    method : str or None
        If provided, use the named inference method from the registry
        instead of the default dispatch.
    **kwargs
        Inference parameters (e.g., ``num_results``, ``num_warmup``,
        ``random_seed``) and/or named data kwargs.  Any kwarg whose
        name matches a distribution component name is treated as
        observed data; everything else is an inference parameter.
    """
    from ..inference import inference_method_registry
    from .record import Record

    # Separate data kwargs (names matching fields) from
    # inference kwargs (everything else like num_results, num_warmup).
    data_kwargs, inference_kwargs = _split_data_kwargs(dist, kwargs)

    # Explicit method override → always use the registry
    if method is not None:
        if data_kwargs:
            if observed is not None:
                raise ValueError(
                    "Cannot provide both positional `observed` and named "
                    f"data kwargs ({', '.join(data_kwargs)})"
                )
            observed = Record(data_kwargs)
        return inference_method_registry.execute(
            dist, observed, method=method, **inference_kwargs
        )

    # Exact conditioning (conjugate updates, joint marginalization, etc.)
    # All kwargs pass through to _condition_on — it handles its own
    # validation (e.g., ProductDistribution raises KeyError on unknown names).
    if isinstance(dist, SupportsConditioning):
        return dist._condition_on(observed, **data_kwargs, **inference_kwargs)

    # Registry auto-selects the best approximate inference algorithm.
    # Data kwargs are bundled into observed as a Record object.
    if data_kwargs:
        if observed is not None:
            raise ValueError(
                "Cannot provide both positional `observed` and named "
                f"data kwargs ({', '.join(data_kwargs)})"
            )
        observed = Record(data_kwargs)
    return inference_method_registry.execute(dist, observed, **inference_kwargs)

condition_on dispatches inference via the inference-method registry; override the auto-selection with method="<name>".

Conversion

from_distribution(source, target_type, *, key=None, check_support=True, **kwargs)

Convert source into an instance of target_type.

Delegates to the global converter registry.

Parameters:

Name Type Description Default
source Distribution

Source distribution to convert.

required
target_type type

The target distribution class.

required
key PRNGKey

JAX PRNG key for sampling-based conversion.

None
check_support bool

If True (default), verify the supports are compatible.

True
**kwargs Any

Additional keyword arguments passed to the converter.

{}
Source code in probpipe/core/ops.py
@workflow_function
def from_distribution(
    source: Distribution,
    target_type: type,
    *,
    key: Any | None = None,
    check_support: bool = True,
    **kwargs: Any,
) -> Any:
    """Convert *source* into an instance of *target_type*.

    Delegates to the global converter registry.

    Parameters
    ----------
    source : Distribution
        Source distribution to convert.
    target_type : type
        The target distribution class.
    key : PRNGKey, optional
        JAX PRNG key for sampling-based conversion.
    check_support : bool
        If ``True`` (default), verify the supports are compatible.
    **kwargs
        Additional keyword arguments passed to the converter.
    """
    from ..converters import converter_registry
    if key is None:
        key = _auto_key()
    return converter_registry.convert(
        source, target_type, key=key, check_support=check_support, **kwargs
    )

Backed by the converter registry.