Skip to content

Workflows and orchestration

WorkflowFunction wraps every op and every user-written @workflow_function. Module is the stateful container with @workflow_method children.

Prefect orchestration is off by default. Set prefect_config.workflow_kind = WorkflowKind.TASK (or FLOW) globally, or export PROBPIPE_WORKFLOW_KIND=task in the environment.

Wrappers and decorators

WorkflowFunction(*, func, workflow_kind=WorkflowKind.DEFAULT, name=None, bind=None, module=None, n_broadcast_samples=None, vectorize='auto', parallel=False, seed=0, include_inputs=False, **kwargs)

Bases: Node

A single executable DAG node wrapping exactly one function.

Infers dependency-vs-input from the function signature and type hints. Optionally resolves missing values from an attached Module.

Broadcasting: When a Distribution is passed for an argument whose type hint is not a Distribution subclass, the workflow automatically samples from the distribution and calls the wrapped function once per sample, returning an EmpiricalDistribution over the outputs (or a plain list when results are not numeric).

Vectorization and orchestration are orthogonal concerns:

  • Vectorization (vectorize) controls how samples are dispatched: jax.vmap for JAX-traceable functions, or a Python loop otherwise.
  • Orchestration (workflow_kind) controls whether the dispatch is wrapped in a Prefect task or flow for compute-graph tracing.

When both are active, the JAX-vectorized computation is executed inside a Prefect task/flow, giving the benefits of vmap performance with full Prefect lineage tracking.

Parameters:

Name Type Description Default
func Callable

The function to wrap.

required
workflow_kind WorkflowKind

Prefect orchestration mode. DEFAULT inherits from prefect_config.workflow_kind (shipped default: OFF; set via the PROBPIPE_WORKFLOW_KIND environment variable or explicit assignment). TASK / FLOW explicitly request Prefect orchestration. OFF disables orchestration. Legacy strings ("task", "flow") and None are auto-converted.

DEFAULT
name str or None

Display name; defaults to func.__name__.

None
bind dict or None

Construction-time keyword bindings (defaults / config).

None
module Module or None

Parent module for input / dependency resolution.

None
n_broadcast_samples int

Default number of samples drawn when broadcasting. Can be overridden at call time by passing n_broadcast_samples=… (provided the wrapped function does not itself declare a parameter with that name).

None
vectorize str

Vectorization strategy for broadcasting:

  • "auto" (default): probe with jax.make_jaxpr; on success use "jax", on failure fall back to "loop".
  • "jax": vectorise via jax.vmap. Requires the wrapped function to be JAX-traceable.
  • "loop": Python loop (optionally threaded via parallel).
'auto'
parallel bool or int

Controls parallel execution during broadcasting ("loop" vectorization only). False → sequential, TrueThreadPoolExecutor with default workers, int → explicit max_workers.

False
seed int

Random seed for JAX PRNG key management during broadcasting.

0
Source code in probpipe/core/node.py
def __init__(
    self,
    *,
    func: Callable,
    workflow_kind: WorkflowKind | str | None = WorkflowKind.DEFAULT,  # TODO: remove str | None in follow-up issue
    name: str | None = None,
    bind: dict[str, Any] | None = None,         # construction-time bindings (defaults/config)
    module: Any | None = None,                  # typically a Module; kept as Any to avoid import cycles
    n_broadcast_samples: int | None = None,      # default number of samples for broadcasting
    vectorize: str = "auto",                     # "auto" | "jax" | "loop"
    parallel: bool | int = False,               # True/int for ThreadPoolExecutor, or Prefect .map()
    seed: int = 0,                              # JAX PRNG seed for broadcasting
    include_inputs: bool = False,                # True → return BroadcastDistribution (joint over inputs+outputs)
    **kwargs: Any,                              # convenience bindings (merged into bind)
):
    self._func = func
    self._sig = inspect.signature(func)
    self._hints = get_type_hints(func)
    # Convert legacy string / None values to WorkflowKind enum
    # TODO: remove this legacy conversion in follow-up issue
    if workflow_kind is None:
        self._workflow_kind_raw = WorkflowKind.OFF
    elif isinstance(workflow_kind, str) and not isinstance(workflow_kind, WorkflowKind):
        self._workflow_kind_raw = WorkflowKind(workflow_kind)
    else:
        self._workflow_kind_raw = workflow_kind
    self._name = name or getattr(func, "__name__", self.__class__.__name__)

    # Expose wrapped function's metadata for introspection (help(),
    # inspect.signature(), IDE tooltips, mkdocstrings).  We skip
    # __wrapped__ to prevent inspect.unwrap() from bypassing __call__.
    self.__doc__ = func.__doc__
    self.__name__ = self._name
    self.__qualname__ = getattr(func, "__qualname__", self._name)
    self.__signature__ = self._sig
    self.__module__ = getattr(func, "__module__", None)
    self._module = module
    self._n_broadcast_samples = n_broadcast_samples if n_broadcast_samples is not None else self.DEFAULT_N_BROADCAST_SAMPLES
    self._vectorize = vectorize
    self._parallel = parallel
    self._key = jax.random.PRNGKey(seed)
    self._include_inputs = include_inputs
    self._resolved_vectorize: str | None = None  # cached auto-detection result

    # bind = "construction-time inputs" (defaults/config). kwargs are also treated as bind.
    b = dict(bind or {})
    b.update(kwargs)
    self._bind = b

    super().__init__()

    # Precompute parameter metadata once
    self._param_names = [p for p in self._sig.parameters if p != "self"]
    self._has_var_keyword = any(
        p.kind == inspect.Parameter.VAR_KEYWORD
        for p in self._sig.parameters.values()
    )

    # Reserved names that would collide with WorkflowFunction call-time overrides
    _RESERVED = {"n_broadcast_samples", "seed", "include_inputs"}
    collision = _RESERVED & set(self._param_names)
    if collision:
        raise ValueError(
            f"Function '{self._name}' has parameter(s) {collision} which are "
            f"reserved by WorkflowFunction for call-time overrides. Rename them in "
            f"your function signature."
        )

effective_workflow_kind property

Resolve the orchestration mode for this instance.

Resolution order:

  1. Per-instance override (anything other than DEFAULT).
  2. Global prefect_config.workflow_kind.
  3. If global is also DEFAULT, fall back to OFF. Prefect orchestration is opt-in: set the global or per-instance workflow_kind to TASK / FLOW, or export PROBPIPE_WORKFLOW_KIND=task in the environment.

If Prefect is not installed but TASK or FLOW is requested (either per-instance or globally), a warning is emitted and the mode falls back to OFF.

Module(*, workflow_kind=WorkflowKind.DEFAULT, **kwargs)

Bases: Node

Container for workflow nodes with shared inputs and child nodes.

New user-facing API: MyModule(data=data_node, horizon=30, alpha=0.1)

Internally: - kwargs whose values are Node instances become child_nodes - everything else becomes inputs

Source code in probpipe/core/node.py
def __init__(self, *, workflow_kind: WorkflowKind | str | None = WorkflowKind.DEFAULT, **kwargs: Any):
    # Convert legacy string / None values to WorkflowKind enum
    if workflow_kind is None:
        self._workflow_kind = WorkflowKind.OFF
    elif isinstance(workflow_kind, str) and not isinstance(workflow_kind, WorkflowKind):
        self._workflow_kind = WorkflowKind(workflow_kind)
    else:
        self._workflow_kind = workflow_kind
    super().__init__(**kwargs)
    # validate abstract workflow implementations before wrapping
    self._validate_abstract_workflow_implementations()

    self._build_workflows()

dag()

Return a Graphviz DAG visualization of this module.

Source code in probpipe/core/node.py
def dag(self):
    """Return a Graphviz DAG visualization of this module."""
    if Digraph is None:
        raise ImportError(
            "graphviz is required for dag visualization. "
            "Install it with: pip install probpipe[viz]"
        )
    dot = Digraph(
        name=self.__class__.__name__,
        graph_attr={
            "rankdir": "LR",
            "fontsize": "12",
            "fontname": "Helvetica",
        },
        node_attr={
            "fontname": "Helvetica",
            "fontsize": "11",
        },
    )

    # -------------------------
    # Child nodes (outside)
    # -------------------------
    for name in self._child_nodes:
        dot.node(
            name,
            label=name,
            shape="ellipse",
            style="filled",
            fillcolor="#E8E8E8",
        )

    # -------------------------
    # Module cluster
    # -------------------------
    with dot.subgraph(name=f"cluster_{self.__class__.__name__}") as cluster:
        cluster.attr(
            label=self.__class__.__name__,
            style="rounded",
            color="#4F81BD",
            fontname="Helvetica-Bold",
            fontsize="12",
        )

        # WorkflowFunction nodes inside the module
        for attr_name in dir(self):
            attr = getattr(self, attr_name)
            if not isinstance(attr, WorkflowFunction):
                continue

            wf_name = attr._name  # e.g. PM25ForecastingModule.fit
            wf_label = wf_name.split(".")[-1]

            cluster.node(
                wf_name,
                label=wf_label,
                shape="box",
                style="filled",
                fillcolor="#C6DBEF",
            )

    # -------------------------
    # Dependency edges
    # -------------------------
    for attr_name in dir(self):
        attr = getattr(self, attr_name)
        if not isinstance(attr, WorkflowFunction):
            continue

        wf_name = attr._name

        # Infer dependencies from workflow signature
        # (WorkflowFunctions don't store child_nodes; they resolve dependencies at runtime)
        for param_name in attr._param_names:
            if attr._is_dependency_param(param_name) and param_name in self._child_nodes:
                dot.edge(param_name, wf_name)

    return dot

workflow_function(_func=None, /, **kwargs)

Decorator to create a WorkflowFunction from a plain function.

Can be used with or without arguments::

@workflow_function
def my_func(x, y):
    return x + y

@workflow_function(n_broadcast_samples=100, vectorize="loop")
def my_func(x, y):
    return x + y
Source code in probpipe/core/node.py
def workflow_function(_func=None, /, **kwargs):
    """Decorator to create a :class:`WorkflowFunction` from a plain function.

    Can be used with or without arguments::

        @workflow_function
        def my_func(x, y):
            return x + y

        @workflow_function(n_broadcast_samples=100, vectorize="loop")
        def my_func(x, y):
            return x + y
    """
    def decorator(func):
        return WorkflowFunction(func=func, name=func.__name__, **kwargs)

    if _func is not None:
        # Bare @workflow_function (no parentheses)
        return decorator(_func)
    # @workflow_function(...) with arguments
    return decorator

workflow_method(func)

Mark a method as a workflow method for Module subclasses.

Methods decorated with @workflow_method are automatically converted to WorkflowFunction instances when the Module is instantiated.

Source code in probpipe/core/node.py
def workflow_method(func: Callable):
    """Mark a method as a workflow method for :class:`Module` subclasses.

    Methods decorated with ``@workflow_method`` are automatically
    converted to :class:`WorkflowFunction` instances when the
    ``Module`` is instantiated.
    """
    func._is_workflow = True
    return func

abstract_workflow_method(func)

Mark a method as an abstract workflow interface.

Combines @abstractmethod with @workflow_method so that AbstractModule subclasses can declare workflow-shaped interfaces without providing implementations.

Source code in probpipe/core/node.py
def abstract_workflow_method(func: Callable):
    """Mark a method as an abstract workflow interface.

    Combines ``@abstractmethod`` with ``@workflow_method`` so that
    :class:`AbstractModule` subclasses can declare workflow-shaped
    interfaces without providing implementations.
    """
    return abstractmethod(workflow_method(func))

Orchestration configuration

WorkflowKind

Bases: Enum

Orchestration mode for WorkflowFunction instances.

Members

DEFAULT Inherit from global config; the shipped global default is OFF unless overridden via PROBPIPE_WORKFLOW_KIND or explicit assignment to prefect_config.workflow_kind. At the per-instance level, DEFAULT means "inherit from global config". OFF No Prefect orchestration. Plain Python execution. TASK Wrap execution in a Prefect task (via task.map()). Raises ImportError if Prefect is not installed. FLOW Wrap execution in a Prefect flow. Raises ImportError if Prefect is not installed.

prefect_config = PrefectConfig() module-attribute

PROBPIPE_WORKFLOW_KIND environment variable

PROBPIPE_WORKFLOW_KIND (case-insensitive: off / task / flow / default) sets the initial prefect_config.workflow_kind at import time. Unknown values raise ValueError. prefect_config.reset() re-reads the variable.