Source code for functionalizer.io.circuit

"""Data loading for nodes, edges."""

import glob
import hashlib
import os
import re
from pathlib import Path
from typing import List

import pandas as pd
import pyarrow.parquet as pq
from packaging.version import VERSION_PATTERN, Version
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

import sparkmanager as sm
from functionalizer import schema
from functionalizer.schema import OUTPUT_MAPPING
from functionalizer.utils import get_logger
from functionalizer.utils.filesystem import adjust_for_spark

from .morphologies import MorphologyDB

BASIC_EDGE_SCHEMA = ["source_node_id long", "target_node_id long", "synapse_id long"]
BASIC_NODE_SCHEMA = ["id long"]

logger = get_logger(__name__)

# Widen unsigned data types to prevent potential dataloss during
# conversions.
#
# FIXME we blatantly assume that objects are strings
_numpy_to_spark = {
    "int8": "byte",
    "int16": "short",
    "int32": "int",
    "int64": "long",
    "uint8": "short",
    "uint16": "int",
    "uint32": "long",
    "uint64": "long",
    "float32": "float",
    "float64": "double",
    "object": "string",
}

# Basic unit to load at one time from edge or node files.  Since Spark has its own ideas
# about partitions, several chunks may be used per partition.  Thus set a small-ish size
# here.
_CHUNK_SIZE = 250_000

# Internal calculations rely on branch types being 0-based. Input should
# follow the SONATA conversion, inherited from MorphIO, where values are
# 1-based. Thus this offset...
BRANCH_OFFSET: int = 1
BRANCH_MAX_VALUE_SONATA: int = 4
BRANCH_COLUMNS: List[str] = ["afferent_section_type", "efferent_section_type"]
BRANCH_SHIFT_MINIMUM_CONVERTER_VERSION: Version = Version("0.6.1")


def shift_branch_type(df: DataFrame, shift: int = BRANCH_OFFSET) -> DataFrame:
    """Shift branch/section types from 1-based to 0-based."""
    for attr in BRANCH_COLUMNS:
        tmp_attr = f"__tmp__{attr}"
        if hasattr(df, attr):
            df = (
                df.withColumnRenamed(attr, tmp_attr)
                .withColumn(attr, F.col(tmp_attr) + F.lit(shift))
                .drop(tmp_attr)
            )
    return df


def branch_type_shifted(df: DataFrame) -> bool:
    """Determine if branch/section types should be shifted.

    Returns `True` is the branch/section types should be shifted
    by `-BRANCH_OFFSET`. Which is the case when the branch/section
    types are 1-based.

    If the section type columns are not present, returns `False`.  Will raise a
    `RuntimeError` if the data in the columns is inconclusive.
    """
    shifts = set()
    all_values = set()
    seen = set()
    for attr in BRANCH_COLUMNS:
        if hasattr(df, attr):
            seen.add(attr)
            vals = df.select(F.min(attr).alias("attr_min"), F.max(attr).alias("attr_max")).take(1)[
                0
            ]
            all_values.add(vals["attr_min"])
            all_values.add(vals["attr_max"])
            if vals["attr_min"] == 0 and vals["attr_max"] < BRANCH_MAX_VALUE_SONATA:
                shifts.add(False)
            elif vals["attr_min"] > 0 and vals["attr_max"] == BRANCH_MAX_VALUE_SONATA:
                shifts.add(True)
            elif vals["attr_min"] == 0 and vals["attr_max"] == BRANCH_MAX_VALUE_SONATA:
                raise RuntimeError("Cannot determine branch type offset, range too broad")
            elif vals["attr_min"] < 0 or vals["attr_max"] > BRANCH_MAX_VALUE_SONATA:
                raise RuntimeError("Invalid minimum and maximum branch type.")
    if not seen:
        return False
    if len(shifts) != 1:
        raise RuntimeError(f"Cannot determine branch type offset with {all_values}")
    return next(iter(shifts))


def _accept_node_enumeration(attr: str) -> bool:
    """Select which enumerations to use, will fall back to attributes otherwise."""
    if attr == "morphology":
        return False
    if attr == "morph_class":
        return False
    return True


def _accept_node_attribute(attr: str) -> bool:
    """Accepts all node attributes except for position, rotation ones."""
    if attr.startswith("rotation_"):
        return False
    if attr.startswith("orientation_"):
        return False
    if attr in ("x", "y", "z"):
        return False
    return True


def _get_enumerations(population):
    """Returns the enumerations for a population."""
    return sorted(filter(_accept_node_enumeration, population.enumeration_names))


def _get_pure_attributes(population):
    """Returns the attributes for a population that are not enumerations."""
    return sorted(
        filter(
            _accept_node_attribute,
            population.attribute_names - set(_get_enumerations(population)),
        )
    )


def _column_type(population, column, accessor):
    """Helper to determine column types.

    Determine the Spark datatype for `column` from `population` by getting data for an
    empty selection via the `accessor` method of `population`.
    """
    import libsonata

    data = getattr(population, accessor)(column, libsonata.Selection([]))
    return _numpy_to_spark[data.dtype.name]


def _add_all_attributes(dataframe, population, selection):
    """Helper to add SONATA attributes to Spark dataframes.

    Adds all enumeration and pure attributes for a given `selection` of a `population` to
    the `dataframe` passed in.
    """
    for column in _get_pure_attributes(population):
        name = column
        dataframe[name] = population.get_attribute(column, selection)
    for column in _get_enumerations(population):
        name = column
        dataframe[f"{name}_i"] = population.get_enumeration(column, selection)
    return dataframe


def _types(population):
    """Helper to create Spark type conventions.

    Generates a sequence of schema strings for Spark corresponding to the attributes and
    enumerations of `population`.
    """
    for column in _get_pure_attributes(population):
        kind = _column_type(population, column, "get_attribute")
        name = column
        yield f"{name} {kind}"
    for column in _get_enumerations(population):
        kind = _column_type(population, column, "get_enumeration")
        name = column
        yield f"{name}_i {kind}"


def _create_neuron_loader(filename, population):
    """Create a UDF to load neurons from SONATA.

    Args:
        filename: The name of the circuit file
        population: The population to load
    Returns:
        A Pandas UDF to be used over a group by
    """

    def loader(dfs):
        import libsonata

        pop = libsonata.NodeStorage(filename).open_population(population)

        for df in dfs:
            ids = []
            for row in df.itertuples():
                ids.extend(row.ids)
            selection = libsonata.Selection(ids)
            data = _add_all_attributes({"id": ids}, pop, selection)
            yield pd.DataFrame(data)

    return loader


def _create_touch_loader(filename: str, population: str):
    """Create a UDF to load touches from SONATA.

    Args:
        filename: The name of the touches file
        population: The population to load
    Returns:
        A Pandas UDF to be used in a mapInPandas
    """

    def loader(dfs):
        import libsonata

        pop = libsonata.EdgeStorage(filename).open_population(population)

        for df in dfs:
            intervals = []
            for row in df.itertuples():
                intervals.append((row.start, row.end))
            selection = libsonata.Selection(intervals)
            data = {
                "source_node_id": pop.source_nodes(selection),
                "target_node_id": pop.target_nodes(selection),
                "synapse_id": selection.flatten(),
            }
            data = _add_all_attributes(data, pop, selection)
            yield pd.DataFrame(data)

    return loader


[docs] class NodeData: """Neuron data loading facilities. This class represent neuron populations, lazily loaded. After the construction, general properties of the neurons, such as the unique values of the :attr:`.NodeData.mtype_values`, :attr:`.NodeData.etype_values`, or :attr:`.NodeData.sclass_values` present can be accessed. """ def __init__(self, circuit_config: str, population: str, nodeset: str, cache: str): """Construct a new neuron loader. To load neuron-specific information, access the property :attr:`.NodeData.df`, data will be loaded lazily. Args: circuit_config: the circuit configuration with node storage details population: the population name nodeset: the nodeset name to use cache: a directory name to use for caching generated Parquet """ self._cache = cache self._df = None import libsonata cfg = libsonata.CircuitConfig.from_file(circuit_config) self._ns_filename = cfg.node_sets_path self._ns_nodeset = nodeset if population: self._population = population elif len(cfg.node_populations) != 1: raise ValueError("cannot determine node population") else: self._population = next(iter(cfg.node_populations)) pop = cfg.node_population(self._population) self._size = len(pop) self._columns = ", ".join(BASIC_NODE_SCHEMA + list(_types(pop))) for attr in _get_enumerations(pop): setattr(self, f"{attr}_values", pop.enumeration_values(attr)) pop_prop = cfg.node_population_properties(self._population) self._filename = pop_prop.elements_path self.morphologies = MorphologyDB( pop_prop.alternate_morphology_formats.get("h5v1", pop_prop.morphologies_dir) ) if pop_prop.spine_morphologies_dir: self.spine_morphology_path = Path(pop_prop.spine_morphologies_dir) else: self.spine_morphology_path = None if not os.path.isdir(self._cache): os.makedirs(self._cache) def __len__(self): """The number of nodes in the cell dataframe.""" return self._size @property def population(self): """The population name.""" return self._population @property def df(self): """The PySpark dataframe with the neuron data.""" if not self._df: self._df = self._load_neurons() return self._df def _load_neurons(self): fn = self._filename sha = hashlib.sha256() sha.update(os.path.realpath(fn).encode()) sha.update(self._population.encode()) sha.update(str(os.stat(fn).st_size).encode()) sha.update(str(os.stat(fn).st_mtime).encode()) if self._ns_filename and self._ns_nodeset: sha.update(self._ns_filename.encode()) sha.update(self._ns_nodeset.encode()) digest = sha.hexdigest()[:8] logger.info("Total neurons: %d", len(self)) df_parquet = os.path.join( self._cache, f"neurons_{len(self) / 1000.0:.1f}k_{digest}.parquet" ) if os.path.exists(df_parquet): logger.info("Loading circuit from parquet") df = sm.read.parquet(adjust_for_spark(df_parquet, local=True)) df.count() # force materialize else: logger.info("Building nodes from SONATA") # Create a default selection, or load it from the NodeSets if not self._ns_nodeset: ids = list(range(0, len(self))) else: import libsonata nodesets = libsonata.NodeSets.from_file(self._ns_filename) population = libsonata.NodeStorage(self._filename).open_population(self._population) selection = nodesets.materialize(self._ns_nodeset, population) ids = selection.flatten().tolist() total_parts = len(ids) // _CHUNK_SIZE if len(ids) % _CHUNK_SIZE > 0: total_parts += 1 logger.debug("Partitions: %d", total_parts) def generate_ids(): for n in range(total_parts): start = _CHUNK_SIZE * n end = min(_CHUNK_SIZE * (n + 1), len(ids)) yield n, ids[start:end] parts = sm.createDataFrame(generate_ids(), ["row", "ids"]) # Create DF logger.info("Creating neuron data frame...") raw_df = parts.repartition(total_parts).mapInPandas( _create_neuron_loader(self._filename, self._population), self._columns ) # Evaluate (build partial NameMaps) and store df = raw_df.write.mode("overwrite").parquet(adjust_for_spark(df_parquet, local=True)) df = sm.read.parquet(adjust_for_spark(df_parquet)).cache() return df
def _get_size(files): """Returns the total size of a list of filenames or directories.""" size = 0 def _add_size(fn): nonlocal size if fn.endswith(".parquet") or fn.endswith(".h5"): size += os.path.getsize(fn) for path in files: if os.path.isfile(path): _add_size(path) else: for root, _, filenames in os.walk(path): for fn in filenames: _add_size(os.path.join(root, fn)) return size def _grab_parquet(files): """Returns as many parquet files from the front of `files` as possible.""" parquets = [] while files and files[0].endswith(".parquet"): if os.path.isdir(files[0]): if parquets: return parquets return [files.pop(0)] parquets.append(files.pop(0)) return parquets def _grab_sonata_population(filename): """Retrieve the default population in a SONATA files. Raise an exception if no population present or more than one population is found. """ import libsonata populations = libsonata.EdgeStorage(filename).population_names if len(populations) == 1: return next(iter(populations)) if len(populations) > 1: raise ValueError(f"More than one population in '{filename}'") raise ValueError(f"No population in '{filename}'") def _grab_sonata(files): """Returns a possible SONATA file from the front of `files`.""" if not files: return None if not files[0].endswith(".h5"): return None filename = files.pop(0) if files and not any(files[0].endswith(ext) for ext in (".h5", ".parquet")): population = files.pop(0) else: population = _grab_sonata_population(filename) return (filename, population) class EdgeData: """Edge data loading facilities. This class represent the connectivity between cell populations, lazily loaded. Access the property :attr:`.EdgeData.df`, to load the data. """ def __init__(self, *paths): # noqa: PLR0912 """Initialize the loader. Args: paths: A list of edge files. """ files = [] for path in paths: files.extend(glob.glob(path) or [path]) metadata = [] self._size = _get_size(files) self._loaders = [] while files: if parquet := _grab_parquet(files): local_metadata = self._load_parquet_metadata(*parquet) metadata.append(local_metadata) self._loaders.append(self._load_parquet(local_metadata, *parquet)) elif sonata := _grab_sonata(files): metadata.append(self._load_sonata_metadata(*sonata)) self._loaders.append(self._load_sonata(*sonata)) else: raise ValueError(f"cannot process file(s) {files[0]}") if len(set(frozenset(m.items()) for m in metadata)) == 1: self._metadata = metadata[0] elif metadata: logger.debug("Detected multiple different inputs, prefixing metadata") self._metadata = {} for key in schema.METADATA_FIXED_KEYS: for m in metadata: if key not in m: continue value = m.pop(key) if self._metadata.setdefault(key, value) != value: raise ValueError( "conflicting values for metadata " f"{key}: {self._metadata[key]}, {value}" ) for n, m in enumerate(metadata): self._metadata.update({f"merge{n}_{k}": v for k, v in m.items()}) else: raise ValueError("need to provide at least one file with edges") @property def df(self): """The PySpark dataframe with the edge data.""" df = self._loaders[0]() for loader in self._loaders[1:]: df = df.union(loader()) return df.withColumnRenamed("source_node_id", "src").withColumnRenamed( "target_node_id", "dst" ) @property def input_size(self): """The initial size (in bytes) of the input data.""" return self._size @property def metadata(self): """The metadata associated with the input data.""" return self._metadata @staticmethod def _load_sonata_metadata(filename, population): # Could be (and save us the dependency on h5py): # import libsonata # p = libsonata.EdgeStorage(filename).open_population(population) # return {n: p.get_metadata(n) for n in p.metadata_names} import h5py with h5py.File(filename) as f: return dict(f[f"/edges/{population}"].attrs) @staticmethod def _load_sonata(filename, population): def _loader(): import libsonata p = libsonata.EdgeStorage(filename).open_population(population) total_parts = p.size // _CHUNK_SIZE if p.size % _CHUNK_SIZE > 0: total_parts += 1 logger.debug("Partitions: %d", total_parts) parts = sm.createDataFrame( ( (n, _CHUNK_SIZE * n, min(_CHUNK_SIZE * (n + 1), p.size)) for n in range(total_parts) ), "row: long, start: long, end: long", ) columns = ", ".join(BASIC_EDGE_SCHEMA + list(_types(p))) logger.info("Creating edge data frame...") edges = parts.repartition(total_parts).mapInPandas( _create_touch_loader(filename, population), columns ) edges = shift_branch_type(edges, -BRANCH_OFFSET) for new, (old, _) in OUTPUT_MAPPING.items(): if old in edges.columns: edges = edges.withColumnRenamed(old, new) return edges.cache() return _loader @staticmethod def _load_parquet_metadata(path, *_): path = Path(path) if path.is_dir(): path /= "_metadata" meta = pq.ParquetDataset(path, use_legacy_dataset=False).schema.metadata return { k.decode(): v.decode() for (k, v) in (meta or {}).items() if not k.startswith(b"org.apache.spark") } @staticmethod def _load_parquet(metadata, *args): def _loader(): files = [adjust_for_spark(f) for f in args] edges = sm.read.parquet(*files) for old, new in schema.LEGACY_MAPPING.items(): if old in edges.columns: edges = edges.withColumnRenamed(old, new) if raw_version := metadata.get("touch2parquet_version"): if m := re.compile(VERSION_PATTERN, re.VERBOSE | re.IGNORECASE).search(raw_version): t2p_version = Version(m.group(0)) else: raise RuntimeError(f"Can't determine touch2parquet version from {raw_version}") shift = t2p_version >= BRANCH_SHIFT_MINIMUM_CONVERTER_VERSION else: shift = branch_type_shifted(edges) if shift: return shift_branch_type(edges, -BRANCH_OFFSET) return edges return _loader