Skip to content

Constraints

A Constraint describes the value set on which a distribution is supported — positivity, the unit interval, the simplex, a positive-definite cone, and so on. Built-in singletons cover the common cases; factories (interval, greater_than, integer_interval) parameterise the rest. bijector_for maps any Constraint to a TFP bijector that takes unconstrained ℝⁿ into the constrained set, which is the usual route for MAP estimation, reparameterized MCMC, and VI.

Base class

Constraint

Describes the support of a distribution (the set of valid values).

check(value)

Return a boolean array indicating which elements satisfy the constraint.

Source code in probpipe/core/constraints.py
def check(self, value: ArrayLike) -> Array:
    """Return a boolean array indicating which elements satisfy the constraint."""
    raise NotImplementedError

Built-in Constraints

The following constraint singletons are available:

  • probpipe.real -- any real number
  • probpipe.positive -- strictly positive
  • probpipe.non_negative -- non-negative
  • probpipe.non_negative_integer -- non-negative integers
  • probpipe.boolean -- 0 or 1
  • probpipe.unit_interval -- values in [0, 1]
  • probpipe.simplex -- vectors summing to 1 with non-negative entries
  • probpipe.positive_definite -- positive-definite matrices
  • probpipe.sphere -- unit-norm vectors

Constraint Factories

interval(low, high)

Source code in probpipe/core/constraints.py
def interval(low: ArrayLike, high: ArrayLike) -> _Interval:
    return _Interval(low, high)

greater_than(lower_bound)

Source code in probpipe/core/constraints.py
def greater_than(lower_bound: ArrayLike) -> _GreaterThan:
    return _GreaterThan(lower_bound)

integer_interval(low, high)

Source code in probpipe/core/constraints.py
def integer_interval(low: ArrayLike, high: ArrayLike) -> _IntegerInterval:
    return _IntegerInterval(low, high)

Bijectors for Unconstrained Reparameterization

bijector_for(c) returns the canonical TFP bijector mapping unconstrained ℝⁿ to values satisfying the constraint c. Defaults follow Pyro / NumPyro conventions:

Constraint Bijector
real tfb.Identity()
positive tfb.Exp()
non_negative tfb.Softplus()
unit_interval tfb.Sigmoid()
interval(low, high) tfb.Sigmoid(low, high)
greater_than(b) tfb.Chain([tfb.Shift(b), tfb.Exp()])
simplex tfb.SoftmaxCentered()
positive_definite tfb.Chain([tfb.CholeskyOuterProduct(), tfb.FillScaleTriL()])

Constraints with no canonical smooth bijector (sphere, boolean, non_negative_integer, integer_interval) raise NotImplementedError with a specific reason.

register_bijector is the extension point for custom Constraint subclasses or for overriding defaults (e.g., preferring Softplus over Exp for positive). Instance registrations take precedence over type registrations.

Round-trip with TransformedDistribution.support

bijector_for and the forward map used by TransformedDistribution.support are not strict inverses. TransformedDistribution(base, bijector_for(c)).support == c holds only for real, positive, and unit_interval. For non_negative (Softplus → positive), interval(low, high) (parameterized Sigmoid → unit_interval), simplex and positive_definite (Chain → real), and greater_than (Chain → real), the round-trip drifts to a coarser support. The two maps answer different questions and have different reliability tiers.

bijector_for(constraint)

Return a canonical bijector mapping ℝⁿ to constraint's support.

Lookup precedence:

  1. Exact instance match (e.g., the singleton positive).
  2. Type match, walking the Constraint MRO (most-specific first).
  3. NotImplementedError if nothing matches.

Parameters:

Name Type Description Default
constraint Constraint

The target support.

required

Returns:

Type Description
Bijector

A bijector whose forward image lies in constraint's support.

Raises:

Type Description
NotImplementedError

If no factory is registered for constraint's type, or if the constraint is one for which no smooth bijector exists (discrete constraints, the unit sphere).

Source code in probpipe/distributions/_bijector_dispatch.py
def bijector_for(constraint: Constraint) -> tfb.Bijector:
    """Return a canonical bijector mapping ℝⁿ to *constraint*'s support.

    Lookup precedence:

    1. Exact instance match (e.g., the singleton ``positive``).
    2. Type match, walking the Constraint MRO (most-specific first).
    3. :class:`NotImplementedError` if nothing matches.

    Parameters
    ----------
    constraint : Constraint
        The target support.

    Returns
    -------
    tfb.Bijector
        A bijector whose forward image lies in *constraint*'s support.

    Raises
    ------
    NotImplementedError
        If no factory is registered for *constraint*'s type, or if the
        constraint is one for which no smooth bijector exists (discrete
        constraints, the unit sphere).
    """
    # 1. Instance match.  ``Constraint.__hash__`` hashes
    # ``(type, sorted-__dict__-items)``, which raises ``TypeError`` when
    # ``__dict__`` contains an unhashable value (e.g., a JAX array as
    # ``low``); catch that and fall through to type lookup.
    try:
        if constraint in _CONSTRAINT_BIJECTOR_REGISTRY:
            return _CONSTRAINT_BIJECTOR_REGISTRY[constraint](constraint)
    except TypeError:
        pass

    # 2. Type match via MRO.
    for cls in type(constraint).__mro__:
        if cls in _CONSTRAINT_BIJECTOR_REGISTRY:
            return _CONSTRAINT_BIJECTOR_REGISTRY[cls](constraint)

    raise NotImplementedError(
        f"No bijector registered for {constraint!r}. "
        f"Use ``probpipe.register_bijector`` to add one."
    )

register_bijector(key, factory)

Register a bijector factory for a Constraint type or singleton.

Parameters:

Name Type Description Default
key type or Constraint

Either a Constraint subclass (applies to all instances of that type, including parameterized ones) or a specific Constraint value (applies to constraints equal to it). Instance keys take precedence over type keys at lookup time.

required
factory callable

factory(constraint) -> tfb.Bijector. The constraint instance is passed in so the factory can read parameters (e.g., low / high from an _Interval).

required
Notes

Re-registering an existing key silently overwrites the previous factory.

Avoid registering against the base Constraint class itself: every constraint shares it in their MRO, so a base-class registration would catch every unmatched constraint.

Source code in probpipe/distributions/_bijector_dispatch.py
def register_bijector(
    key: type | Constraint,
    factory: BijectorFactory,
) -> None:
    """Register a bijector factory for a Constraint type or singleton.

    Parameters
    ----------
    key : type or Constraint
        Either a :class:`Constraint` subclass (applies to all instances of
        that type, including parameterized ones) or a specific
        :class:`Constraint` value (applies to constraints equal to it).
        Instance keys take precedence over type keys at lookup time.
    factory : callable
        ``factory(constraint) -> tfb.Bijector``.  The constraint instance
        is passed in so the factory can read parameters (e.g., ``low`` /
        ``high`` from an :class:`_Interval`).

    Notes
    -----
    Re-registering an existing key silently overwrites the previous
    factory.

    Avoid registering against the base :class:`Constraint` class itself:
    every constraint shares it in their MRO, so a base-class registration
    would catch every unmatched constraint.
    """
    _CONSTRAINT_BIJECTOR_REGISTRY[key] = factory