Extending ProbPipe¶
ProbPipe's extension surface is small and grouped by capability. The table below maps each kind of extension to the contract you implement against and the registry (if any) you register with. Each row links to the section on this page that covers it in detail.
| To add a... | Implement | Register with |
|---|---|---|
| New distribution family | Subclass of Distribution, RecordDistribution, NumericRecordDistribution, or TFPDistribution |
(none — capability is detected by isinstance against the matching protocol) |
| New op support on an existing distribution | The matching underscore method (_sample, _log_prob, _mean, ...) on the class |
(none — see Protocols for which method backs which op) |
| New inference method (custom sampler, optimiser, ...) | Subclass of Method declaring supported_types, priority, check(), and execute() |
inference_method_registry.register(...) — see Custom inference methods |
| New distribution-to-distribution converter | Subclass of Converter with check() / convert() |
converter_registry.register(...) — see Custom converters |
New canonical bijector for a Constraint |
A factory returning a TFP bijector | register_bijector(constraint_or_class, factory) — see Custom bijectors |
| New auxiliary-metadata adapter (custom array-like) | capture and restore callables |
register_aux(leaf_type, capture, restore) — see Custom auxiliary metadata |
The two remaining sections — Broadcasting internals and the Internals page — document classes that an extension rarely constructs directly but may need to reference.
Distribution base classes¶
Distribution is the abstract root. RecordDistribution and
NumericRecordDistribution specialise it for distributions whose
_sample() returns a Record or NumericRecord respectively.
TFPDistribution wraps an existing TFP Distribution.
Distribution(*, name)
¶
Bases: ABC
Abstract base for all ProbPipe distributions, parameterized by
value type T.
Every distribution has a name. Leaf distributions (Normal, Gamma,
etc.) require an explicit name= argument; composite distributions
(ProductDistribution, EmpiricalDistribution, etc.) auto-generate a
name from their components when one is not provided.
Provides naming, provenance, conversion, and approximation tracking.
Sampling and expectation capabilities are provided by the
SupportsSampling protocol.
Source code in probpipe/core/_distribution_base.py
is_approximate
property
¶
Whether this distribution is an approximation (e.g., from sampling or MCMC).
validation_results
property
¶
Results from predictive_check runs.
Each entry is a dict with at least "replicated_statistics"
and "test_fn_name". Posterior checks also include
"observed_statistic" and "p_value".
record_template
property
¶
Structural template for this distribution's samples, or None.
.. deprecated::
This property is migrating to
RecordDistribution.
Non-Record distributions should not rely on this.
auxiliary
property
¶
An xarray DataTree of auxiliary information (diagnostics,
sample statistics, algorithm metadata), or None.
Populated by inference methods. Follows ArviZ group conventions
(posterior, sample_stats, warmup, etc.) with metadata
stored as DataTree attributes.
with_source(source)
¶
Attach provenance to this distribution (write-once).
Source code in probpipe/core/_distribution_base.py
renamed(new_name)
¶
Return a shallow copy with a different name.
The copy shares all internal state but has a new name.
Provenance is tracked: the copy's source records the rename
operation and points to the original as parent. Any cached
record_template is cleared so it regenerates with the new
name (relevant for TFPDistribution).
Source code in probpipe/core/_distribution_base.py
from_batched_params(*, name, batch_shape=None, **batched_params)
classmethod
¶
Class-method alias for DistributionArray.from_batched_params.
Lets users write the ergonomic per-class form::
Normal.from_batched_params(loc=jnp.zeros(5), scale=1.0, name="x")
instead of the universal entry point::
DistributionArray.from_batched_params(
Normal, loc=jnp.zeros(5), scale=1.0, name="x",
)
Both produce the same DistributionArray — the alias is a
thin classmethod that calls the universal factory with
cls bound. Subclasses inherit the alias automatically;
no per-family override is needed.
See DistributionArray.from_batched_params for the full
contract (dispatch on
SupportsArrayBackend,
batch_shape inference, per-cell name suffixing).
Source code in probpipe/core/_distribution_base.py
RecordDistribution(*, name)
¶
Bases: Distribution[Record]
Generic Record-based distribution.
Provides named component access (fields, __getitem__,
select()) and Record-aware flatten / unflatten. Does NOT impose
numeric shape / dtype conventions (dtype, support,
event_shape) — those belong on NumericRecordDistribution
and its consumers.
Concrete subclasses must set _record_template (a
RecordTemplate describing the named
structure) and implement the relevant sampling / log-prob protocols.
Source code in probpipe/core/_distribution_base.py
record_template
property
¶
Structural template describing this distribution's samples.
Returns a RecordTemplate with
field names and per-field shapes, or None if no template
is set.
fields
property
¶
Field names from the record_template, or empty tuple.
event_size
property
¶
Total number of scalar elements in one sample.
Sums the sizes of every numeric leaf described by the template;
opaque leaves contribute zero. A NumericRecordTemplate has
flat_size already cached — reuse it when available.
event_shapes
property
¶
Per-field event shapes.
For untemplated distributions (no record_template) returns
{}; use event_shape for the scalar shape of a
single unnamed field. Nested sub-templates collapse to ()
at the top level.
shape
property
¶
Shape of one draw (equals the sole field's event_shape).
ndim
property
¶
Number of axes in one draw.
select(*fields, **mapping)
¶
Select named fields as views for workflow function broadcasting.
Positional args use the field name as the argument name.
Keyword args remap: select(x="field_name").
Usage::
predict(**posterior.select("r", "K", "phi"), x=x_grid)
Source code in probpipe/core/_record_distribution.py
select_all()
¶
Return every component as a view, for splatting into function calls.
Sugar for select(*self.fields). Matches
Record.select_all / RecordArray.select_all so
the splat-all pattern works uniformly across the three field-
bearing container types. Preserves cross-field correlation via
the parent-identity machinery in the WorkflowFunction sweep
layer.
Source code in probpipe/core/_record_distribution.py
keys()
¶
values()
¶
items()
¶
flatten_value(value)
¶
Flatten a NumericRecord or NumericRecordArray sample to a flat array.
The flatten operation is numeric-only, so Record inputs must
be convertible to NumericRecord (all leaves numeric). Raw
arrays are returned unchanged.
Source code in probpipe/core/_record_distribution.py
unflatten_value(flat)
¶
Reconstruct a NumericRecord or NumericRecordArray from a flat array.
Source code in probpipe/core/_record_distribution.py
as_flat_distribution()
¶
View this distribution as a flat FlatNumericRecordDistribution.
Returns a FlattenedDistributionView
with event_shape = (event_size,) for algorithms expecting
flat vectors (MCMC, optimizers, VI methods). Inverse:
as_record_distribution.
Source code in probpipe/core/_record_distribution.py
NumericRecordDistribution(*, name)
¶
Bases: RecordDistribution
Distribution over numeric arrays with Record support.
Extends RecordDistribution with numeric-specific metadata.
The class is the most general numeric random variable in ProbPipe:
samples are a pytree of jax.Array leaves named via
RecordTemplate. Single-leaf distributions (Normal,
Beta, MultivariateNormal, …) are the trivial case; the
same machinery covers future multi-leaf joint distributions.
A Distribution represents one random variable. Collections of
independent distributions live in
DistributionArray.
Canonical / convenience accessor pairs
Per-field accessors (canonical) are the source of truth; scalar accessors (convenience) are derived shortcuts that raise on multi-leaf templates. Subclasses override the canonical side; convenience accessors are inherited and derived automatically.
+------------+---------------------------------+--------------------------------------+
| Concept | Canonical (per-leaf) | Convenience (single-leaf) |
+============+=================================+======================================+
| Structure | record_template | — |
+------------+---------------------------------+--------------------------------------+
| Pytree | treedef (from template) | — |
+------------+---------------------------------+--------------------------------------+
| Shapes | event_shapes : dict | event_shape : tuple |
| | | (raises on multi-leaf) |
+------------+---------------------------------+--------------------------------------+
| Dtypes | dtypes : dict (raises if | dtype : dtype | None (unique or |
| | not declared) | None) |
+------------+---------------------------------+--------------------------------------+
| Supports | supports : dict (raises if | support : Constraint (raises on |
| | not declared) | multi-leaf) |
+------------+---------------------------------+--------------------------------------+
| Flat dim | event_size : int | — |
+------------+---------------------------------+--------------------------------------+
Single-field auto-template
Any concrete subclass that declares an event_shape and is
constructed with a name= gets an auto-built single-field
RecordTemplate (RecordTemplate(**{name: event_shape}))
on first read of record_template. Subclasses that need a
multi-field template (joint distributions) override
record_template directly to skip the auto-build.
_sample contract (Story A)
- Single-leaf templates →
_sample(key, sample_shape)returns a rawjax.Arrayof shapesample_shape + event_shape. - Multi-leaf templates →
_sample(key, sample_shape)returns aNumericRecord(orNumericRecordArrayfor non-emptysample_shape) keyed byrecord_template.fields.
The treedef property locks this relationship by deriving
from record_template.
Standard distributions (Normal, Gamma, Poisson, etc.) inherit from
this class via TFPDistribution.
Source code in probpipe/core/_distribution_base.py
record_template
property
¶
Auto-build a single-field RecordTemplate from
name + event_shape when the subclass hasn't set one.
Cached via object.__setattr__ on first read.
Multi-field subclasses (joint distributions) override this
property to skip the auto-build.
dtypes
property
¶
Per-field dtypes — canonical, subclasses must override.
Returns a {field: dtype} dict aligned with record_template.fields.
Default raises NotImplementedError rather than returning a
silent default-float for every field (which lied for integer-
valued distributions like Bernoulli, Poisson, Categorical).
supports
property
¶
Per-field support constraints — canonical, subclasses must override.
Subclasses should override to provide meaningful constraints.
Default raises NotImplementedError.
dtype
property
¶
Convenience: scalar dtype if all fields share one, else None.
Derived from dtypes. dtypes is the canonical
per-field accessor; subclasses override that, not this.
support
property
¶
Convenience: support for a single-field distribution.
Derived from supports. supports is the canonical
per-field accessor; subclasses override that (or, for
TFPDistribution-backed classes that follow the existing
single-field override pattern, override support directly
to short-circuit this derivation).
Raises TypeError (via _single_field_name) on
multi-field distributions; reach for supports then.
treedef
property
¶
Treedef of one sample, derived from record_template.
Locks the relationship between the structural template and the sample's pytree shape:
- Single-leaf template (
len(fields) <= 1) → a leaf treedef (jax.tree.structure(None)). Matches the_samplecontract that single-leaf distributions return a rawjax.Array. - Multi-leaf template → the treedef of a
NumericRecordskeleton with the same field names. Matches the_samplecontract that multi-leaf distributions return aNumericRecord.
Cached on first read; the underlying template is immutable post-construction so the cache is always valid.
flat_event_shapes
property
¶
List of per-field event shapes in template field order.
Tree-walk over event_shapes: list(event_shapes.values()).
For a single-field distribution this is [event_shape];
for a multi-leaf distribution it's one entry per leaf.
flatten_value(value)
¶
Flatten a sample (Record, NumericRecordArray, or array) to flat trailing axis.
Delegates to RecordDistribution.flatten_value for Record-like
inputs. For raw arrays, flattens event dimensions preserving
leading batch/sample dims.
Source code in probpipe/core/_numeric_record_distribution.py
unflatten_value(flat)
¶
Unflatten a flat trailing axis back to event dims, Record, or NumericRecordArray.
When record_template is set with multiple fields, delegates
to RecordDistribution.unflatten_value (returns NumericRecord
or NumericRecordArray). For single-field leaf distributions,
reshapes to (*batch, *event_shape) for _log_prob compat.
Source code in probpipe/core/_numeric_record_distribution.py
as_flat_distribution()
¶
View this distribution as a flat distribution.
Returns a FlattenedDistributionView wrapping this
distribution. The view satisfies the
FlatNumericRecordDistribution contract regardless of
self's structure (multi-field, multi-dim event, …) — its
event_shape is always (self.event_size,).
Inverse: FlatNumericRecordDistribution.as_record_distribution.
Source code in probpipe/core/_numeric_record_distribution.py
as_record_distribution(*, template, name=None)
¶
Lift this distribution to a Record-keyed view under template.
Only available on FlatNumericRecordDistribution subclasses.
Calling this on a non-flat NumericRecordDistribution
raises TypeError with a hint to call
as_flat_distribution first.
See FlatNumericRecordDistribution.as_record_distribution
for the actual implementation and parameters.
Source code in probpipe/core/_numeric_record_distribution.py
TFPDistribution(*, name)
¶
Bases: NumericRecordDistribution, SupportsSampling, SupportsLogProb, SupportsMean, SupportsVariance, SupportsCovariance
Base class for distributions backed by a tfd.Distribution instance.
Subclasses set self._tfp_dist in __init__. The private
protocol methods _sample, _expectation, _log_prob,
_mean, and _variance all delegate to TFP (or use MC
fallback for expectations).
Inherits from SupportsSampling, SupportsExpectation,
SupportsLogProb (provides _prob,
_unnormalized_log_prob, _unnormalized_prob defaults),
SupportsMean, and SupportsVariance.
Rejects batched parameters
Per the framework hierarchy "one random variable per
Distribution" rule (CONTRIBUTING.md), the constructor raises
ValueError when the underlying tfd.Distribution has
a non-empty batch_shape. Wrap multiple distributions in a
DistributionArray instead — the migration
factory is from_batched_params
(or the per-class alias Distribution.from_batched_params).
The check fires in __init__ after super().__init__(name=name)
completes, so concrete subclasses that set self._tfp_dist
before calling super().__init__ (the standard pattern used
by Normal, Beta, Gamma, …) are validated. Subclasses
that set _tfp_dist after super().__init__ (e.g.
KDEDistribution) are skipped
via the hasattr guard — those classes are responsible for
their own shape invariants and don't go through TFP's batched
parameter convention.
Internal infrastructure that legitimately needs the batched form
(the _TFPArrayBackend fused storage, converters, sequential
joints, GRF predictions) opts into the bypass via
_allow_batched_tfp_init.
Final-stage initializer for TFP-backed distributions.
Concrete subclasses (Normal, Beta, …) set
self._tfp_dist in their own __init__ before calling
super().__init__(name=name), so by the time we get here
the TFP backend is fully constructed and we can validate its
batch_shape.
Source code in probpipe/distributions/_tfp_base.py
dtypes
property
¶
Per-field dtypes — the canonical accessor for
TFPDistribution. Reads self._tfp_dist.dtype and
spreads it across every field of the auto-built
single-field template. dtype (the convenience) is
inherited from the base and derives from this dict.
support
property
¶
The support of this distribution. Override in subclasses.
supports
property
¶
Per-field support constraints — spreads the single-field
support (overridden by each concrete TFP-backed
subclass) across the auto-built template.
Protocols¶
Protocols define capabilities that distributions may support. Each
protocol is @runtime_checkable; compliance is checked via isinstance
at dispatch time. External types satisfy a protocol structurally by
implementing the underscore method (_sample, _log_prob, ...) — no
inheritance required.
SupportsSampling
¶
Bases: Protocol
Distribution that can produce samples via _sample(key, sample_shape).
Only requires _sample(key, sample_shape); concrete classes choose
their own implementation strategy (TFP batched sampling, index
resampling, vmap over a local single-draw helper, etc.).
Does NOT extend SupportsExpectation — not all samplable
distributions support array-valued expectations (e.g., random functions).
Classes that support both should inherit both protocols.
Return-type convention
The shape of the return value depends on whether the distribution emits structured samples and whether the caller asks for a batch:
===================== ======================= =========================================
Distribution kind sample_shape == () sample_shape == (S1, S2, ...)
===================== ======================= =========================================
Numeric (raw array) Array[*event_shape] Array[*sample_shape, *event_shape]
RecordDistribution Record / NumericRecord NumericRecordArray(batch_shape=sample_shape)
===================== ======================= =========================================
To draw a single sample, call _sample(key, ()). Implementations
that find it clearer to factor out a single-draw helper should
define it as a private method (e.g. _one_bootstrap) and have
_sample dispatch on sample_shape internally.
SupportsExpectation
¶
Bases: Protocol
Distribution that can compute E[f(X)].
SupportsLogProb
¶
Bases: SupportsUnnormalizedLogProb, Protocol
Distribution with a (normalized) log-density.
Extends SupportsUnnormalizedLogProb because any distribution
with a normalized density also has an unnormalized one (they coincide).
The base Distribution class
provides _unnormalized_log_prob defaulting to _log_prob.
SupportsUnnormalizedLogProb
¶
Bases: Protocol
Distribution with an unnormalized log-density.
Provides _unnormalized_log_prob(value).
SupportsRandomLogProb
¶
Bases: Protocol
Distribution over distributions with a random (normalized) log-density.
For a RandomMeasure[T] M, _random_log_prob returns the
random function x ↦ log D(x) where D ~ M as a
RandomFunction. The op
layer (random_log_prob) optionally
forwards an input value by calling the returned random function;
that two-argument convenience is purely op-layer sugar — concrete
subclasses implement only the zero-argument method here.
Mirrors SupportsLogProb for the random-measure setting.
SupportsRandomUnnormalizedLogProb
¶
Bases: Protocol
Distribution over distributions with a random unnormalized log-density.
Mirrors SupportsUnnormalizedLogProb for the random-measure
setting. _random_unnormalized_log_prob returns the random
function x ↦ log D̃(x) where D̃ is the unnormalized density
of a draw D ~ M, as a
RandomFunction. The op
layer (random_unnormalized_log_prob)
optionally forwards an input value by calling the returned random
function; that two-argument convenience is purely op-layer sugar —
concrete subclasses implement only the zero-argument method here.
SupportsMean
¶
Bases: Protocol
Distribution with an exact mean via _mean().
The return type is T-shaped where T is the distribution's
sample type. For the common cases this is:
NumericRecordDistributionand friends (T = Array) — returnsArray.RecordDistributionand friends (T = Record) — returnsRecord.RandomMeasure[T](T = Distribution[T]) — returns the marginalisedDistribution[T]with marginalD̄(A) = ∫ D(A) dM(D).
The protocol is sample-type-polymorphic by design: the array-valued
and structured paths are unchanged; RandomMeasure opts in by
implementing _mean to return its expected distribution.
Independent of SupportsExpectation. The ops layer falls
back to MC estimation via SupportsExpectation when this protocol
is absent. Concrete classes that want the MC default can apply
@compute_expectation to their _mean implementation (only
valid when T is array-like).
SupportsVariance
¶
Bases: Protocol
Distribution with an exact variance via _variance().
Independent of SupportsExpectation. The ops layer falls
back to MC estimation via SupportsExpectation when this protocol
is absent. Concrete classes that want the MC default can apply
@compute_expectation to their _variance implementation.
SupportsCovariance
¶
Bases: Protocol
Distribution with an exact covariance via _cov().
Independent of SupportsExpectation. The ops layer falls
back to MC estimation via SupportsExpectation when this protocol
is absent. Concrete classes that want the MC default can apply
@compute_expectation to their _cov implementation.
SupportsConditioning
¶
Bases: Protocol
Distribution that has a fast, built-in condition_on path.
Implemented by distributions whose _condition_on produces a
posterior without calling into the inference registry — either
closed-form (conjugate updates, joint Gaussian marginalization)
or amortized (e.g., a pre-trained SBI posterior that just runs a
forward pass). When condition_on(dist, observed) is called
and dist implements this protocol, the built-in path is used
directly; otherwise the inference method registry selects an
algorithm (NUTS, RWMH, variational, ...).
Probabilistic models whose conditioning requires on-the-fly MCMC or variational inference should not implement this protocol — let the registry handle algorithm selection instead.
SupportsArrayBackend is the only class-level protocol: its declared
method (_make_array_backend) is a @classmethod, so the runtime check
is isinstance(MyDistribution, SupportsArrayBackend) against the class
itself, not an instance.
SupportsArrayBackend
¶
Bases: Protocol
Distribution class that supports efficient batched construction.
Used by DistributionArray.from_batched_params to fuse
storage when the caller's components are homogeneous instances of
the same class. Implementations construct an internal
_DistributionArrayBackend that owns the batched parameters
and the vectorised ops; DistributionArray becomes a thin
consumer.
Distribution classes that don't implement this protocol still work
in a DistributionArray via the literal-array fallback (one
Distribution instance per cell) — slower but correct.
The protocol attaches to the class, not to instances. The
runtime check is isinstance(MyDistribution, SupportsArrayBackend)
(i.e. the class itself implements _make_array_backend).
isinstance(an_instance, SupportsArrayBackend) returns
True too — instances inherit class attributes, and
runtime_checkable just looks for the named attribute — but
the result is misleading because the contract is at class
scope.
The protocol is internal to the library; user code never calls
_make_array_backend directly. DistributionArray is the
sole consumer.
Examples:
A distribution class declares the capability by implementing the classmethod::
class MyDistribution(Distribution[T]):
@classmethod
def _make_array_backend(
cls,
*,
name: str,
batch_shape: tuple[int, ...],
**batched_params,
) -> _DistributionArrayBackend:
return _MyArrayBackend(
cls=cls, name=name, batch_shape=batch_shape,
**batched_params,
)
Custom inference methods¶
Method subclasses register with inference_method_registry and
declare supported_types, a priority, and check() / execute()
methods. When condition_on runs, the
registry tries methods in descending priority order and the first whose
check() reports feasibility wins. The built-in methods table is on
Modeling and inference → Inference methods.
Setting priority for a new method¶
The integer returned by priority carries semantics: it tells the registry
whether your method should auto-dispatch, and if so, where it ranks
against the alternatives.
priority > 50— exact: auto-dispatched, higher = preferred among exact alternatives.0 < priority <= 50— inexact: auto-dispatched, higher = preferred among inexact alternatives. The50break is documentary; the registry walks every positive priority uniformly.priority == 0— opt-in only: the registry skips the method during auto-dispatch. The method is reachable by name viamethod="...". This is the default; aMethodsubclass that doesn't overrideprioritygets opt-in behaviour automatically.
Selection criteria¶
Choose a number with these axes in mind, roughly in order of weight:
- Robustness when applicable — how often the method gives a usable
answer without per-model tuning, conditional on
check()passing. - Computational cost per effective sample (or per converged result). Two kinds of cost advantage deserve separate consideration: algorithmic specialisation that exploits model structure for an asymptotic speedup (Kalman, INLA, conjugate updates), and engineering specialisation — same algorithm, faster backend (nutpie's Rust-backed NUTS vs. TFP's; Stan's compiled gradients vs. JAX traces).
- Approximation quality — analytical exact > controlled-error approximations > asymptotically-exact MCMC > intrinsic approximations.
- Diagnostic richness — methods that fail silently rank below methods with built-in failure signals, all else equal.
- Model-class breadth as a tiebreaker only. A broader-applicability
method does not need a higher priority than a narrow one; whichever
applies wins via
check().
Tier ranges — exact (51–100)¶
Five tiers, each 10 wide. Criteria are stated as positive properties of the method.
| Range | Criterion |
|---|---|
| 91–100 | Per-call cost in a strictly better complexity class than general-purpose alternatives; the speedup comes from exploiting model structure. |
| 81–90 | Optimised implementation of a more general algorithm; lower constant-factor cost than the reference implementation within its applicable model class. |
| 71–80 | Self-tuning; converges robustly without per-model hyperparameter selection. |
| 61–70 | Well-understood with strong convergence theory; performs well once hand-tuned. |
| 51–60 | Slow per effective sample or unreliable in typical use. |
Tier ranges — inexact (1–50)¶
Four named tiers ordered by the strength of the asymptotic-to-exact story. The slot at 11–20 is intentionally reserved for methods with intermediate guarantees that don't fit cleanly into a named tier.
| Range | Criterion |
|---|---|
| 41–50 | Asymptotically exact under algorithmic refinement; bias is a knob the user can tighten (step size, mini-batch size). |
| 31–40 | Particle-based approximation refinable by particle count; quality improves with more particles, though convergence may be slow or unstable. |
| 21–30 | Parametric posterior approximation; error bounded by family expressiveness or by regularity conditions on the posterior shape. |
| 11–20 | (reserved for methods with intermediate guarantees not covered by neighbouring tiers) |
| 1–10 | No asymptotic-to-exact guarantee in practice; quality bounded by intrinsic information loss (summary statistics, fixed tolerance, learned representations). |
Setting priority on a Method subclass¶
class MyNutsMethod(Method):
@property
def priority(self) -> int:
# Tier 71-80 (self-tuning, broadly applicable).
return 75
A method that should not auto-dispatch — perhaps it's experimental, has
sharp failure modes, or exists only for method= testing — leaves
priority at the inherited default of 0. The registry will exclude
it from the auto-dispatch walk; users can still invoke it explicitly by
name.
MethodRegistry()
¶
Generic priority-based method registry.
Methods are tried in descending priority order. The first method
whose check() returns feasible=True wins. Users can also
select a specific method by name.
Methods whose effective priority equals OPT_IN_ONLY_PRIORITY
(0) are excluded from the auto-dispatch walk and are reachable
only by name via method="...".
Source code in probpipe/core/_registry.py
register(method)
¶
Register a method (invalidates the lookup cache).
Source code in probpipe/core/_registry.py
set_priorities(**name_to_priority)
¶
Override the priority of one or more methods.
Higher priority methods are tried first during auto-selection.
Overrides are unrestricted: the new value can be any integer,
including the opt-in-only sentinel 0. When an override
moves a method into or out of 0, the registry emits a
UserWarning because that crossing changes whether the
method participates in auto-dispatch at all.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**name_to_priority
|
int
|
Keyword arguments mapping method names to new priorities.
e.g., |
{}
|
Raises:
| Type | Description |
|---|---|
KeyError
|
If a method name is not registered. |
Source code in probpipe/core/_registry.py
get_method(name)
¶
Look up a method by name. Raises KeyError if not found.
Source code in probpipe/core/_registry.py
list_methods()
¶
check(*args, method=None, **kwargs)
¶
Check feasibility. Auto-selects or uses the named method.
Source code in probpipe/core/_registry.py
execute(*args, method=None, **kwargs)
¶
Execute using the best (or named) method.
Raises TypeError if no method is applicable.
Raises KeyError if method is not registered.
Source code in probpipe/core/_registry.py
Method
¶
Bases: ABC
Abstract base for a pluggable method in a registry.
Subclasses declare a unique name, which supported_types
they handle (for fast filtering), a priority (higher =
tried first), and implement check/execute.
name
abstractmethod
property
¶
Unique identifier for this method (e.g., 'tfp_nuts').
priority
property
¶
Auto-dispatch ordering.
Higher priority methods are tried first during auto-selection.
The default of 0 is the opt-in-only sentinel — a method
that doesn't override this property is reachable only by name
via method="...". See the module docstring for the
> 50 / <= 50 / == 0 convention contributors should
follow when choosing a number.
supported_types()
abstractmethod
¶
check(*args, **kwargs)
abstractmethod
¶
MethodInfo(feasible, method_name='', description='')
dataclass
¶
Metadata describing whether a method is applicable.
Returned by a method's check() to describe feasibility
without performing the actual computation.
Custom converters¶
Subclass Converter, implement check() / convert(), and register
with converter_registry.register(...). The built-in priorities and the
registry handle itself are documented under
Conversion and interop.
Custom bijectors¶
register_bijector overrides the canonical bijector returned by
bijector_for(c) for a given Constraint instance or class. See
Constraints → Bijectors.
Custom auxiliary metadata¶
register_aux extends the Record ↔ NumericRecord round-trip to a
new array-like type. See
Records → Auxiliary-metadata registry.
Broadcasting internals (exposed for extension)¶
DistributionArray is the shape-indexed container produced by
parameter-sweep workflow functions whose inner call returns a
Distribution. BroadcastDistribution is the joint container produced
by @workflow_function when include_inputs=True.
DistributionArray(components, *, batch_shape=None, name=None)
¶
Bases: Distribution[T]
Ordered collection of independent scalar distributions
addressed by a (multi-d) batch_shape.
Exposes only the container surface (indexing, iteration,
components, batch_shape, event_shape). Vectorized
ops (sample, mean, variance, log_prob, …) are
delivered by the WorkflowFunction
sweep layer, which treats the array as Array[Distribution] and
dispatches cell-by-cell.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
components
|
sequence of Distribution
|
The n component distributions. Must be non-empty and share
|
required |
batch_shape
|
tuple of int
|
Leading batch shape. Defaults to |
None
|
name
|
str
|
Name for provenance / introspection. Defaults to
|
None
|
Notes
How ops work on a DistributionArray. The class deliberately
does not implement _sample / _mean / _log_prob /
etc. — those would couple the array to specific component
capabilities. Vectorization is handled at a different layer:
sample(da, ...)calls thesampleWorkflowFunction, whose dispatch sees aDistributionArrayargument where the op's annotation expects a scalarSupportsSampling.- WF dispatches cell-by-cell: each
da[i]is sampled, results are stacked alongbatch_shapeand returned as aNumericRecordArray(orRecordArrayfor non-numeric components). For ops whose inner return is itself aDistribution(e.g. posterior-predictive sweeps), the result is a nestedDistributionArray. - Multiple swept arguments combine by the product rule:
passing two
DistributionArrayargs of shapes(m,)and(n,)produces an output of shape(m, n).
Consequences of this design:
- Calling
da._sample(key)directly raisesAttributeError—DistributionArraydoesn't have_sample. Always use the public op (sample(da, key=...)). isinstance(da, SupportsSampling)isFalseeven when every component supports sampling. The protocol attaches to individual cells, not to the array.- Component capabilities don't have to be uniform: an array
where some cells are
SupportsLogProband some are not will fail at op-dispatch on the first non-supporting cell, rather than rejecting at construction.
Source code in probpipe/core/_distribution_array.py
components
property
¶
Flat tuple of component distributions, in row-major order
across the leading batch_shape.
For backend-delegated arrays the tuple is materialised lazily
on first access via backend.cell(i) for each flat index
and cached. Cells are still freshly constructed inside
cell() (no de-duplication), so successive components
accesses return the same cached tuple but indexing via
__getitem__ / _flat_component always returns a
fresh scalar.
batch_shape
property
¶
Batch axes of this DistributionArray.
Components are scalar (one random variable per
Distribution), so this is simply self._batch_shape
— there is no inherit-from-component composition. Multi-d
broadcasting outputs pass the full sweep shape; the default
1-D form is (n,).
event_shape
property
¶
Shared event_shape across components.
dtype
property
¶
Per-cell dtype.
Cells share an event shape and (in practice) a dtype because homogeneous backends produce uniformly-typed cells and literal-array constructions inherit from the source. Backend-delegated arrays read it from the backend; literal arrays read it from the first component.
size
property
¶
Total number of cells (prod(batch_shape)).
Mirrors np.ndarray.size / jax.Array.size: len(da)
is the leading-axis dim, da.size is the total cell count.
from_batched_params(dist_cls, *, name, batch_shape=None, **batched_params)
classmethod
¶
Construct a DistributionArray of homogeneous components.
The recommended way to build a DistributionArray whose
cells are all instances of the same class — most often a
TFP-backed family like Normal — without manually
constructing each cell::
DistributionArray.from_batched_params(
Normal, loc=jnp.zeros(5), scale=1.0, name="x",
)
When dist_cls implements
SupportsArrayBackend (every
TFP-backed concrete class does — Normal, Beta,
Gamma, MultivariateNormal, …), the factory dispatches
onto the backend's fused-storage path: a single batched
TFP backend owns the parameters; cells are materialised lazily
on demand. Otherwise the factory falls back to the
literal-array path: one dist_cls instance per cell with
per-cell parameters auto-sliced and names auto-suffixed
f"{name}_{flat_index}".
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_cls
|
type
|
A |
required |
name
|
str
|
Base name; per-cell scalars are named
|
required |
batch_shape
|
tuple of int
|
Leading shape of the batched parameters. Inferred from
|
None
|
**batched_params
|
Constructor kwargs for |
{}
|
Returns:
| Type | Description |
|---|---|
DistributionArray
|
Backend-delegated when |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Examples:
Backend-delegated TFP path::
da = DistributionArray.from_batched_params(
Normal, loc=jnp.zeros(5), scale=1.0, name="x",
)
da.batch_shape # (5,)
da[0].name # "x_0"
da[0].loc # 0.0
da._backend # _TFPArrayBackend(...)
Literal-array fallback (any class without the protocol)::
da = DistributionArray.from_batched_params(
MyCustomDist, param=jnp.arange(4), name="z",
)
da._backend # None — fallback path
da[0].name # "z_0"
Source code in probpipe/core/_distribution_array.py
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | |
__getitem__(key)
¶
Index a single component, slice, or multi-d tuple.
Supported key forms:
int→ index along the leading batch axis.slice→ slice along the leading batch axis; returns a newDistributionArraycontaining the sliced subset.tuple→ multi-axis index (int or slice per axis); collapses along int axes and slices along slice axes.
Indexing uses row-major order across batch_shape.
The pure-int path translates the key directly to a flat index
via np.ravel_multi_index without materialising a
shape=batch_shape object array.
Source code in probpipe/core/_distribution_array.py
__iter__()
¶
Iterate the leading axis (numpy / jax convention).
len(self) items are yielded:
ndim == 1(the common case): each item is a scalarDistributioncell.ndim >= 2: each item is aDistributionArrayof shapebatch_shape[1:]— a leading-axis slice, mirroringiter(np.zeros((2, 3)))yielding two(3,)-shaped views.ndim == 0(batch_shape == ()): raisesTypeErrorto matchiter(np.zeros(())). Reach for_flat_component(orcomponents) to access the single cell — those work uniformly across everybatch_shapeincluding().
For flat row-major access over every cell (the pre-#178
behaviour), use components or
range(self.size) with _flat_component.
Source code in probpipe/core/_distribution_array.py
BroadcastDistribution(input_samples, output_samples, weights=None, *, log_weights=None, output_distributions=None, broadcast_args, name=None)
¶
Bases: Distribution[dict], SupportsSampling
Joint distribution over broadcast inputs and function output.
Stores the paired input–output samples from a
WorkflowFunction broadcast. Supports
joint sampling (resampling paired input–output tuples) and named
component access.
Call marginalize to obtain the output-only marginal, which
supports moment protocols (mean, variance, etc.) when the output
data permits.
.. note::
BroadcastDistribution does not inherit from
JointDistribution.
JointDistribution requires all leaves to be
NumericRecordDistribution instances with TFP shape semantics
(batch_shape, event_shape), but a broadcast output can be
any type — arrays, distributions, strings, etc. — and input
samples are plain arrays without distribution metadata. The two
hierarchies serve different roles: JointDistribution models
structured probabilistic variables; BroadcastDistribution
records the empirical input–output mapping of a function
evaluation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_samples
|
dict[str, Array]
|
|
required |
output_samples
|
Array or list
|
|
required |
weights
|
array-like, :class:`~probpipe.Weights`, or None
|
Non-negative weights (normalized internally). A pre-built
|
None
|
log_weights
|
array-like, :class:`~probpipe.Weights`, or None
|
Log-unnormalized weights. A pre-built |
None
|
output_distributions
|
list of Distribution or None
|
When each function evaluation returns a |
None
|
broadcast_args
|
list of str
|
Ordered names of the broadcast arguments. |
required |
name
|
str or None
|
Distribution name for provenance. |
None
|
Source code in probpipe/core/_broadcast_distributions.py
n
property
¶
Number of input–output pairs.
weights
property
¶
Normalised weights, shape (n,).
input_samples
property
¶
Broadcast input samples: {arg_name: (n, *event_shape)}.
samples
property
¶
Output samples (forwarded to output marginal for backward compat).
output
property
¶
Alias for marginalize.
marginalize()
¶
Return the output marginal distribution.
Lazy — the marginal is constructed on first call and cached.
The marginal inherits this distribution's provenance (if any)
so the lineage is preserved without a direct reference to the
BroadcastDistribution.
Source code in probpipe/core/_broadcast_distributions.py
The truly private machinery (_RecordDistributionView, _vmap_sample,
_mc_expectation) lives on Internals, alongside the
public-but-rarely-constructed FlattenedDistributionView and
NumericRecordDistributionView.