Operations¶
Standalone workflow functions for sampling, density evaluation, moments, conditioning, and conversion. Each op dispatches via the matching protocol, participates in broadcasting, and is subject to Prefect orchestration when configured.
Sampling¶
sample(dist, *, key=None, sample_shape=())
¶
Draw samples from a distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist
|
SupportsSampling
|
Distribution to sample from. |
required |
key
|
PRNGKey
|
JAX PRNG key. Auto-generated if |
None
|
sample_shape
|
tuple of int
|
Shape prefix for independent draws. |
()
|
Source code in probpipe/core/ops.py
Density evaluation¶
log_prob(dist, value)
¶
Evaluate the normalized log-density at value.
Source code in probpipe/core/ops.py
prob(dist, value)
¶
Evaluate the density at value (exp(log_prob)).
Source code in probpipe/core/ops.py
unnormalized_log_prob(dist, value)
¶
Evaluate the unnormalized log-density at value.
Source code in probpipe/core/ops.py
unnormalized_prob(dist, value)
¶
Evaluate the unnormalized density at value (exp(unnormalized_log_prob)).
Source code in probpipe/core/ops.py
random_log_prob(dist, value=None)
¶
Return the random (normalized) log-density of a random measure.
For a RandomMeasure[T] M with draws D ~ M, the random
function x ↦ log D(x) is itself a callable returning a
distribution over scalars at every input.
When value is omitted, returns that callable as a
RandomFunction. When
value is provided, returns the Distribution[Array] over
log D(value) directly — equivalent to
random_log_prob(dist)(value). The two-argument form mirrors
log_prob for non-random distributions.
Concrete subclasses implement a single method
_random_log_prob() returning a RandomFunction; the optional
value dispatch lives entirely in this op, not on the protocol.
Source code in probpipe/core/ops.py
random_unnormalized_log_prob(dist, value=None)
¶
Return the random unnormalized log-density of a random measure.
For a RandomMeasure[T] M with draws D ~ M, the random
function x ↦ log D̃(x) (where D̃ is the unnormalized density
of D) is itself a callable returning a distribution over
scalars at every input.
When value is omitted, returns that callable as a
RandomFunction. When
value is provided, returns the Distribution[Array] over
log D̃(value) directly — equivalent to
random_unnormalized_log_prob(dist)(value). The two-argument
form mirrors unnormalized_log_prob for non-random
distributions.
Concrete subclasses implement a single method
_random_unnormalized_log_prob() returning a RandomFunction;
the optional value dispatch lives entirely in this op, not on
the protocol.
Source code in probpipe/core/ops.py
Moments and expectations¶
mean(dist)
¶
Compute E[X] where X ~ dist.
The return type is T-shaped where T is dist's sample type:
- Numeric distributions (
T = Array) — returnsArray. - Structured distributions (
T = Record) — returnsRecord. RandomMeasure[T](Titself aDistribution[T]) — returns the marginalisedDistribution[T]with marginalD̄(A) = ∫ D(A) dM(D).
Requires the distribution to implement SupportsMean.
Source code in probpipe/core/ops.py
variance(dist)
¶
Compute Var[X].
Requires the distribution to implement SupportsVariance.
Source code in probpipe/core/ops.py
cov(dist)
¶
Compute the covariance matrix.
Requires the distribution to implement SupportsCovariance.
Source code in probpipe/core/ops.py
expectation(dist, f, *, key=None, num_evaluations=None, return_dist=None)
¶
Compute E[f(X)] where X ~ dist.
Source code in probpipe/core/ops.py
Conditioning¶
condition_on(dist, observed=None, *, method=None, **kwargs)
¶
Condition a distribution on observed values.
Observed data can be passed positionally or as named keyword arguments::
# Positional (backward compatible):
condition_on(model, y_obs)
# Named data kwargs — bundled into Record(X=..., y=...):
condition_on(model, X=bootstrap["X"], y=bootstrap["y"],
n_broadcast_samples=16)
When named data kwargs are distribution views from the same parent, the workflow function broadcasting machinery samples the parent once and distributes the fields, preserving joint correlation.
Dispatch priority:
- Explicit override —
method="tfp_nuts"(or any registered name) routes directly to the named inference method. - Exact conditioning — if dist implements
SupportsConditioning, its_condition_onis called for a closed-form result (e.g., conjugate updates, joint marginalization). - Registry auto-select — the inference method registry picks the highest-priority feasible algorithm (NUTS, HMC, RWMH, etc.).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist
|
Distribution
|
Distribution or model to condition. Need not implement
|
required |
observed
|
Any
|
Observed values to condition on. |
None
|
method
|
str or None
|
If provided, use the named inference method from the registry instead of the default dispatch. |
None
|
**kwargs
|
Any
|
Inference parameters (e.g., |
{}
|
Source code in probpipe/core/ops.py
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 | |
condition_on dispatches inference via the
inference-method registry; override the
auto-selection with method="<name>".
Conversion¶
from_distribution(source, target_type, *, key=None, check_support=True, **kwargs)
¶
Convert source into an instance of target_type.
Delegates to the global converter registry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
source
|
Distribution
|
Source distribution to convert. |
required |
target_type
|
type
|
The target distribution class. |
required |
key
|
PRNGKey
|
JAX PRNG key for sampling-based conversion. |
None
|
check_support
|
bool
|
If |
True
|
**kwargs
|
Any
|
Additional keyword arguments passed to the converter. |
{}
|
Source code in probpipe/core/ops.py
Backed by the converter registry.