Source code for functionalizer.filters.definitions

"""Query interface for Neuron dataframe / graph."""

import hashlib
import importlib
import inspect
import os
import sys
from abc import abstractmethod
from datetime import datetime
from glob import glob
from pathlib import Path

import sparkmanager as sm
from functionalizer.circuit import Circuit
from functionalizer.utils import get_logger
from functionalizer.utils.checkpointing import checkpoint_resume

logger = get_logger(__name__)


def load(*dirnames: str) -> None:
    """Load plugins from a list of directories.

    If no directories are given, load a default set of plugins.

    Args:
        dirnames: A list of directories to load plugins from.
    """
    if not dirnames:
        dirnames = [os.path.join(os.path.dirname(__file__), "implementations")]
    for dirname in dirnames:
        filenames = glob(f"{dirname}/*.py")
        for filename in filenames:
            modulename = filename[:-3]
            relative = min((os.path.relpath(modulename, p) for p in sys.path), key=len)
            modulename = relative.replace(os.sep, ".")
            importlib.import_module(modulename)


class FilterInitializationError(RuntimeError):
    """Error to be raised when filters should be skipped."""


# ---------------------------------------------------
# Dataset operations
# ---------------------------------------------------
class __DatasetOperationType(type):
    """Forced unification of classes.

    The structure of the constructor and application function to circuits
    is pre-defined to be able to construct and apply filters automatically.

    Classes implementing this type are automatically added to a registry,
    with a trailing "Filter" stripped of their name.  Set the attribute
    `_visible` to `False` to exclude a filter from appearing in the list.
    """

    __filters = {}

    def __init__(cls, name, bases, attrs) -> None:
        if "apply" not in attrs:
            raise AttributeError(f'class {cls} does not implement "apply(circuit)"')
        try:
            spec = inspect.getfullargspec(attrs["apply"])
            if not (
                spec.varargs is None
                and spec.varkw is None
                and spec.defaults is None
                and spec.args == ["self", "circuit"]
            ):
                raise TypeError
        except TypeError as e:
            raise AttributeError(f'class {cls} does not implement "apply(circuit)" properly') from e
        spec = inspect.getfullargspec(cls.__init__)
        if not (
            spec.varargs is None
            and spec.varkw is None
            and spec.defaults is None
            and spec.args == ["self", "recipe", "source", "target"]
        ):
            raise AttributeError(
                f"class {cls} does not implement " '"__init__(recipe, source, target)" properly'
            )
        type.__init__(cls, name, bases, attrs)
        if attrs.get("_visible", True):
            cls.__filters[name.replace("Filter", "")] = cls

    @classmethod
    def initialize(mcs, names, *args):
        """Create filters from a list of names.

        :param names: A list of filter class names to invoke
        :param args: Arguments to pass through to the filters
        :return: A list of filter instances
        """
        for fcls in mcs.__filters.values():
            if hasattr(fcls, "_checkpoint_name"):
                delattr(fcls, "_checkpoint_name")
        key = hashlib.sha256()
        key.update(b"foobar3000")
        filters = []
        for name in names:
            fcls = mcs.__filters.get(name, mcs.__filters.get(name + "Filter"))
            if fcls is None:
                raise ValueError(f"Cannot find filter '{name}'")
            key.update(fcls.__name__.encode())
            if hasattr(fcls, "_checkpoint_name"):
                raise ValueError(f"Cannot have more than one {fcls.__name__}")
            fcls._checkpoint_name = (
                f"{fcls.__name__.replace('Filter', '').lower()}" f"_{key.hexdigest()[:8]}"
            )
            try:
                filters.append(fcls(*args))
            except FilterInitializationError as e:
                if fcls._required:
                    logger.fatal("Could not instantiate %s", fcls.__name__)
                    raise
                logger.warning("Disabling optional %s: %s", fcls.__name__, e)
        for i in range(len(filters) - 1, -1, -1):
            base = Path(checkpoint_resume.directory)
            parquet = filters[i]._checkpoint_name + ".parquet"
            table = filters[i]._checkpoint_name + ".ptable"
            fn = "_SUCCESS"
            if (base / parquet / fn).exists() or (base / table / fn).exists():
                classname = filters[i].__class__.__name__
                logger.info("Found checkpoint for %s", classname)
                break
        else:
            i = 0  # force initialization in case filters is empty
        for f in filters[:i]:
            classname = f.__class__.__name__
            logger.info("Removing %s", classname)
        return filters[i:]

    @classmethod
    def modules(mcs):
        """List registered subclasses."""
        return sorted(mcs.__filters.keys())


[docs] class DatasetOperation(metaclass=__DatasetOperationType): """Basis for synapse filters. Every filter should derive from :class:`~functionalizer.filters.DatasetOperation`, which will enforce the right format for the constructor and :meth:`~functionalizer.filters.DatasetOperation.apply` functions. The former is optional, but should be used to extract relevant information from the recipe. The two node populations are passed to the constructor to enable cross-checks between the recipe information and the population properties. If the constructor raises an exception and the :attr:`._required` attribute is set to `False`, the filter will be skipped. If filters add or remove columns from the dataframe, this should be communicated via the :attr:`._columns` attribute, otherwise the general invocation of the filters will fail, as column consistency is checked. """ _checkpoint = False """Store the results on disk, allows to skip computation on subsequent runs. """ _checkpoint_buckets = None """Partition the data when checkpointing, avoids sort on load. """ _visible = False """Determines the visibility of the filter to the user. """ _reductive = True """Indicates if the filter is expected to reduce the touch count. """ _required = True """If set to `False`, the filter will be skipped if recipe components are not found. """ _columns = [] """A list columns to be consumed and produced. Each item should be a tuple of two strings, giving the column consumed/dropped, and the column produced. If no column is dropped, `None` can be used. Likewise, if a column is only dropped, `None` can be the second element. Examples:: (None, "synapse_id") # will produce the column "synapse_id" ("synapse_id", None) # will drop the colulmn "synapse_id" ("ham", "spam") # will produce the colum "spam" while also # dropping "ham". If the latter is not # present, the former will not be # added. """ def __init__(self, recipe, source, target): """Empty constructor supposed to be overriden. Args: recipe: Wrapper around an XML document with parametrization information source: The source node population target: The target node population """
[docs] def __call__(self, circuit): """Apply the operation to `circuit`.""" classname = self.__class__.__name__ logger.info("Applying %s", classname) with sm.jobgroup(classname): ran_filter = False # assume loading from disk by default start = datetime.now() old_count = len(circuit) olds = set(circuit.df.columns) to_add = set(a for (c, a) in self._columns if not c or c in olds) to_drop = set(c for (c, _) in self._columns if c in olds) to_remove = olds & to_add if to_remove: logger.warning("Removing columns %s", ", ".join(to_remove)) circuit.df = circuit.df.drop(*to_remove) olds -= to_remove if not self._checkpoint: ran_filter = True circuit.df = self.apply(circuit) else: @checkpoint_resume( self._checkpoint_name, bucket_cols=self._checkpoint_buckets, ) def fun(): nonlocal ran_filter ran_filter = True return self.apply(circuit) circuit.df = fun() news = set(circuit.df.columns) dropped = olds - news added = news - olds if to_drop - dropped: raise RuntimeError(f"Undropped columns: {to_drop - dropped}") if dropped - to_drop: raise RuntimeError(f"Dropped columns: {dropped - to_drop}") if to_add - added: raise RuntimeError(f"Missing columns: {to_add - added}") if added - to_add: raise RuntimeError(f"Additional columns: {added - to_add}") if ran_filter: new_count = len(circuit) diff = old_count - new_count if self._reductive: logger.info( # pylint: disable=logging-fstring-interpolation f"{classname} removed {diff:,d} touches, " f"circuit now contains {new_count:,d}" ) elif diff != 0: raise RuntimeError(f"{classname} removed touches, but should not") logger.info("%s application took %s", classname, datetime.now() - start) return circuit
[docs] @abstractmethod def apply(self, circuit: Circuit): """Needs actual implementation of the operation. Takes a `Circuit`, applies some operations to it, and returns Spark dataframe representing the updated circuit. """
[docs] @staticmethod def pathway_functions(columns, counts): """Construct pathway adding functions given columns and a value counts.""" def _rename_maybe_numeric(col): if (col.startswith("src_") or col.startswith("dst_")) and not col.endswith("_i"): return f"{col}_i" return col def add_pathway(df): from pyspark.sql import functions as F pathway_column = F.lit(0) pathway_column_format = [] pathway_column_values = [] for col, factor in zip(columns, counts): name = _rename_maybe_numeric(col) pathway_column *= factor pathway_column += F.col(name) pathway_column_format.append(f"{name}(%d)") pathway_column_values.append(name) return df.withColumn("pathway_i", pathway_column).withColumn( "pathway_str", F.format_string(", ".join(pathway_column_format), *pathway_column_values), ) return add_pathway