from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import Any, Callable
import networkx as nx
import numpy as np
import pandas as pd
from dplutils.pipeline import OutputBatch, PipelineExecutor, PipelineTask
from dplutils.pipeline.utils import deque_extract, split_dataframe
@dataclass
class StreamBatch:
"""Container for task output tracking
Args:
length: length of dataframe that is referenced by ``data``. This field is
required as in many cases ``data`` will be something that eventually
resolves to a dataframe, but not available to driver.
data: data should contain a reference to a DataFrame in whatever way is
meaningful to implementation. This field is not introspected, only
passed by the framework.
"""
length: int
data: Any
@dataclass
class StreamTask:
"""Internal task wrapper for :class:`StreamingGraphExecutor`"""
task: PipelineTask
data_in: list[StreamBatch] = field(default_factory=deque)
pending: list = field(default_factory=deque)
counter: int = 0
split_pending: list = field(default_factory=deque)
def __hash__(self):
return hash(self.task)
@property
def name(self):
return self.task.name
@property
def all_pending(self):
return self.pending + self.split_pending
def total_pending(self):
return sum(len(i) for i in [self.data_in, self.pending, self.split_pending])
[docs]
class StreamingGraphExecutor(PipelineExecutor, ABC):
"""Base class for implementing streaming remote graph execution
This class implements the :meth:`execute` method of
:class:`PipelineExecutor` and contains logic necessary to schedule tasks,
prioritizing completing those that are closer to terminals. It supports
arbitrary pipeline graphs with branches, multiple inputs and outputs. By
default, for each run, it generates a indefinite stream of input dataframes
tagged with a monotonically incrementing batch id.
Args:
max_batches: maximum number of batches from the source generator to feed
to the input task(s). Default is None, which means either exhaust the
source generator or run indefinitely.
generator: A callable that when called returns a generator which yields
dataframes. The driver will call ``len()`` on the yielded dataframes to
obtain the number of rows and will split and batch according to the input
task settings. Each generated dataframe, regardless of size, counts as
a single source batch with respect to ``max_batches``.
Implementations must override abstract methods for (remote) task submission
and polling. The following must be overriden, see their docs for more:
- :meth:`is_task_ready`
- :meth:`poll_tasks`
- :meth:`split_batch_submit`
- :meth:`task_resolve_output`
- :meth:`task_submit`
- :meth:`task_submittable`
"""
[docs]
def __init__(
self, graph, max_batches: int = None, generator: Callable[[], Generator[pd.DataFrame, None, None]] = None
):
super().__init__(graph)
self.max_batches = max_batches
# make a local copy of the graph with each node wrapped in a tracker
# object
self.stream_graph = nx.relabel_nodes(self.graph, StreamTask)
self.generator_fun = generator or self.source_generator_fun
def pre_execute(self):
pass
def output_batch_transform(self, batch: OutputBatch) -> OutputBatch:
return batch
def execute(self):
self.n_sourced = 0
self.source_exhausted = False
self.source_generator = self.generator_fun()
self.pre_execute()
while True:
batch = self.execute_until_output()
if batch is None:
return
yield self.output_batch_transform(batch)
def source_generator_fun(self):
bid = 0
while True:
yield pd.DataFrame({"run_id": [self.run_id], "id": [bid]})
bid += 1
def get_pending(self):
return [p for tn in self.stream_graph for p in tn.all_pending]
def task_exhausted(self, task=None):
if task is not None and len(task.split_pending) > 0:
return False
for upstream in self.stream_graph.walk_back(task):
if upstream.total_pending() > 0:
return False
return True
def resolve_completed(self):
# Walk graph forward to promote completed tasks to next task
# queue. Dataframes for completed sink tasks are returned here in order
# to prioritize flushing.
for task in self.stream_graph.walk_fwd():
for ready in deque_extract(task.pending, self.is_task_ready):
block_info = self.task_resolve_output(ready)
if block_info.length == 0:
continue
if task in self.stream_graph.sink_tasks:
self.logger.debug(f"Batch <{task.name}>[l={block_info.length}] completed as output")
return OutputBatch(block_info.data, task=task.name)
else:
for next_task in self.stream_graph.neighbors(task):
self.logger.debug(f"Moving <{task.name}>[l={block_info.length}] to <{next_task.name}>")
next_task.data_in.appendleft(block_info)
for ready in deque_extract(task.split_pending, self.is_task_ready):
self.logger.debug(f"Splits <{task.name}> completed, moving to input queue")
task.data_in.extendleft(self.task_resolve_output(ready))
return None
def _feed_source(self, source):
if self.source_exhausted:
return
total_length = sum(i.length for i in source.data_in)
while total_length < (source.task.batch_size or 1):
try:
next_df = next(self.source_generator)
except StopIteration:
self.logger.debug("Source generator exhausted")
self.source_exhausted = True
break
# We feed any generated source to all source tasks similar the way
# upstream forked outputs broadcast. We add to data_in so that any
# necessary batching and splitting can be handled by normal procedure.
for task in self.stream_graph.source_tasks:
task.data_in.append(StreamBatch(data=next_df, length=len(next_df)))
self.n_sourced += 1
if self.n_sourced == self.max_batches:
self.logger.debug("Max batches reached, cancelling source generation")
self.source_exhausted = True
break
total_length += len(next_df)
def enqueue_tasks(self):
# helper to make submission decision of a single task based on the batch
# size, exhaustion conditions, and whether the implementation deems it
# submittable. Returns flags (eligible, submitted) to indicate whether
# it was eligible to be submitted based on input queue and batch size,
# and whether it was actually submitted.
def _handle_one_task(task, rank):
eligible = submitted = False
if len(task.data_in) == 0:
return (eligible, submitted)
batch_size = task.task.batch_size
if batch_size is not None:
for batch in deque_extract(task.data_in, lambda b: b.length > batch_size):
self.logger.debug(f"Enqueueing split for <{task.name}>[bs={batch_size}]")
task.split_pending.appendleft(self.split_batch_submit(batch, batch_size))
num_to_merge = deque_num_merge(task.data_in, batch_size)
if num_to_merge == 0:
# If the feed is terminated and there are no more tasks that
# will feed to this one, submit everything
if self.source_exhausted and self.task_exhausted(task):
num_to_merge = len(task.data_in)
else:
return (eligible, submitted)
eligible = True
if not self.task_submittable(task.task, rank):
return (eligible, submitted)
merged = [task.data_in.pop().data for _ in range(num_to_merge)]
self.logger.debug(f"Enqueueing merged batches <{task.name}>[n={len(merged)};bs={batch_size}]")
task.pending.appendleft(self.task_submit(task.task, merged))
task.counter += 1
submitted = True
return (eligible, submitted)
# proceed through all non-source tasks, which will be handled separately
# below due to the need to feed from generator. We walk backwards,
# re-evaluating the sort order of tasks of same depth after each single
# submission, implementing a kind of "fair" submission, while still
# prioritizing tasks closer to the sink.
submitted = True
while submitted:
rank = 0
submitted = False
for task in self.stream_graph.walk_back(sort_key=lambda x: x.counter):
if task in self.stream_graph.source_tasks:
continue
eligible, submitted = _handle_one_task(task, rank)
if eligible: # update rank of this task if it _could_ be done, whether or not it was
rank += 1
if submitted:
break
# Source as many inputs as can fit on source tasks. We prioritize flushing the
# input queue and secondarily on number of invocations in case batch sizes differ.
while True:
task_scheduled = False
for task in sorted(self.stream_graph.source_tasks, key=lambda x: (-len(x.data_in), x.counter)):
if self.task_submittable(task.task, rank):
self._feed_source(task)
_, task_scheduled = _handle_one_task(task, rank)
if task_scheduled: # we want to re-evalute the sort order
break
if not task_scheduled:
break
def execute_until_output(self):
while True:
if (completed := self.resolve_completed()) is not None:
return completed
if self.source_exhausted and self.task_exhausted():
self.logger.debug("All tasks exhausted, pipeline run ends")
return None
self.enqueue_tasks()
self.poll_tasks(self.get_pending())
[docs]
@abstractmethod
def task_submit(self, task: PipelineTask, df_list: list[pd.DataFrame]) -> Any:
"""Run or arrange for the running of task
Implementations must override this method and arrange for the function
of ``task`` to be called on a dataframe made from the concatenation of
``df_list``. The return value will be maintained in a pending queue, and
both ``task_resolve_output`` and ``is_task_ready`` will take these as
input, but will otherwise not be inspected. Typically the return value
would be a handle to the remote result or a future, or equivalent.
Note:
``PipelineTask`` expects a single DataFrame as input, while this
function receives a batch of such. It MUST concatenate these into a
single DataFrame prior to execution (e.g. with
``pd.concat(df_list)``). This is not done in the driver code as the
dataframes in ``df_list`` may not be local.
"""
pass
[docs]
@abstractmethod
def task_resolve_output(self, pending_task: Any) -> StreamBatch:
"""Return a :class:`StreamBatch` from completed task
This function takes the output produced by either :meth:`task_submit` or
:meth:`split_batch_submit`, and returns a :class:`StreamBatch` object which
tracks the length of returned dataframe(s) and the object which
references the underlying DataFrame.
The ``data`` member of returned :class:`StreamBatch` will be either:
- passed to another call of :meth:`task_submit` in a list container, or
- yielded in the :meth:`execute` call (which yields in the user-called
``run`` method). If any handling must be done prior to yield,
implementation should do so in overloaded :meth:`execute`.
"""
pass
[docs]
@abstractmethod
def is_task_ready(self, pending_task: Any) -> bool:
"""Return true if pending task is ready
This method takes outputs from :meth:`task_submit` and
:meth:`split_batch_submit` and must return ``True`` if the task is complete
and can be passed to :meth:`task_resolve_output` or ``False`` otherwise.
"""
pass
[docs]
@abstractmethod
def task_submittable(self, task: PipelineTask, rank: int) -> bool:
"""Preflight check if task can be submitted
Return ``True`` if current conditions enable the ``task`` to be
submitted. The ``rank`` argument is an indicator of relative importance,
and is incremented whenever the pending data for a given tasks meets the
batching requirements as driver walks the task graph backward. Thus
``Rank=0`` represents the task furthest along and so the highest
priority for submission.
"""
pass
[docs]
@abstractmethod
def split_batch_submit(self, batch: StreamBatch, max_rows: int) -> Any:
"""Submit a task to split batch into at most ``max_rows``
Similart to task_submit, implementations should arrange by whatever
means make sense to take the dataframe reference in ``batch.data`` of
:class:`StreamBatch`, given its length in ``batch.length`` and split
into a number of parts that result in no more than ``max_rows`` per
part. The return value should be a list of objects that can be processed
by :meth:`is_task_ready` and :meth:`task_resolve_output`.
"""
pass
[docs]
@abstractmethod
def poll_tasks(self, pending_task_list: list[Any]) -> None:
"""Wait for any change in status to ``pending_task_list``
This method will be called after submitting as many tasks as
possible. It gives a chance for implementations to wait in a io-friendly
way, for example by waiting on async futures. The input is a list of
objects as returned by :meth:`task_submit` or :meth:`split_batch_submit`. The
return value is unused.
"""
pass
class LocalSerialExecutor(StreamingGraphExecutor):
"""Implementation for reference and testing purposes
This reference implementation demonstrates expected outputs for abstract
methods, feeding a single batch at a time source to sink in the main thread.
"""
sflag = 0
def task_submit(self, pt, df_list):
self.sflag = 1
return pt.func(pd.concat(df_list))
def split_batch_submit(self, stream_batch, max_rows):
df = stream_batch.data
return split_dataframe(df, max_rows=max_rows)
def task_resolve_output(self, to):
if isinstance(to, list):
return [StreamBatch(len(i), i) for i in to]
return StreamBatch(len(to), to)
def task_submittable(self, t, rank):
return self.sflag == 0
def is_task_ready(self, t):
return True
def poll_tasks(self, pending):
self.sflag = 0
def deque_num_merge(queue, batch_size):
if batch_size is None:
return 1 if len(queue) > 0 else 0
else:
# So long as batch size is set, try to merge if necessary. Proceed in
# fifo order and take up to batch_size rows, but no more.
s_accum = np.cumsum([i.length for i in reversed(queue)])
(idxs,) = np.where(s_accum >= batch_size)
if len(idxs) == 0:
return 0
return idxs[0] + 1