Internals¶
These classes and helpers are not part of the stable public API — they're
documented here for contributors, advanced users who need to understand the
broadcasting / protocol-dispatch machinery, and anyone debugging a failing
isinstance(obj, SupportsX) check. Names start with _ to reflect this.
Signatures may change without deprecation warnings between PRs.
If you find yourself reaching for something on this page in user code, there's probably a public replacement — check Operations, Distributions, Records, or Extending ProbPipe first, and open an issue if there isn't.
Record-distribution views¶
dist["field"] on a RecordDistribution
returns a _RecordDistributionView — a lightweight reference that preserves
correlation when multiple views from the same parent are used in
@workflow_function broadcasting. The
view's protocol membership (SupportsSampling, SupportsMean,
SupportsVariance, SupportsLogProb, SupportsCovariance) is computed
dynamically from the parent's capabilities, so
isinstance(view, SupportsMean) is true iff the parent is.
_RecordDistributionView(parent, key)
¶
Bases: Distribution
Lightweight reference to a single named field of a Record-based distribution.
The Record-world analog of
DistributionView. Preserves
correlation when multiple views from the same parent are used in
WorkflowFunction broadcasting.
Dynamic protocol support: this base class intentionally does
not inherit any SupportsFoo protocols. Each concrete instance
is a cached subclass built by _view_class_for_parent, which
mixes in only the protocols the parent actually implements. Calling
_RecordDistributionView(parent, key) routes through __new__
and picks the right subclass automatically, so
isinstance(view, SupportsSampling) is True iff the parent is.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parent
|
Distribution
|
A distribution with |
required |
key
|
str
|
Field name in the parent's |
required |
Source code in probpipe/core/_record_distribution.py
parent
property
¶
The RecordDistribution this view points at.
Shared-identity signal for the WorkflowFunction sweep layer:
views with the same parent co-sample (preserve correlation)
when passed as sibling broadcast args to a workflow function.
Matches the _RecordArrayView.parent surface.
field
property
¶
Name of the viewed field (the top-level key into the parent).
shape
property
¶
Shape of one draw from this view — equals event_shape.
dtype
property
¶
Dtype of a single draw, if the parent exposes dtypes.
ndim
property
¶
Number of axes in a single draw (len(event_shape)).
Flat / Record view helpers¶
FlattenedDistributionView and NumericRecordDistributionView are the public
view classes for the flat ↔ Record-keyed bridge. Both follow the same
dynamic-protocol pattern as _RecordDistributionView — only the protocols
the base distribution supports are attached to the view.
FlattenedDistributionView is a FlatNumericRecordDistribution:
it exposes any distribution as flat (single field, event_shape=(N,)),
for interop with algorithms that expect a flat parameter vector.
Construct via as_flat_distribution.
NumericRecordDistributionView is the inverse: it takes a
FlatNumericRecordDistribution and a NumericRecordTemplate, and presents
the distribution under the template's named-field structure. Construct via
as_record_distribution.
FlatNumericRecordDistribution(*, name)
¶
Bases: NumericRecordDistribution
A NumericRecordDistribution whose samples are flat 1-D vectors.
The flat contract:
- exactly one field (
len(fields) == 1) event_shape == (N,)for someN- samples shaped
sample_shape + (N,)
Algorithms that operate on a flat parameter vector — MCMC kernels,
optimisers, Hessian / curvature builders, variational families,
Pathfinder / Laplace surrogates — should declare their input as
FlatNumericRecordDistribution. The natively-multivariate
parametrics (MultivariateNormal,
Dirichlet, Multinomial,
VonMisesFisher) and
FlattenedDistributionView all satisfy this contract.
Scalar parametrics (Normal, Beta, …) have
event_shape == () and do not satisfy the contract directly;
call as_flat_distribution to get
a FlattenedDistributionView (whose event_shape is (1,)).
This class is also the home of
as_record_distribution — the inverse of
as_flat_distribution. Receiver
typing means non-flat callers fail at the type level rather than at
a runtime shape check.
Source code in probpipe/core/_distribution_base.py
flat_size
property
¶
Number of scalar elements — equal to event_shape[0].
Validates the flat contract on access: subclasses with
non-1-D event_shape raise TypeError here rather than
silently truncating to the first dimension.
as_record_distribution(*, template, name=None)
¶
Lift this flat distribution to a Record-keyed view under template.
Inverse of as_flat_distribution.
Samples come back as NumericRecord /
NumericRecordArray keyed by template.fields.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
template
|
NumericRecordTemplate
|
Target structural skeleton. Must be a
|
required |
name
|
str
|
Name for the lifted distribution. Defaults to |
None
|
Returns:
| Type | Description |
|---|---|
NumericRecordDistribution
|
A thin view over |
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
ValueError
|
If |
Source code in probpipe/core/_numeric_record_distribution.py
FlattenedDistributionView(base)
¶
Bases: FlatNumericRecordDistribution
Wraps a distribution as a flat FlatNumericRecordDistribution.
Sampling produces flat vectors of shape (event_size,), and
_log_prob accepts flat vectors and delegates to the wrapped
distribution after unflattening.
This is the primary interoperability mechanism: any algorithm written
against FlatNumericRecordDistribution works with an
arbitrary RecordDistribution /
NumericRecordDistribution via
dist.as_flat_distribution().
Dynamic protocol support: the view's isinstance compliance
matches the base's capabilities — a log-prob-only base produces a
view that is not SupportsSampling, and a sampling-only base
produces one that is not SupportsLogProb.
Source code in probpipe/core/_numeric_record_distribution.py
NumericRecordDistributionView(base, template, *, name=None)
¶
Bases: NumericRecordDistribution
View that lifts a flat distribution to a Record-keyed structure.
Inverse of FlattenedDistributionView. self._base is a
FlatNumericRecordDistribution (single-field, event_shape
== (N,)); self.record_template is the user-supplied
NumericRecordTemplate (not the source's auto-template).
Sampling, log-prob, and moments delegate to self._base and
reshape via the template's flatten / unflatten machinery.
Capability protocols match the source via
_numeric_record_distribution_view_class_for_base.
Constructed via
FlatNumericRecordDistribution.as_record_distribution.
Source code in probpipe/core/_numeric_record_distribution.py
event_shape
property
¶
Single-field shortcut: the lone field's shape.
Raises TypeError via _single_field_name for
multi-field templates; reach for event_shapes (dict)
in that case.
event_shapes
property
¶
Per-field event shapes from the user-supplied template.
dtypes
property
¶
Per-field dtypes — all fields inherit the source's single dtype.
supports
property
¶
Per-field supports — all fields inherit the source's single support.
base_distribution
property
¶
The underlying single-field flat distribution.
Sampling primitive¶
_vmap_sample(dist, key, sample_shape=())
¶
Draw samples via jax.vmap over dist._sample(key, ()).
Convenience for distributions whose _sample implementation is
naturally a single-draw function: call this helper from _sample
and it will handle the sample_shape prefix by splitting keys
and vmap-ing over the single-draw path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist
|
Distribution whose |
required | |
key
|
PRNGKey
|
JAX PRNG key. |
required |
sample_shape
|
tuple of int
|
Shape prefix for independent draws. |
()
|
Source code in probpipe/core/_numeric_record_distribution.py
_mc_expectation(dist, f, *, key=None, num_evaluations=None, return_dist=None)
¶
Estimate E[f(X)] where X ~ dist via Monte Carlo.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist
|
Distribution with a |
required | |
f
|
callable
|
Function mapping a single sample to an array (or pytree of arrays). |
required |
key
|
PRNGKey
|
JAX PRNG key for sampling. Auto-generated if |
None
|
num_evaluations
|
int
|
Number of samples to draw. If |
None
|
return_dist
|
bool
|
If |
None
|