Skip to content

Records and data

Named, immutable containers for structured non-random data, plus the batched (RecordArray) and parameter-sweep (Design) variants built on top.

Field access is bracket-only: record["x"], array["x"]. Slash-delimited strings index nested paths: record["params/intercept"].

Records

Record(_dict=None, /, *, name=None, **fields)

Named, immutable, pytree-registered container for structured values.

Fields iterate in insertion order and are returned verbatim; Record performs no coercion between backends (numpy, JAX, xarray, Python scalars, strings, nested Records are all accepted). Use NumericRecord when you want a uniform jax.Array leaf type and flatten / unflatten support.

Parameters:

Name Type Description Default
**fields _FieldValue

Named values. Values may be JAX or numpy arrays, Python scalars, strings, xarray / pandas objects, nested Record, or any other opaque object. Nothing is converted at construction. Field names must not contain / (reserved as the nested-path separator).

{}
name str

Name for provenance / introspection. Auto-generated from field names if not provided.

None
Source code in probpipe/core/record.py
def __init__(
    self,
    _dict: dict[str, _FieldValue] | None = None,
    /,
    *,
    name: str | None = None,
    **fields: _FieldValue,
):
    if _dict is not None:
        if fields:
            raise ValueError("Cannot pass both positional dict and keyword arguments")
        fields = _dict
    if not fields:
        raise ValueError("Record requires at least one named field")
    for field_name in fields:
        _check_no_path_sep(field_name)
    store = dict(fields)
    object.__setattr__(self, "_store", store)
    # Auto-generate name from field names if not provided
    if name is None:
        name = "record(" + ",".join(store.keys()) + ")"
    object.__setattr__(self, "_name", name)
    object.__setattr__(self, "_source", None)

name property

Name of this Record.

source property

Provenance describing how this Record was created, or None.

fields property

Field names in insertion order.

with_source(source)

Attach provenance to this Record (write-once).

Mirrors Distribution.with_source_source is set once and subsequent calls raise. Semantic transformations (replace, merge, without, map, map_with_names) return a new Record with an empty source; the caller attaches fresh provenance there if desired.

Notes

_source is runtime-only metadata — it is not serialised into the JAX pytree aux (a Provenance parent is a Distribution or Record, neither of which is hashable by structure). Round-tripping through jax.tree_util.tree_flatten / tree_unflatten therefore drops the source; re-attach it on the reconstructed Record if you need to preserve the chain.

Source code in probpipe/core/record.py
def with_source(self, source: Provenance) -> Record:
    """Attach provenance to this Record (write-once).

    Mirrors ``Distribution.with_source`` — `_source` is set once and
    subsequent calls raise. Semantic transformations (``replace``,
    ``merge``, ``without``, ``map``, ``map_with_names``) return a
    *new* Record with an empty source; the caller attaches fresh
    provenance there if desired.

    Notes
    -----
    ``_source`` is runtime-only metadata — it is not serialised into
    the JAX pytree aux (a ``Provenance`` parent is a ``Distribution``
    or ``Record``, neither of which is hashable by structure).
    Round-tripping through ``jax.tree_util.tree_flatten`` /
    ``tree_unflatten`` therefore drops the source; re-attach it on
    the reconstructed Record if you need to preserve the chain.
    """
    if self._source is not None:
        raise RuntimeError(
            f"Source already set on {self!r}. "
            "Provenance is write-once; create a new Record instead."
        )
    object.__setattr__(self, "_source", source)
    return self

items()

Iterate over (name, value) pairs.

Source code in probpipe/core/record.py
def items(self) -> Iterator[tuple[str, _FieldValue]]:
    """Iterate over (name, value) pairs."""
    return iter(self._store.items())

keys()

Iterate over field names.

Source code in probpipe/core/record.py
def keys(self) -> Iterator[str]:
    """Iterate over field names."""
    return iter(self._store)

values()

Iterate over values.

Source code in probpipe/core/record.py
def values(self) -> Iterator[_FieldValue]:
    """Iterate over values."""
    return iter(self._store.values())

select(*fields, **mapping)

Select fields as a dict, for splatting into function calls.

Positional args use the field name as the key (identity mapping). Keyword args remap: select(x="field_name"){"x": self.field_name}.

Usage::

predict(**params.select("r", "K"), x=x_grid)
predict(**params.select(growth_rate="r"), x=x_grid)
Source code in probpipe/core/record.py
def select(self, *fields: str, **mapping: str) -> dict[str, _FieldValue]:
    """Select fields as a dict, for splatting into function calls.

    Positional args use the field name as the key (identity mapping).
    Keyword args remap: ``select(x="field_name")`` → ``{"x": self.field_name}``.

    Usage::

        predict(**params.select("r", "K"), x=x_grid)
        predict(**params.select(growth_rate="r"), x=x_grid)
    """
    result: dict[str, _FieldValue] = {}
    for f in fields:
        if f not in self._store:
            raise KeyError(f"No field {f!r} in Record")
        result[f] = self[f]
    for arg_name, field_name in mapping.items():
        if field_name not in self._store:
            raise KeyError(f"No field {field_name!r} in Record")
        result[arg_name] = self[field_name]
    return result

select_all()

Return every field as a dict, for splatting into function calls.

Sugar for select(*self.fields). Subclasses whose __getitem__ returns a view (RecordArray_RecordArrayView, RecordDistribution_RecordDistributionView) inherit this method and return per-field views — so f(**ra.select_all()) triggers the parent-identity zip sweep in WorkflowFunction, and f(**dist.select_all()) similarly preserves cross-field correlation.

Source code in probpipe/core/record.py
def select_all(self) -> dict[str, _FieldValue]:
    """Return every field as a dict, for splatting into function calls.

    Sugar for ``select(*self.fields)``. Subclasses whose
    ``__getitem__`` returns a view (``RecordArray`` →
    ``_RecordArrayView``, ``RecordDistribution`` →
    ``_RecordDistributionView``) inherit this method and return
    per-field views — so ``f(**ra.select_all())`` triggers the
    parent-identity zip sweep in ``WorkflowFunction``, and
    ``f(**dist.select_all())`` similarly preserves cross-field
    correlation.
    """
    return self.select(*self.fields)

replace(**updates)

Return a new Record with specified fields replaced.

Returns an instance of type(self) so that subclasses (NumericRecord) preserve their class through the update.

Source code in probpipe/core/record.py
def replace(self, **updates: ArrayLike | Record) -> Record:
    """Return a new Record with specified fields replaced.

    Returns an instance of ``type(self)`` so that subclasses
    (``NumericRecord``) preserve their class through the update.
    """
    new = dict(self._store)
    for k, v in updates.items():
        if k not in new:
            raise KeyError(f"Cannot replace non-existent field {k!r}")
        new[k] = v
    return type(self)(new)

merge(other)

Return a new Record combining fields from self and other.

Raises ValueError if any field names overlap. Returns an instance of type(self).

Source code in probpipe/core/record.py
def merge(self, other: Record) -> Record:
    """Return a new Record combining fields from self and other.

    Raises ``ValueError`` if any field names overlap. Returns an
    instance of ``type(self)``.
    """
    overlap = set(self._store) & set(other._store)
    if overlap:
        raise ValueError(f"Overlapping field names: {overlap}")
    combined = dict(self._store)
    combined.update(other._store)
    return type(self)(combined)

without(*names)

Return a new Record with the specified fields removed.

Returns an instance of type(self).

Source code in probpipe/core/record.py
def without(self, *names: str) -> Record:
    """Return a new Record with the specified fields removed.

    Returns an instance of ``type(self)``.
    """
    new = {k: v for k, v in self._store.items() if k not in names}
    if not new:
        raise ValueError("Cannot remove all fields from Record")
    return type(self)(new)

to_dict()

Return a dict of stored values (recursive for nested Record).

Leaves are returned verbatim; no coercion to numpy or JAX.

Source code in probpipe/core/record.py
def to_dict(self) -> dict[str, Any]:
    """Return a dict of stored values (recursive for nested Record).

    Leaves are returned verbatim; no coercion to numpy or JAX.
    """
    result: dict[str, Any] = {}
    for name, val in self._store.items():
        if isinstance(val, Record):
            result[name] = val.to_dict()
        else:
            result[name] = val
    return result

to_numpy()

Return a dict of numpy arrays (recursive for nested Record).

Each numeric leaf is converted via np.asarray. Non-numeric leaves (strings, opaque objects) are returned as-is. Backend metadata (xarray dims / coords, pandas index) is stripped — use to_numeric followed by NumericRecord.to_native if you need a metadata-preserving round-trip.

Source code in probpipe/core/record.py
def to_numpy(self) -> dict[str, Any]:
    """Return a dict of numpy arrays (recursive for nested Record).

    Each numeric leaf is converted via ``np.asarray``. Non-numeric
    leaves (strings, opaque objects) are returned as-is. Backend
    metadata (xarray dims / coords, pandas index) is stripped — use
    :meth:`to_numeric` followed by :meth:`NumericRecord.to_native`
    if you need a metadata-preserving round-trip.
    """
    result: dict[str, Any] = {}
    for name, val in self._store.items():
        if isinstance(val, Record):
            result[name] = val.to_numpy()
        elif hasattr(val, "shape") or isinstance(val, (int, float, complex)):
            result[name] = np.asarray(val)
        else:
            result[name] = val
    return result

to_numeric()

Convert to a NumericRecord with every leaf a jax.Array.

Per-field metadata that jnp.asarray would drop (xarray dims / coords / attrs, pandas index / columns / dtypes) is captured via the aux registry in probpipe.core._array_backend and stored on the resulting NumericRecord. Calling NumericRecord.to_native on the result reverses the conversion, restoring each leaf to its original backend type. Nested Record children recurse — every level becomes a NumericRecord.

Equivalent to NumericRecord.from_record.

Raises:

Type Description
TypeError

If any leaf is not coercible via jnp.asarray (e.g. strings, opaque Python objects).

Source code in probpipe/core/record.py
def to_numeric(self) -> "NumericRecord":  # type: ignore[name-defined]
    """Convert to a :class:`NumericRecord` with every leaf a ``jax.Array``.

    Per-field metadata that ``jnp.asarray`` would drop (xarray
    dims / coords / attrs, pandas index / columns / dtypes) is
    captured via the aux registry in
    :mod:`probpipe.core._array_backend` and stored on the resulting
    ``NumericRecord``. Calling :meth:`NumericRecord.to_native`
    on the result reverses the conversion, restoring each leaf to
    its original backend type. Nested ``Record`` children recurse
    — every level becomes a ``NumericRecord``.

    Equivalent to :meth:`NumericRecord.from_record`.

    Raises
    ------
    TypeError
        If any leaf is not coercible via ``jnp.asarray`` (e.g.
        strings, opaque Python objects).
    """
    # Lazy import to avoid the module-level circular dep:
    # _numeric_record.py imports Record from this module.
    from ._numeric_record import NumericRecord
    return NumericRecord.from_record(self)

ensure(x) classmethod

Coerce x to Record if it isn't already.

  • Record → pass through
  • dictRecord(**x)
  • array-like → Record(data=x)
Source code in probpipe/core/record.py
@classmethod
def ensure(cls, x: Any) -> Record:
    """Coerce *x* to Record if it isn't already.

    - ``Record`` → pass through
    - ``dict`` → ``Record(**x)``
    - array-like → ``Record(data=x)``
    """
    if isinstance(x, cls):
        return x
    if isinstance(x, dict):
        return cls(x)
    return cls(data=x)

from_dict(d) classmethod

Construct Record from a dict of arrays.

Source code in probpipe/core/record.py
@classmethod
def from_dict(cls, d: dict[str, ArrayLike | Record]) -> Record:
    """Construct Record from a dict of arrays."""
    return cls(d)

map(fn)

Apply fn to each leaf, returning a new Record.

Nested Record objects are traversed and rebuilt with the same class. fn sees leaves as stored (no coercion).

Source code in probpipe/core/record.py
def map(self, fn: Callable[[Any], Any]) -> Record:
    """Apply *fn* to each leaf, returning a new Record.

    Nested ``Record`` objects are traversed and rebuilt with the same
    class. ``fn`` sees leaves as stored (no coercion).
    """
    fields: dict[str, Any] = {}
    for name, val in self._store.items():
        if isinstance(val, Record):
            fields[name] = val.map(fn)
        else:
            fields[name] = fn(val)
    return type(self)(fields)

map_with_names(fn)

Apply fn(name, value) to each leaf, returning a new Record.

Source code in probpipe/core/record.py
def map_with_names(self, fn: Callable[[str, Any], Any]) -> Record:
    """Apply *fn(name, value)* to each leaf, returning a new Record."""
    fields: dict[str, Any] = {}
    for name, val in self._store.items():
        if isinstance(val, Record):
            fields[name] = val.map_with_names(fn)
        else:
            fields[name] = fn(name, val)
    return type(self)(fields)

NumericRecord(_dict=None, /, *, name=None, **fields)

Bases: Record

Record where every leaf is a jax.Array.

Adds flatten / unflatten / flat_size for serialising the record to / from a flat 1-D vector. Construction validates that every leaf is a numeric value (or a nested NumericRecord) and coerces scalar / numpy / xarray / pandas leaves to jnp.ndarray so downstream code sees a uniform JAX array type. Backend-specific metadata (xarray dims / coords / attrs, pandas index / columns / dtypes) is captured via the aux registry in probpipe.core._array_backend and stored on the instance; to_native reverses the conversion.

Parameters:

Name Type Description Default
**fields ArrayLike | NumericRecord

Named values. Every leaf must be a numeric array (jax.numpy, numpy, xarray.DataArray, pandas.Series / DataFrame with numeric dtype), a numeric Python scalar (int, float, complex, bool), or a nested NumericRecord. Non-numeric values raise TypeError at construction time.

{}
Notes

NumericRecord(**fields) and Record(**fields).to_numeric() are semantically identical — both consult the aux registry to capture metadata, both coerce leaves via jnp.asarray, both raise TypeError on non-coercible leaves.

Validation and coercion happen before the underlying Record is constructed, so _store is populated exactly once and remains immutable from the moment __init__ returns — consistent with the __slots__ + __setattr__ guard on the base class.

Source code in probpipe/core/_numeric_record.py
def __init__(
    self,
    _dict: dict[str, ArrayLike | NumericRecord] | None = None,
    /,
    *,
    name: str | None = None,
    **fields: ArrayLike | NumericRecord,
):
    # Build the validated + coerced field dict *before* Record's
    # __init__ runs, so ``_store`` is populated exactly once and the
    # "constructed once, never touched" invariant implied by
    # ``__slots__`` + the ``__setattr__`` guard holds.
    if _dict is not None:
        if fields:
            raise ValueError(
                "Cannot pass both positional dict and keyword arguments"
            )
        raw_fields = _dict
    else:
        raw_fields = fields
    validated, aux = self._validate_and_coerce(raw_fields)
    super().__init__(validated, name=name)
    # Cache flat_size — leaves are immutable arrays after construction.
    total = 0
    for val in self._store.values():
        if isinstance(val, NumericRecord):
            total += val.flat_size
        else:
            total += int(val.size)
    object.__setattr__(self, "_flat_size", total)
    # Aux is ``None`` if no field had a registered hook — keeps the
    # common all-jax case allocation-free and lets ``to_native``
    # short-circuit.
    object.__setattr__(self, "_aux", aux if aux else None)

flat_size property

Total number of scalar elements across all numeric leaves.

aux property

Captured backend metadata blobs, keyed by field name (or None).

Each entry is the opaque aux_blob returned by the registered capture hook for that field's original leaf type. Fields whose leaf type wasn't in the registry (plain numpy / jax / Python scalars) are absent.

The hook pair is intentionally not exposed here — call to_native to materialise the original backend objects.

flatten()

Concatenate all leaf arrays into a single 1-D vector.

Fields are traversed in insertion order; nested NumericRecord are traversed depth-first. Each leaf is raveled before concatenation.

Source code in probpipe/core/_numeric_record.py
def flatten(self) -> jnp.ndarray:
    """Concatenate all leaf arrays into a single 1-D vector.

    Fields are traversed in insertion order; nested ``NumericRecord``
    are traversed depth-first. Each leaf is raveled before
    concatenation.
    """
    parts: list[jnp.ndarray] = []
    for val in self._store.values():
        if isinstance(val, NumericRecord):
            parts.append(val.flatten())
        else:
            parts.append(jnp.ravel(val))
    return jnp.concatenate(parts)

unflatten(flat, *, template) classmethod

Reconstruct a NumericRecord from a flat array.

Parameters:

Name Type Description Default
flat array

1-D array of concatenated scalars.

required
template RecordTemplate

Provides field names and shapes for reconstruction.

required
Source code in probpipe/core/_numeric_record.py
@classmethod
def unflatten(
    cls,
    flat: jnp.ndarray,
    *,
    template: RecordTemplate,
) -> NumericRecord:
    """Reconstruct a ``NumericRecord`` from a flat array.

    Parameters
    ----------
    flat : array
        1-D array of concatenated scalars.
    template : RecordTemplate
        Provides field names and shapes for reconstruction.
    """
    fields: dict[str, jnp.ndarray | NumericRecord] = {}
    offset = 0

    for field_name in template.fields:
        spec = template[field_name]
        size = _spec_size(spec)
        chunk = flat[offset : offset + size]
        if isinstance(spec, RecordTemplate):
            fields[field_name] = cls.unflatten(chunk, template=spec)
        else:
            fields[field_name] = chunk.reshape(spec)
        offset += size

    return cls(fields)

from_record(record) classmethod

Convert a Record to NumericRecord, validating leaves.

Equivalent to record.to_numeric(); both paths consult the aux registry, coerce every leaf via jnp.asarray, and raise TypeError on non-coercible leaves. Nested Record children recurse, preserving structure.

Source code in probpipe/core/_numeric_record.py
@classmethod
def from_record(cls, record: Record) -> NumericRecord:
    """Convert a ``Record`` to ``NumericRecord``, validating leaves.

    Equivalent to ``record.to_numeric()``; both paths consult the
    aux registry, coerce every leaf via ``jnp.asarray``, and raise
    ``TypeError`` on non-coercible leaves. Nested ``Record``
    children recurse, preserving structure.
    """
    return cls({
        field_name: cls.from_record(val) if isinstance(val, Record) else val
        for field_name, val in record._store.items()
    })

to_native()

Restore each leaf to its original backend type, returning a Record.

Fields whose original leaf type was registered in probpipe.core._array_backend are restored via hooks.restore(jax_array, aux). Fields without captured aux pass through as their stored jax.Array. Nested NumericRecord fields recurse.

The result is a permissive Record, not a NumericRecord — restored xarray / pandas leaves are no longer jax.Array and would fail the numeric invariant.

Source code in probpipe/core/_numeric_record.py
def to_native(self) -> Record:
    """Restore each leaf to its original backend type, returning a :class:`Record`.

    Fields whose original leaf type was registered in
    :mod:`probpipe.core._array_backend` are restored via
    ``hooks.restore(jax_array, aux)``. Fields without captured aux
    pass through as their stored ``jax.Array``. Nested
    :class:`NumericRecord` fields recurse.

    The result is a permissive :class:`Record`, not a
    ``NumericRecord`` — restored xarray / pandas leaves are no
    longer ``jax.Array`` and would fail the numeric invariant.
    """
    fields: dict[str, Any] = {}
    aux = self._aux or {}
    for field_name, val in self._store.items():
        if isinstance(val, NumericRecord):
            fields[field_name] = val.to_native()
            continue
        entry = aux.get(field_name)
        if entry is None:
            fields[field_name] = val
        else:
            hooks, blob = entry
            fields[field_name] = hooks.restore(val, blob)
    return Record(fields)

RecordTemplate(_dict=None, /, **field_specs)

Structural description of a Record: field names, leaf shapes, nesting.

Stores the skeleton of a Record without data — field names, per-field shapes (for numeric leaves) or None (for opaque leaves), and optional nested RecordTemplate for hierarchical structure.

Inspired by JAX's PyTreeDef: a template can reconstruct a Record from flat data, and describes the expected structure for type-checking and flattening.

Parameters:

Name Type Description Default
**field_specs _FieldSpec

Named fields. Each value is one of:

  • tuple[int, ...] — shape of a numeric array leaf (e.g., () for scalar, (3,) for 3-vector).
  • None — opaque (non-array) leaf.
  • RecordTemplate — nested sub-structure.
{}

Examples:

::

RecordTemplate(x=(), y=(3,))                    # -> NumericRecordTemplate
RecordTemplate(label=None, x=())                 # -> RecordTemplate (mixed)
RecordTemplate(physics=RecordTemplate(force=(), mass=()), obs=())
Notes

Calling RecordTemplate(...) directly auto-promotes to a NumericRecordTemplate when every spec is numeric (and every nested sub-template is itself all-numeric). That keeps flat_size and numeric_leaf_shapes reachable in the common all-numeric case without requiring the caller to name the subclass. Mixed templates (any None spec) stay as plain RecordTemplate and do not expose flat_size — it isn't a meaningful quantity once opaque leaves are in the mix.

Source code in probpipe/core/record.py
def __init__(
    self,
    _dict: dict[str, _FieldSpec] | None = None,
    /,
    **field_specs: _FieldSpec,
):
    if _dict is not None:
        if field_specs:
            raise ValueError(
                "Cannot pass both positional dict and keyword arguments"
            )
        field_specs = _dict
    if not field_specs:
        raise ValueError(f"{type(self).__name__} requires at least one field")
    # Validate specs
    for name, spec in field_specs.items():
        _check_no_path_sep(name)
        if spec is not None and not isinstance(spec, (tuple, RecordTemplate)):
            raise TypeError(
                f"Field {name!r}: spec must be a shape tuple, None, "
                f"or RecordTemplate, got {type(spec).__name__}"
            )
        if isinstance(spec, tuple):
            if not all(isinstance(d, int) and d >= 0 for d in spec):
                raise TypeError(
                    f"Field {name!r}: shape must be a tuple of "
                    f"non-negative ints, got {spec!r}"
                )
    self._post_validate(field_specs)
    specs = dict(field_specs)
    object.__setattr__(self, "_specs", specs)

fields property

Field names in insertion order.

leaf_shapes property

Per-field leaf shapes. None for opaque (non-array) leaves.

For nested RecordTemplate fields, returns the nested template's leaf_shapes (not the template itself), keyed by /-delimited paths so the keys round-trip with Record.__getitem__'s path syntax.

from_record(record, *, batch_shape=()) classmethod

Infer a template from an existing Record.

Numeric leaves are recorded with their shape (after stripping the leading batch_shape). Python numeric scalars are treated as shape-() leaves. Non-numeric leaves (strings, opaque objects) are recorded as None.

Parameters:

Name Type Description Default
record Record

Source record whose fields define the template structure.

required
batch_shape tuple of int

Leading dimensions to strip from field shapes to get event shapes. For a single-sample Record, use () (default).

()
Notes

A Python list or tuple leaf has no .shape / .dtype and is treated as opaque (None) even if it contains numbers. Wrap it in np.asarray(...) or jnp.asarray(...) before putting it in the Record if you want a numeric template entry. Downstream operations that call NumericRecord.unflatten will otherwise raise on the opaque field.

Source code in probpipe/core/record.py
@classmethod
def from_record(
    cls,
    record: Record,
    *,
    batch_shape: tuple[int, ...] = (),
) -> RecordTemplate:
    """Infer a template from an existing Record.

    Numeric leaves are recorded with their shape (after stripping the
    leading ``batch_shape``). Python numeric scalars are treated as
    shape-``()`` leaves. Non-numeric leaves (strings, opaque objects)
    are recorded as ``None``.

    Parameters
    ----------
    record : Record
        Source record whose fields define the template structure.
    batch_shape : tuple of int
        Leading dimensions to strip from field shapes to get event
        shapes.  For a single-sample Record, use ``()`` (default).

    Notes
    -----
    A Python ``list`` or ``tuple`` leaf has no ``.shape`` / ``.dtype``
    and is treated as opaque (``None``) even if it contains numbers.
    Wrap it in ``np.asarray(...)`` or ``jnp.asarray(...)`` before
    putting it in the Record if you want a numeric template entry.
    Downstream operations that call ``NumericRecord.unflatten`` will
    otherwise raise on the opaque field.
    """
    # Promote plain ``RecordTemplate.from_record`` to
    # ``NumericRecordTemplate`` when the source signals it is all-numeric
    # (a ``NumericRecord`` or any Record whose recursive leaves are
    # numeric). That keeps ``flat_size`` reachable for the common
    # all-numeric case without requiring callers to name the subclass.
    target_cls = cls
    if cls is RecordTemplate:
        from ._numeric_record import NumericRecord
        if isinstance(record, NumericRecord):
            target_cls = NumericRecordTemplate
    n_batch = len(batch_shape)
    specs: dict[str, _FieldSpec] = {}
    for name in record.fields:
        val = record[name]
        if isinstance(val, Record):
            specs[name] = target_cls.from_record(val, batch_shape=batch_shape)
            continue
        # Numeric scalar / numeric array → strip leading batch dims.
        if isinstance(val, (bool, int, float, complex, np.integer, np.floating, np.bool_)):
            full_shape: tuple[int, ...] = ()
        elif hasattr(val, "shape") and hasattr(val, "dtype"):
            kind = getattr(val.dtype, "kind", None)
            if kind in {"b", "i", "u", "f", "c"}:
                full_shape = tuple(val.shape)
            else:
                specs[name] = None
                continue
        else:
            specs[name] = None
            continue
        event_shape = full_shape[n_batch:] if n_batch else full_shape
        specs[name] = event_shape
    return target_cls(specs)

NumericRecordTemplate(_dict=None, /, **field_specs)

Bases: RecordTemplate

RecordTemplate where every leaf is numeric.

Extends RecordTemplate by requiring each spec to be a shape tuple (or a nested NumericRecordTemplate) — no opaque None leaves are allowed. That restriction is what makes flat_size and numeric_leaf_shapes meaningful: flat_size is the total number of scalar elements across every numeric leaf, and the unflatten machinery (NumericRecord.unflatten / NumericRecordArray.unflatten) requires a template of this class so that every field can be reconstructed from a slice of the flat buffer.

Use RecordTemplate.from_record on a NumericRecord (it auto-promotes) or call this constructor directly when you have the shape specs in hand.

Source code in probpipe/core/record.py
def __init__(
    self,
    _dict: dict[str, _FieldSpec] | None = None,
    /,
    **field_specs: _FieldSpec,
):
    super().__init__(_dict, **field_specs)
    object.__setattr__(self, "_flat_size", self._compute_flat_size())

numeric_leaf_shapes property

Per-field shapes for numeric leaves.

On NumericRecordTemplate every leaf is numeric, so this is equivalent to leaf_shapes. Kept as a distinct name for symmetry with historical callers that used it as a filter.

flat_size property

Total number of scalar elements across all numeric leaves.

Record arrays

RecordArray(_dict=None, /, *, batch_shape, template, name=None, **fields)

Bases: Record

Batch of Records with consistent field structure.

Each field stores values with shape (*batch_shape, *leaf_shape). A RecordArray is a Record — the batched variant, parallel to the way DistributionArray is a Distribution. Consolidating the two in a single hierarchy means:

  • isinstance(x, Record) accepts both scalar and batched Records. Code that needs to distinguish uses isinstance(x, RecordArray) for the batched case, or isinstance(x, Record) and not isinstance(x, RecordArray) for scalar-only.
  • .source / .with_source / .name are inherited from Record (stored on the _name / _source slots declared on Record).
  • replace / merge / without / map / map_with_names are overridden here because the base constructor signature doesn't carry batch_shape / template; RecordArray versions preserve those.

Parameters:

Name Type Description Default
batch_shape tuple of int

Shape of the batch dimensions.

required
template RecordTemplate

Structural description of each element.

required
name str

Human-readable name for provenance / introspection. Defaults to "{class_name}({field list in template order})".

None
**fields Any

Named values, each with shape (*batch_shape, *leaf_shape).

{}
Notes

Construct from a list of Records with RecordArray.stack. Indexing is either integer (arr[i] → single Record) or field name (arr["x"] → batched leaf array).

Source code in probpipe/core/_record_array.py
def __init__(
    self,
    _dict: dict[str, Any] | None = None,
    /,
    *,
    batch_shape: tuple[int, ...],
    template: RecordTemplate,
    name: str | None = None,
    **fields: Any,
):
    if _dict is not None:
        if fields:
            raise ValueError(
                "Cannot pass both positional dict and keyword arguments"
            )
        fields = _dict
    if not fields:
        raise ValueError("RecordArray requires at least one field")
    if set(fields.keys()) != set(template.fields):
        raise ValueError(
            f"Field names {sorted(fields)} do not match template "
            f"fields {sorted(template.fields)}"
        )
    # Reorder to match the template so iteration order is canonical
    # regardless of kwarg order.
    store: "OrderedDict[str, Any]" = OrderedDict(
        (name, fields[name]) for name in template.fields
    )
    # Subclass validation hook. Runs after sort / name-check so
    # subclasses (e.g. NumericRecordArray) see a canonicalised view
    # of the leaves. Raises from ``_validate_fields`` propagate.
    store = type(self)._validate_fields(store, batch_shape, template)
    # Inherit the Record plumbing for _store / _name / _source.
    # We bypass Record's normal constructor path because RecordArray
    # requires its own field-validation hook and an auto-name that
    # reflects the class name, not the "record(...)" default.
    if name is None:
        name = f"{type(self).__name__.lower()}({','.join(store.keys())})"
    object.__setattr__(self, "_store", store)
    object.__setattr__(self, "_name", name)
    object.__setattr__(self, "_source", None)
    object.__setattr__(self, "_batch_shape", batch_shape)
    object.__setattr__(self, "_template", template)

batch_shape property

Shape of the batch dimensions.

template property

Structural description of each element.

view(field)

Return a single-field view carrying parent identity.

Unlike ra[field] (which returns the raw column), a view remembers the parent RecordArray. When multiple views of the same parent land in a single WorkflowFunction call, the sweep layer groups them by parent identity and iterates them in lockstep (zip) rather than cartesian-producting.

Used internally by select_all. Direct construction by end-users is supported but rarely needed.

Source code in probpipe/core/_record_array.py
def view(self, field: str) -> "_RecordArrayView":
    """Return a single-field view carrying parent identity.

    Unlike ``ra[field]`` (which returns the raw column), a view
    remembers the parent ``RecordArray``. When multiple views of
    the same parent land in a single ``WorkflowFunction`` call,
    the sweep layer groups them by parent identity and iterates
    them in lockstep (zip) rather than cartesian-producting.

    Used internally by :meth:`~probpipe.record.Design.select_all`.
    Direct construction by end-users is supported but rarely needed.
    """
    return _RecordArrayView(self, field)

keys()

Iterate over field names.

Source code in probpipe/core/_record_array.py
def keys(self) -> Iterator[str]:
    """Iterate over field names."""
    return iter(self._store)

values()

Iterate over field values (batched).

Source code in probpipe/core/_record_array.py
def values(self) -> Iterator[Any]:
    """Iterate over field values (batched)."""
    for name in self._store:
        yield self._store[name]

items()

Iterate over (name, batched_value) pairs.

Source code in probpipe/core/_record_array.py
def items(self) -> Iterator[tuple[str, Any]]:
    """Iterate over (name, batched_value) pairs."""
    for name in self._store:
        yield name, self._store[name]

select(*fields, **mapping)

Select fields as a dict of single-field views.

Mirrors Record.select but each entry is a _RecordArrayView rather than the raw column. The views carry this RecordArray as their parent, so splatting the result into a @workflow_function triggers the parent-identity zip sweep (one inner call per row, matching f(p=self)) instead of cartesian-producting the fields as independent axes.

For raw-column access, use self["field"] per field or iterate self.items().

Source code in probpipe/core/_record_array.py
def select(self, *fields: str, **mapping: str) -> dict[str, Any]:
    """Select fields as a dict of single-field **views**.

    Mirrors :meth:`Record.select` but each entry is a
    :class:`_RecordArrayView` rather than the raw column. The views
    carry this ``RecordArray`` as their parent, so splatting the
    result into a ``@workflow_function`` triggers the
    parent-identity **zip sweep** (one inner call per row, matching
    ``f(p=self)``) instead of cartesian-producting the fields as
    independent axes.

    For raw-column access, use ``self["field"]`` per field or
    iterate ``self.items()``.
    """
    result: dict[str, Any] = {}
    for f in fields:
        if f not in self._store:
            raise KeyError(f"No field {f!r} in {type(self).__name__}")
        result[f] = self.view(f)
    for arg_name, field_name in mapping.items():
        if field_name not in self._store:
            raise KeyError(f"No field {field_name!r} in {type(self).__name__}")
        result[arg_name] = self.view(field_name)
    return result

stack(records, *, template=None) classmethod

Stack a list of Records into a RecordArray with batch_shape=(n,).

Parameters:

Name Type Description Default
records list of Record

Records with consistent field structure.

required
template RecordTemplate

If not provided, inferred from the first record.

None
Notes

Any backend metadata captured on the source NumericRecord instances (xarray dims / coords, pandas index) is dropped — the stacked leaves are plain jax.Array objects. RecordArray does not currently carry per-row aux.

Source code in probpipe/core/_record_array.py
@classmethod
def stack(cls, records: list[Record], *, template: RecordTemplate | None = None) -> RecordArray:
    """Stack a list of Records into a RecordArray with batch_shape=(n,).

    Parameters
    ----------
    records : list of Record
        Records with consistent field structure.
    template : RecordTemplate, optional
        If not provided, inferred from the first record.

    Notes
    -----
    Any backend metadata captured on the source ``NumericRecord``
    instances (xarray dims / coords, pandas index) is dropped — the
    stacked leaves are plain ``jax.Array`` objects. ``RecordArray``
    does not currently carry per-row aux.
    """
    if not records:
        raise ValueError("Cannot stack empty list of Records")
    if template is None:
        template = RecordTemplate.from_record(records[0])
    fields: dict[str, Any] = {}
    for name in template.fields:
        field_vals = [r[name] for r in records]
        fields[name] = jnp.stack(field_vals, axis=0)
    return cls(fields, batch_shape=(len(records),), template=template)

NumericRecordArray(_dict=None, /, *, batch_shape, template, name=None, **fields)

Bases: RecordArray

Batch of NumericRecords — all leaves are numeric arrays.

Adds flatten/unflatten, mean, var operations. Construction validates that every leaf has a numeric dtype and shape (*batch_shape, *event_shape) matching the template, so pytree round-trips (jax.tree.map) cannot silently produce a NumericRecordArray with non-numeric or ill-shaped leaves.

Each field has shape (*batch_shape, *event_shape).

Source code in probpipe/core/_record_array.py
def __init__(
    self,
    _dict: dict[str, Any] | None = None,
    /,
    *,
    batch_shape: tuple[int, ...],
    template: RecordTemplate,
    name: str | None = None,
    **fields: Any,
):
    if _dict is not None:
        if fields:
            raise ValueError(
                "Cannot pass both positional dict and keyword arguments"
            )
        fields = _dict
    if not fields:
        raise ValueError("RecordArray requires at least one field")
    if set(fields.keys()) != set(template.fields):
        raise ValueError(
            f"Field names {sorted(fields)} do not match template "
            f"fields {sorted(template.fields)}"
        )
    # Reorder to match the template so iteration order is canonical
    # regardless of kwarg order.
    store: "OrderedDict[str, Any]" = OrderedDict(
        (name, fields[name]) for name in template.fields
    )
    # Subclass validation hook. Runs after sort / name-check so
    # subclasses (e.g. NumericRecordArray) see a canonicalised view
    # of the leaves. Raises from ``_validate_fields`` propagate.
    store = type(self)._validate_fields(store, batch_shape, template)
    # Inherit the Record plumbing for _store / _name / _source.
    # We bypass Record's normal constructor path because RecordArray
    # requires its own field-validation hook and an auto-name that
    # reflects the class name, not the "record(...)" default.
    if name is None:
        name = f"{type(self).__name__.lower()}({','.join(store.keys())})"
    object.__setattr__(self, "_store", store)
    object.__setattr__(self, "_name", name)
    object.__setattr__(self, "_source", None)
    object.__setattr__(self, "_batch_shape", batch_shape)
    object.__setattr__(self, "_template", template)

flatten()

Flatten event dimensions into a single trailing axis.

Returns array of shape (*batch_shape, flat_event_size).

Source code in probpipe/core/_record_array.py
def flatten(self) -> jnp.ndarray:
    """Flatten event dimensions into a single trailing axis.

    Returns array of shape ``(*batch_shape, flat_event_size)``.
    """
    n_batch = len(self._batch_shape)
    parts = []
    for name in self._store:
        val = self._store[name]
        event_shape = val.shape[n_batch:]
        field_size = prod(event_shape)
        new_shape = self._batch_shape + (field_size,)
        parts.append(jnp.reshape(val, new_shape))
    return jnp.concatenate(parts, axis=-1)

unflatten(flat, *, template, batch_shape=None) classmethod

Reconstruct from a flat array.

Parameters:

Name Type Description Default
flat array

Shape (*batch_shape, flat_event_size).

required
template RecordTemplate

Structural description providing field names and event shapes.

required
batch_shape tuple of int

If not provided, inferred as flat.shape[:-1].

None
Source code in probpipe/core/_record_array.py
@classmethod
def unflatten(
    cls,
    flat: jnp.ndarray,
    *,
    template: RecordTemplate,
    batch_shape: tuple[int, ...] | None = None,
) -> NumericRecordArray:
    """Reconstruct from a flat array.

    Parameters
    ----------
    flat : array
        Shape ``(*batch_shape, flat_event_size)``.
    template : RecordTemplate
        Structural description providing field names and event shapes.
    batch_shape : tuple of int, optional
        If not provided, inferred as ``flat.shape[:-1]``.
    """
    if batch_shape is None:
        batch_shape = flat.shape[:-1]

    fields: dict[str, jnp.ndarray] = {}
    offset = 0
    for name in template.fields:
        spec = template[name]
        size = _spec_size(spec)
        chunk = flat[..., offset : offset + size]
        if isinstance(spec, RecordTemplate):
            fields[name] = cls.unflatten(
                chunk, template=spec, batch_shape=batch_shape,
            )
        else:
            fields[name] = jnp.reshape(chunk, batch_shape + spec)
        offset += size

    return cls(fields, batch_shape=batch_shape, template=template)

mean(axis=0)

Mean over a batch axis.

Returns NumericRecord if no batch dims remain, else NumericRecordArray.

Source code in probpipe/core/_record_array.py
def mean(self, axis: int = 0) -> Any:
    """Mean over a batch axis.

    Returns ``NumericRecord`` if no batch dims remain, else
    ``NumericRecordArray``.
    """
    return self._reduce(jnp.mean, axis)

var(axis=0)

Variance over a batch axis.

Returns NumericRecord if no batch dims remain, else NumericRecordArray.

Source code in probpipe/core/_record_array.py
def var(self, axis: int = 0) -> Any:
    """Variance over a batch axis.

    Returns ``NumericRecord`` if no batch dims remain, else
    ``NumericRecordArray``.
    """
    return self._reduce(jnp.var, axis)

Weights

Weights(*, n=None, weights=None, log_weights=None)

Normalized probability weights over n items.

Weights stores log-unnormalized weights internally for numerical stability and provides lazy-cached access to normalized weights and log-weights.

It implements the JAX array protocol (__jax_array__), so a Weights object can be passed directly to any JAX operation that expects an array — it will automatically convert to its normalized weight vector.

Parameters:

Name Type Description Default
n int

Number of items. When provided alone (without weights or log_weights), creates uniform weights. When provided alongside an array, validates that the array length matches.

None
weights ArrayLike, Weights, or None

Non-negative weights. n is inferred from len(weights) if not given explicitly. A pre-built Weights object is accepted and adopted without re-validation. Mutually exclusive with log_weights.

None
log_weights ArrayLike, Weights, or None

Log-unnormalized weights. Preferred when weights span many orders of magnitude (e.g. importance sampling). n is inferred from len(log_weights) if not given explicitly. A pre-built Weights object is accepted and adopted without re-validation. Mutually exclusive with weights.

None

Examples:

Create from raw weights or log-weights:

>>> w = Weights(weights=jnp.array([1.0, 2.0, 1.0]))
>>> w = Weights(log_weights=jnp.array([-1.0, 0.0, -1.0]))

Create uniform weights:

>>> w = Weights(n=5)

Use as a JAX array — returns normalized weights automatically:

>>> jnp.sum(w)                          # -> ~1.0
>>> jnp.einsum("n,n...->...", w, vals)  # weighted sum
>>> w * values                          # element-wise product

This means Weights can be passed anywhere a weight array is expected, including jax.random.choice(..., p=w).

Access different representations explicitly when needed:

>>> w.normalized         # Array, shape (n,) — probabilities summing to 1
>>> w.log_normalized     # Array — log-probabilities (always an array)
>>> w.log_unnormalized   # Array | None — raw stored log-weights (None if uniform)
>>> w.is_uniform         # bool — True when all items are equally weighted

Compute weighted statistics directly:

>>> w.mean(values)                  # weighted mean along leading axis
>>> w.variance(values)              # weighted variance
>>> w.covariance(values)            # weighted covariance matrix
>>> w.choice(key, shape=(10,))      # draw 10 weighted random indices

Passing to distribution constructors — all ProbPipe distribution constructors that accept weights or log_weights also accept a pre-built Weights object for either parameter. When a Weights object is passed, it is used as-is (no re-validation). The behavior is the same regardless of which parameter it is passed to, since the Weights object already encapsulates its representation::

w = Weights(log_weights=log_w)
EmpiricalDistribution(samples, weights=w)       # OK
EmpiricalDistribution(samples, log_weights=w)   # also OK, same result

JAX compatibilityWeights is registered as a JAX pytree whose single leaf is the normalized weight array, so it works transparently inside jax.jit, jax.vmap, jnp.sum, jnp.einsum, and other JAX operations.

Notes

Zero weights and -inf in log-space. When a weight array contains zeros (e.g. [0.0, 1.0, 0.0]), the internal log-representation stores -inf for those entries. All Weights operations handle this correctly:

  • normalized produces 0.0 for those items (via softmax).
  • log_normalized contains -inf entries (mathematically correct: log(0) = -inf).
  • effective_sample_size is unaffected (logsumexp handles -inf inputs).
  • choice never selects zero-weight items.

Code that consumes log_normalized directly should be aware that -inf values may be present.

Source code in probpipe/_weights.py
def __init__(
    self,
    *,
    n: int | None = None,
    weights: ArrayLike | Weights | None = None,
    log_weights: ArrayLike | Weights | None = None,
):
    # --- Fast path: adopt an existing Weights object ---
    source = None
    if isinstance(weights, Weights):
        if log_weights is not None:
            raise ValueError(
                "Provide either weights or log_weights, not both."
            )
        source = weights
    elif isinstance(log_weights, Weights):
        if weights is not None:
            raise ValueError(
                "Provide either weights or log_weights, not both."
            )
        source = log_weights

    if source is not None:
        if n is not None and source._n != n:
            raise ValueError(
                f"Weights length {source._n} does not match n={n}."
            )
        self._n = source._n
        self._log_weights = source._log_weights
        self._is_uniform = source._is_uniform
        self._cache = source._cache
        return

    # --- Build from raw inputs ---
    # Infer n from array length when not given explicitly.
    if n is None:
        if weights is not None:
            weights = _as_float_array(weights)
            if weights.ndim != 1 or weights.shape[0] == 0:
                raise ValueError("weights must be a non-empty 1-D array.")
            n = weights.shape[0]
        elif log_weights is not None:
            log_weights = _as_float_array(log_weights)
            if log_weights.ndim != 1 or log_weights.shape[0] == 0:
                raise ValueError("log_weights must be a non-empty 1-D array.")
            n = log_weights.shape[0]
        else:
            raise ValueError(
                "At least one of n, weights, or log_weights must be "
                "provided."
            )

    self._log_weights, self._is_uniform = _validate_to_log_weights(
        n, weights, log_weights=log_weights,
    )
    self._n = n

    self._cache: Array | None = None

n property

Number of items.

is_uniform property

True when all items are equally weighted.

normalized property

Normalized weights, shape (n,). Cached after first access.

log_normalized property

Normalized log-weights, shape (n,).

Returns -log(n) for uniform weights, matching the behavior of normalized which always returns an array.

log_unnormalized property

Raw log-unnormalized weights as stored. None when uniform.

effective_sample_size property

Kish's effective sample size (ESS).

.. math::

n_{\mathrm{eff}} = \frac{1}{\sum_i w_i^2}

where :math:w_i are the normalized weights. For uniform weights this equals n exactly.

Computed in log-space for numerical stability:

.. math::

n_{\mathrm{eff}}
= \exp\!\bigl(-\log \sum_i \exp(2 \log w_i)\bigr)
= \exp\!\bigl(-\mathrm{logsumexp}(2\,\log\mathbf{w})\bigr)

Returns:

Type Description
Array

Scalar effective sample size (1 <= n_eff <= n).

shape property

Shape of the weight vector: (n,).

dtype property

Data type of the underlying log-weights array.

For uniform weights, returns JAX's current default float dtype.

uniform(n) staticmethod

Create uniform weights over n items.

Equivalent to Weights(n=n) but avoids keyword overhead in hot internal paths.

Source code in probpipe/_weights.py
@staticmethod
def uniform(n: int) -> Weights:
    """Create uniform weights over *n* items.

    Equivalent to ``Weights(n=n)`` but avoids keyword overhead in
    hot internal paths.
    """
    w = object.__new__(Weights)
    w._n = n
    w._log_weights = None
    w._is_uniform = True
    w._cache = None
    return w

__jax_array__()

Return normalized weights as a JAX array.

This allows Weights to be used directly in JAX operations (jnp.sum(w), jnp.einsum(..., w, ...), w * arr, etc.).

Source code in probpipe/_weights.py
def __jax_array__(self) -> Array:
    """Return normalized weights as a JAX array.

    This allows ``Weights`` to be used directly in JAX operations
    (``jnp.sum(w)``, ``jnp.einsum(..., w, ...)``, ``w * arr``, etc.).
    """
    return self.normalized

mean(values)

Compute weighted mean: sum_i w_i * values[i].

Source code in probpipe/_weights.py
def mean(self, values: Array) -> Array:
    """Compute weighted mean: ``sum_i w_i * values[i]``."""
    return weighted_mean(
        None if self._is_uniform else self.normalized, values,
    )

variance(values, mean=None)

Compute weighted variance over the leading axis.

Source code in probpipe/_weights.py
def variance(self, values: Array, mean: Array | None = None) -> Array:
    """Compute weighted variance over the leading axis."""
    return weighted_variance(
        None if self._is_uniform else self.normalized, values, mean=mean,
    )

covariance(values, mean=None)

Compute weighted covariance matrix over the leading axis.

Source code in probpipe/_weights.py
def covariance(self, values: Array, mean: Array | None = None) -> Array:
    """Compute weighted covariance matrix over the leading axis."""
    return weighted_covariance(
        None if self._is_uniform else self.normalized, values, mean=mean,
    )

choice(key, *, shape=())

Draw weighted random indices from 0..n-1.

Source code in probpipe/_weights.py
def choice(self, key: PRNGKey, *, shape: tuple[int, ...] = ()) -> Array:
    """Draw weighted random indices from ``0..n-1``."""
    return weighted_choice(
        key, self._n,
        weights=None if self._is_uniform else self.normalized,
        shape=shape,
    )

subsample(indices)

Return a new Weights for a subset, re-normalized.

Parameters:

Name Type Description Default
indices Array

Integer indices selecting a subset of items.

required

Returns:

Type Description
Weights

New Weights over len(indices) items with weights proportional to the original weights at indices.

Source code in probpipe/_weights.py
def subsample(self, indices: Array) -> Weights:
    """Return a new ``Weights`` for a subset, re-normalized.

    Parameters
    ----------
    indices : Array
        Integer indices selecting a subset of items.

    Returns
    -------
    Weights
        New ``Weights`` over ``len(indices)`` items with weights
        proportional to the original weights at *indices*.
    """
    if self._is_uniform:
        return Weights.uniform(len(indices))
    sub_log = self._log_weights[indices]
    return Weights(log_weights=sub_log)

tree_flatten()

Flatten for JAX pytree.

The single leaf is the normalized weight array so that JAX operations (jnp.sum, jnp.einsum, etc.) receive a plain array when unpacking the pytree.

Source code in probpipe/_weights.py
def tree_flatten(self):
    """Flatten for JAX pytree.

    The single leaf is the **normalized** weight array so that JAX
    operations (``jnp.sum``, ``jnp.einsum``, etc.) receive a plain
    array when unpacking the pytree.
    """
    return (self.normalized,), (self._n, self._is_uniform)

tree_unflatten(aux, children) classmethod

Unflatten from JAX pytree.

Source code in probpipe/_weights.py
@classmethod
def tree_unflatten(cls, aux, children):
    """Unflatten from JAX pytree."""
    (normalized,) = children
    n, is_uniform = aux
    w = object.__new__(cls)
    w._n = n
    w._is_uniform = is_uniform
    if is_uniform:
        w._log_weights = None
    else:
        w._log_weights = jnp.log(jnp.clip(normalized, 1e-45))
    w._cache = normalized
    return w

Parameter-sweep designs

FullFactorialDesign(**marginals) materialises the Cartesian product of per-field marginals as a sweep-ready RecordArray.

Design(_dict=None, /, *, batch_shape, template, name=None, **fields)

Bases: RecordArray

RecordArray that carries its per-field marginals.

A Design is not meant to be instantiated directly — concrete subclasses (FullFactorialDesign) assemble the underlying rows in __init__ and stash the originating marginals for introspection.

Two equivalent ways to drive a sweep through a @workflow_function::

@workflow_function
def fit(p): ...
result = fit(p=design)              # one row per call

@workflow_function
def fit(r, K): ...
result = fit(**design.select_all()) # zip across sibling views

select_all() returns one view per field; views that share the Design as their parent zip across rows in the WF sweep layer (so the two shapes above produce identical outputs). For raw columns (no sweep — just JAX broadcasting), index with design["r"].

Attributes:

Name Type Description
marginals Mapping[str, Any]

The per-field marginals this design was built from, in construction (insertion) order. Kept for introspection; read-only.

Source code in probpipe/core/_record_array.py
def __init__(
    self,
    _dict: dict[str, Any] | None = None,
    /,
    *,
    batch_shape: tuple[int, ...],
    template: RecordTemplate,
    name: str | None = None,
    **fields: Any,
):
    if _dict is not None:
        if fields:
            raise ValueError(
                "Cannot pass both positional dict and keyword arguments"
            )
        fields = _dict
    if not fields:
        raise ValueError("RecordArray requires at least one field")
    if set(fields.keys()) != set(template.fields):
        raise ValueError(
            f"Field names {sorted(fields)} do not match template "
            f"fields {sorted(template.fields)}"
        )
    # Reorder to match the template so iteration order is canonical
    # regardless of kwarg order.
    store: "OrderedDict[str, Any]" = OrderedDict(
        (name, fields[name]) for name in template.fields
    )
    # Subclass validation hook. Runs after sort / name-check so
    # subclasses (e.g. NumericRecordArray) see a canonicalised view
    # of the leaves. Raises from ``_validate_fields`` propagate.
    store = type(self)._validate_fields(store, batch_shape, template)
    # Inherit the Record plumbing for _store / _name / _source.
    # We bypass Record's normal constructor path because RecordArray
    # requires its own field-validation hook and an auto-name that
    # reflects the class name, not the "record(...)" default.
    if name is None:
        name = f"{type(self).__name__.lower()}({','.join(store.keys())})"
    object.__setattr__(self, "_store", store)
    object.__setattr__(self, "_name", name)
    object.__setattr__(self, "_source", None)
    object.__setattr__(self, "_batch_shape", batch_shape)
    object.__setattr__(self, "_template", template)

marginals property

Per-field marginals this design was built from.

FullFactorialDesign(**marginals)

Bases: Design

Cartesian product over all marginals — one row per combination.

Each marginal is a Python sequence (list, tuple, numpy / jax array). Numeric marginals become jnp.ndarray columns and categorical / string marginals become numpy.ndarray(dtype=object) columns. Row order is row-major over the marginals in insertion order — i.e., the last-listed marginal varies fastest.

Parameters:

Name Type Description Default
**marginals Sequence

Candidate values for each field. Must pass at least one marginal; each must be non-empty.

{}

Examples:

Cartesian grid of two numeric fields:

>>> ff = FullFactorialDesign(r=[1.5, 1.8], K=[60.0, 80.0])
>>> ff.batch_shape
(4,)
>>> ff.fields
('r', 'K')

Mixed numeric / categorical marginals are supported — columns fall out as object-dtype arrays for the categorical fields:

>>> ff2 = FullFactorialDesign(method=['nutpie', 'pymc'], scale=[0.5, 1.0])
>>> ff2.batch_shape
(4,)
Source code in probpipe/record/design.py
def __init__(self, **marginals: Sequence) -> None:
    if not marginals:
        raise ValueError(
            "FullFactorialDesign requires at least one marginal"
        )
    names = list(marginals)
    lists = [list(marginals[n]) for n in names]
    sizes = [len(v) for v in lists]
    if any(s == 0 for s in sizes):
        raise ValueError(
            "FullFactorialDesign marginals must each be non-empty; "
            f"got sizes {dict(zip(names, sizes))}"
        )
    n_total = prod(sizes)
    # ``meshgrid(..., indexing='ij')`` then flatten: each axis
    # iterates at its own stride; C-order flatten then yields a
    # lexicographic row-major traversal over the marginals in
    # insertion order.
    grids = np.meshgrid(
        *(np.arange(s) for s in sizes), indexing="ij",
    )
    flat_indices = {
        name: grid.reshape(-1) for name, grid in zip(names, grids)
    }

    fields: dict[str, Any] = {}
    template_spec: dict[str, Any] = {}
    for name, values in zip(names, lists):
        col, leaf_shape = _seq_to_column(
            values, indices=flat_indices[name],
        )
        fields[name] = col
        template_spec[name] = leaf_shape

    RecordArray.__init__(
        self, fields,
        batch_shape=(n_total,),
        template=RecordTemplate(template_spec),
        name=f"FullFactorialDesign({','.join(names)})",
    )
    object.__setattr__(self, "_marginals", dict(marginals))

Auxiliary-metadata registry

RecordNumericRecord conversion drops backend-specific metadata (xarray dims / coords / attrs, pandas index / columns / dtypes); the auxiliary registry round-trips it so to_native() reproduces the original container.

register_aux(leaf_type, *, capture, restore)

Register (capture, restore) hooks for a backend leaf type.

Parameters:

Name Type Description Default
leaf_type type

The Python type of the leaves whose metadata should be preserved across a Record/NumericRecord round-trip.

required
capture callable

capture(leaf) -> aux.

required
restore callable

restore(arr, aux) -> leaf.

required
Notes

Re-registering an existing leaf_type overwrites the previous hook silently. Lookup uses aux_for which walks the MRO of type(obj), so registering a base class also covers its subclasses.

Source code in probpipe/core/_array_backend.py
def register_aux(
    leaf_type: type,
    *,
    capture: Callable[[Any], Any],
    restore: Callable[[jax.Array, Any], Any],
) -> None:
    """Register `(capture, restore)` hooks for a backend leaf type.

    Parameters
    ----------
    leaf_type : type
        The Python type of the leaves whose metadata should be
        preserved across a ``Record``/``NumericRecord`` round-trip.
    capture : callable
        ``capture(leaf) -> aux``.
    restore : callable
        ``restore(arr, aux) -> leaf``.

    Notes
    -----
    Re-registering an existing ``leaf_type`` overwrites the previous
    hook silently. Lookup uses :func:`aux_for` which walks the MRO of
    ``type(obj)``, so registering a base class also covers its
    subclasses.
    """
    aux_registry[leaf_type] = AuxHooks(capture=capture, restore=restore)

aux_for(obj)

Return the registered hooks for obj, or None if absent.

Walks the MRO of type(obj) so subclass instances pick up base-class registrations.

Notes

Exact-type lookup is checked before the MRO walk so the common np.ndarray / jax.Array leaves on the NumericRecord.__init__ / pytree-unflatten hot path skip a multi-step MRO traversal.

Source code in probpipe/core/_array_backend.py
def aux_for(obj: Any) -> AuxHooks | None:
    """Return the registered hooks for ``obj``, or ``None`` if absent.

    Walks the MRO of ``type(obj)`` so subclass instances pick up
    base-class registrations.

    Notes
    -----
    Exact-type lookup is checked before the MRO walk so the common
    ``np.ndarray`` / ``jax.Array`` leaves on the
    ``NumericRecord.__init__`` / pytree-unflatten hot path skip a
    multi-step MRO traversal.
    """
    if not aux_registry:
        return None
    cls = type(obj)
    # Exact-type fast path — covers the common case where the
    # registered key is the leaf's concrete class.
    hooks = aux_registry.get(cls)
    if hooks is not None:
        return hooks
    for base in cls.__mro__[1:]:
        hooks = aux_registry.get(base)
        if hooks is not None:
            return hooks
    return None

AuxHooks(capture, restore) dataclass

A pair of (capture, restore) hooks for one backend type.

Parameters:

Name Type Description Default
capture callable

capture(leaf) -> aux — extract backend-specific metadata that would otherwise be lost when jnp.asarray coerces the leaf to a JAX array. May return any pickle-friendly value.

required
restore callable

restore(arr, aux) -> leaf — reconstruct an instance of the original backend type from a JAX array and the previously captured aux blob.

required