Skip to content

Composite and joint

Distributions combining named components into a joint over a Record. Component access: dist["name"] returns a view (see Internals for the correlation semantics); dist.select("x", "y") splats into a workflow function.

ProductDistribution(*positional, name=None, **components)

Bases: RecordDistribution, SupportsSampling, SupportsConditioning

Joint distribution with independent leaf components.

Inherits from RecordDistribution. All leaf components are sampled independently. _sample() returns Record.

Dynamic protocol support: SupportsLogProb, SupportsMean, and SupportsVariance are included only when ALL leaf components support them. isinstance(product, SupportsLogProb) is True only when every component has _log_prob.

Parameters:

Name Type Description Default
*positional NumericRecordDistribution

Named distributions. Each distribution's .name is used as the component key.

()
name str

Distribution name for the joint.

None
**components NumericRecordDistribution or dict

Named independent component distributions. Values may be NumericRecordDistribution instances (leaves) or nested dicts whose leaves are NumericRecordDistribution instances. When a keyword key differs from the distribution's name, the distribution is automatically renamed (via renamed()) to match the key.

{}

Examples:

::

# Positional — uses each distribution's name as the key:
ProductDistribution(Normal(0, 1, name="x"), Gamma(2, 1, name="y"))

# Keyword — auto-renames if the key differs:
ProductDistribution(growth_rate=Normal(0, 1, name="x"))

# Mixed:
ProductDistribution(Normal(0, 1, name="x"), scale=Gamma(2, 1, name="y"))
Source code in probpipe/distributions/_product.py
def __init__(self, *positional, name: str | None = None, **components):
    components = _merge_positional_and_keyword(positional, components)
    if not components:
        raise ValueError("ProductDistribution requires at least one component.")
    resolved: dict[str, Any] = {}
    for key, comp in components.items():
        if isinstance(comp, dict):
            resolved[key] = _resolve_nested_names(key, comp)
        elif comp.name != key:
            resolved[key] = comp.renamed(key)
        else:
            resolved[key] = comp
    for leaf in jax.tree.leaves(resolved):
        if not isinstance(leaf, NumericRecordDistribution):
            raise TypeError(
                f"All leaf components must be NumericRecordDistribution, "
                f"got {type(leaf).__name__}"
            )
    self._components = resolved
    if name is None:
        name = "product(" + ",".join(resolved.keys()) + ")"
    super().__init__(name=name)
    self._record_template = _build_record_template(self._components)

components property

Read-only view of the component distributions.

SequentialJointDistribution(*, name=None, **components)

Bases: RecordDistribution, SupportsSampling, SupportsConditioning

Joint distribution with autoregressive (sequential) dependence.

Components can be Distribution instances (roots) or callables that receive previously-sampled values and return a Distribution (conditionals).

Example::

joint = SequentialJointDistribution(
    z=Normal(loc=0.0, scale=1.0, name="z"),
    x=lambda z: Normal(loc=z, scale=0.5, name="x"),
    y=lambda z, x: Normal(loc=z + x, scale=0.1, name="y"),
)

Callable signatures are inspected: parameter names must match earlier component names.

Dynamic protocol support: SupportsLogProb / SupportsMean / SupportsVariance are included only when every resolved leaf component supports them. Sampling and conditioning are always available.

Parameters:

Name Type Description Default
name str

Distribution name.

None
**components Distribution or Callable[..., Distribution]

Named components in topological (dependency) order.

{}
Source code in probpipe/distributions/_sequential_joint.py
def __init__(
    self,
    *,
    name: str | None = None,
    **components: NumericRecordDistribution | callable,
):
    if not components:
        raise ValueError("SequentialJointDistribution requires at least one component.")

    self._raw_components: dict[str, NumericRecordDistribution | callable] = dict(components)
    if name is None:
        name = "sequential(" + ",".join(components.keys()) + ")"
    super().__init__(name=name)
    self._conditioned_names: frozenset[str] = frozenset()
    self._conditioned_values: dict[str, Array] = {}
    self._sampleable_error: str | None = None
    # Map callable component names to their dependency parameter names
    self._callable_parents: dict[str, tuple[str, ...]] = {}

    # Validate ordering: callable args must reference earlier names
    seen: list[str] = []
    for cname, comp in self._raw_components.items():
        if callable(comp) and not isinstance(comp, NumericRecordDistribution):
            params = list(inspect.signature(comp).parameters.keys())
            for p in params:
                if p not in seen:
                    raise ValueError(
                        f"Component '{cname}' depends on '{p}', which "
                        f"is not defined before it. "
                        f"Available: {seen}"
                    )
            self._callable_parents[cname] = tuple(params)
        seen.append(cname)

    # Do a prototype forward pass to determine component distributions
    # and compute event shapes / slices
    proto_key = jax.random.PRNGKey(0)
    proto_structured = self._sample_sequential(proto_key, ())
    self._proto_components: dict[str, NumericRecordDistribution] = {}

    resolved: dict[str, NumericRecordDistribution] = {}
    for cname, comp in self._raw_components.items():
        if isinstance(comp, NumericRecordDistribution):
            resolved[cname] = comp
        else:
            # Resolve the callable with zero-valued parents to get shape info
            parent_vals = {}
            for prev_name in list(self._raw_components.keys()):
                if prev_name == cname:
                    break
                parent_vals[prev_name] = proto_structured[prev_name]
            sig = inspect.signature(comp)
            call_kw = {p: parent_vals[p] for p in sig.parameters if p in parent_vals}
            resolved[cname] = comp(**call_kw)
    self._proto_components = resolved

    # Build _components dict from resolved prototypes (for shape introspection)
    self._components = resolved
    self._record_template = _build_record_template(self._components)

    # Reparent to the dynamic subclass whose protocol bases match the
    # resolved components' capabilities. Done after component
    # resolution because the capabilities depend on the distributions
    # that callable components resolve to.
    dynamic_cls = _sequential_class_for_components(self._components)
    if dynamic_cls is not type(self):
        object.__setattr__(self, "__class__", dynamic_cls)

fields property

Component names in topological (insertion) order.

event_shapes property

Per-component event shapes from component distributions.

components property

Read-only view of the component distributions.

TransformedDistribution(base, bijector, *, name=None)

Bases: NumericRecordDistribution

Distribution formed by applying a TFP bijector to a base distribution.

When base is a TFPDistribution, sampling and density evaluation delegate to tfd.TransformedDistribution for maximum efficiency. Otherwise (e.g. EmpiricalDistribution), the bijector's forward / inverse / inverse_log_det_jacobian are applied manually.

Dynamic protocol support: SupportsLogProb, SupportsMean, and SupportsVariance are included only when the base distribution supports them.

Parameters:

Name Type Description Default
base Distribution

The untransformed base distribution.

required
bijector Bijector

A TFP bijector (e.g. tfb.Exp(), tfb.Sigmoid()).

required
name str

Distribution name for provenance.

None
Source code in probpipe/distributions/transformed.py
def __init__(
    self,
    base: NumericRecordDistribution,
    bijector: tfb.Bijector,
    *,
    name: str | None = None,
):
    self._base = base
    self._bijector = bijector
    if name is None:
        name = f"transformed({base.name})"
    super().__init__(name=name)

    if isinstance(base, TFPDistribution):
        self._tfp_transformed = tfd.TransformedDistribution(
            distribution=base._tfp_dist,
            bijector=bijector,
            name=name or "TransformedDistribution",
        )
    else:
        self._tfp_transformed = None

    self._approximate = base.is_approximate

    self.with_source(Provenance(
        "transform",
        parents=(base,),
        metadata={"bijector": type(bijector).__name__},
    ))

base property

The untransformed base distribution.

bijector property

The TFP bijector applied to base.

dtypes property

Per-field dtype — derived from the transformed TFP distribution when available, falling back to the base distribution's dtype. Spread across the auto-built single-field template.

support property

Derive the output support from the bijector when possible.

The bijector machinery pairs naturally with TransformedDistribution for unconstrained-to-constrained reparameterization.

JointGaussian(*, mean, cov, name=None, **component_shapes)

Bases: RecordDistribution, SupportsSampling, SupportsLogProb, SupportsMean, SupportsVariance, SupportsCovariance, SupportsConditioning

Joint Gaussian distribution with named components and cross-covariance.

Supports exact analytical conditioning via condition_on.

Parameters:

Name Type Description Default
mean array-like, shape ``(d,)``

Full (flat) mean vector.

required
cov array-like, shape ``(d, d)``

Full (flat) covariance matrix.

required
name str

Distribution name.

None
**component_shapes int

Named components with their dimensionality. The sum of all dimensions must equal d.

{}

Examples:

>>> joint = JointGaussian(
...     mean=jnp.array([0.0, 0.0, 1.0, 2.0]),
...     cov=jnp.eye(4),
...     x=1,    # x is 1-dimensional
...     yz=3,   # yz is 3-dimensional
... )
Source code in probpipe/distributions/_joint_gaussian.py
def __init__(
    self,
    *,
    mean: ArrayLike,
    cov: ArrayLike,
    name: str | None = None,
    **component_shapes: int,
):
    if not component_shapes:
        raise ValueError("JointGaussian requires at least one component.")

    _, (mean, cov) = _promote_floats(mean, cov)

    total_dim = sum(component_shapes.values())
    if mean.shape != (total_dim,):
        raise ValueError(
            f"mean shape {mean.shape} does not match total dimension "
            f"({total_dim},) from component shapes {component_shapes}."
        )
    if cov.shape != (total_dim, total_dim):
        raise ValueError(
            f"cov shape {cov.shape} does not match ({total_dim}, {total_dim})."
        )

    self._mean_vec = mean
    self._cov_mat = cov
    if name is None:
        name = "joint_gaussian(" + ",".join(component_shapes.keys()) + ")"
    super().__init__(name=name)
    self._component_shapes = dict(component_shapes)

    # Build slices and component MultivariateNormal distributions
    from .multivariate import MultivariateNormal as MVN

    slices = {}
    components = {}
    offset = 0
    for cname, dim in self._component_shapes.items():
        sl = slice(offset, offset + dim)
        slices[cname] = sl
        components[cname] = MVN(
            loc=mean[sl],
            cov=cov[sl, sl],
            name=cname,
        )
        offset += dim

    self._components = components
    self._component_slices = slices  # still needed for Gaussian conditioning
    self._record_template = _build_record_template(self._components)
    self._total_dim = total_dim  # still needed for Gaussian conditioning

mean_vector property

Full mean vector.

covariance property

Full covariance matrix.

fields property

Component names in insertion order.

event_shapes property

Per-component event shapes.

components property

Read-only view of the component distributions.