Source code for dplutils.pipeline.task

import inspect
from copy import copy
from dataclasses import dataclass, field, replace
from typing import Any, Callable

import pandas as pd


[docs] @dataclass class PipelineTask: """Container representing a task and its runtime configuration and dependencies. The PipelineTask represents metadata about a task within a particular pipeline, the execution of which is handled by a :class:`PipelineExecutor<dplutils.pipeline.executor.PipelineExecutor>`. Each task function is expected to take a pandas Dataframe as its first positional argument and produce a dataframe. Any additional arguments should be keyword arguments (required or otherwise) and are expected to be passed to the function by the executor via ``kwargs`` and ``context_kwargs``. The values set in a PipelineTask represent the defaults for a task, which may be updated by a particular pipeline invocation. The ``__call__`` method can be used to return a new task with updated default parameters, for example to customize the name: >>> MyTask = PipelineTask('task', func=myfunc) >>> MyNewTask = MyTask('reconf', kwargs=dict(arg=value)) args: name: name of task for reference in pipeline and configuration. func: the callable to execute task operations. context_kwargs: a mapping from pipeline context key to keyword arguments passed to func. Used for reusing arguments that are not task-specific. For example if the several steps access a key ``data``, then setting context_kwargs to {``data_in``: ``data``} would indicate the executor should pass the value of the context element ``data`` to the ``data_in`` of func at runtime. kwargs: a dict of task-specific kwargs to pass to function as ``**kwargs`` num_cpus: CPU allocations requested for task num_gpus: GPU allocations requested for task resources: dict of any additional resources to pass to the executor batch_size: ideal batch size for this workload. """ name: str func: Callable[[pd.DataFrame, ...], pd.DataFrame] context_kwargs: dict[str, str] = field(default_factory=dict) kwargs: dict[str, Any] = field(default_factory=dict) num_gpus: int = None num_cpus: int = 1 resources: dict[str, Any] = field(default_factory=dict) batch_size: int = None def __call__(self, name: str = None, **kwargs): if name is None: name = self.name return replace(self, name=name, **kwargs) def resolve_kwargs(self, context: dict): """Return a dict of final keyword arguments for the given context. This method consults context_kwargs to build an updated list of kwargs based on the given context to pass to ``func``. """ kwargs = copy(self.kwargs) kwargs.update({k: context[v] for k, v in self.context_kwargs.items() if v in context}) return kwargs def validate(self, context: dict): """Validate the arguments of ``func`` given context prior to run. To enable a pre-execution check of the final configuration including context, this method consults the signature of ``func`` and the ``kwargs`` and ``context_kwargs`` to determine if the call is missing or has too many parameters. A pandas Dataframe is expected as the first parameter and thus ignored in this validation. """ all_kwargs = self.resolve_kwargs(context) # we expect a dataframe as the first argument, so skip validation for that params = list(inspect.signature(self.func).parameters.items())[1:] # Because the signature and params therein do not indicate varadics, we have to # consult getfullargspec for those. Similarly, since fullargspec doesn't indicate # positional only, we utilize both. argspec = inspect.getfullargspec(self.func) if argspec.args[0] in all_kwargs: raise ValueError("first position argument reserved for input dataframe but found in kwargs") for key, param in params: if param.name in [argspec.varargs, argspec.varkw]: continue if param.kind == inspect.Parameter.POSITIONAL_ONLY: raise ValueError(f"only one positional only argument supported, found also {param.name}") if param.default == inspect._empty: if key not in all_kwargs: msg = f"missing required argument {key} for task {self.name}" if key in self.context_kwargs: msg = f"{msg} - expected from context {self.context_kwargs[key]}" raise ValueError(msg) if not argspec.varkw: extra = set(all_kwargs.keys()) - {k for k, v in params} if len(extra) > 0: raise ValueError(f"unkown arguments {extra} for task {self.name}") def __hash__(self): return hash(self.name)