Skip to content

Commit

Permalink
feat(core): add Population.map parallel processing support
Browse files Browse the repository at this point in the history
  • Loading branch information
yzx9 committed Feb 29, 2024
1 parent 46a30b0 commit a1ee5bb
Showing 1 changed file with 47 additions and 7 deletions.
54 changes: 47 additions & 7 deletions swcgeom/core/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,36 @@

import os
import warnings
from concurrent.futures import ProcessPoolExecutor
from functools import reduce
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Protocol,
TypeVar,
cast,
overload,
)

import numpy as np
import numpy.typing as npt
from tqdm.contrib.concurrent import process_map
from typing_extensions import Self

from swcgeom.core.swc import eswc_cols
from swcgeom.core.tree import Tree

__all__ = ["LazyLoadingTrees", "ChainTrees", "Population", "Populations"]


T = TypeVar("T")


class Trees(Protocol):
"""Trees protocol support index and len."""

Expand All @@ -48,6 +56,7 @@ def __init__(self, swcs: Iterable[str], **kwargs) -> None:
kwargs : Dict[str, Any]
Forwarding to `Tree.from_swc`
"""

super().__init__()
self.swcs = list(swcs)
self.trees = [None for _ in swcs]
Expand All @@ -61,6 +70,9 @@ def __getitem__(self, key: int, /) -> Tree:
def __len__(self) -> int:
return len(self.swcs)

def __iter__(self) -> Iterator[Tree]:
return (self[i] for i in range(self.__len__()))

def load(self, key: int) -> None:
if self.trees[key] is None:
self.trees[key] = Tree.from_swc(self.swcs[key], **self.kwargs)
Expand Down Expand Up @@ -92,6 +104,9 @@ def __getitem__(self, key: int, /) -> Tree:
def __len__(self) -> int:
return self.cumsum[-1].item()

def __iter__(self) -> Iterator[Tree]:
return (self[i] for i in range(self.__len__()))


class Population:
"""Neuron population."""
Expand Down Expand Up @@ -153,15 +168,40 @@ def __iter__(self) -> Iterator[Tree]:
def __repr__(self) -> str:
return f"Neuron population in '{self.root}'"

def map(
self,
fn: Callable[[Tree], T],
*,
max_worker: Optional[int] = None,
verbose: bool = False,
) -> Iterator[T]:
"""Map a function to all trees in the population.
This is a straightforward interface for parallelizing
computations. The parameters are intentionally kept simple and
user-friendly. For more advanced control, consider using
`concurrent.futures` directly.
"""

trees = (t for t in self.trees)

if verbose:
results = process_map(fn, trees, max_workers=max_worker)
else:
with ProcessPoolExecutor(max_worker) as p:
results = p.map(fn, trees)

return results

@classmethod
def from_swc(cls, root: str, ext: str = ".swc", **kwargs) -> "Population":
def from_swc(cls, root: str, ext: str = ".swc", **kwargs) -> Self:
if not os.path.exists(root):
raise FileNotFoundError(
f"the root does not refers to an existing directory: {root}"
)

swcs = cls.find_swcs(root, ext)
return Population(LazyLoadingTrees(swcs, **kwargs), root=root)
return cls(LazyLoadingTrees(swcs, **kwargs), root=root)

@classmethod
def from_eswc(
Expand All @@ -170,7 +210,7 @@ def from_eswc(
ext: str = ".eswc",
extra_cols: Optional[Iterable[str]] = None,
**kwargs,
) -> "Population":
) -> Self:
extra_cols = list(extra_cols) if extra_cols is not None else []
extra_cols.extend(k for k, t in eswc_cols)
return cls.from_swc(root, ext, extra_cols=extra_cols, **kwargs)
Expand Down Expand Up @@ -235,15 +275,15 @@ def to_population(self) -> Population:
return Population(ChainTrees(p.trees for p in self.populations))

@classmethod
def from_swc( # pylint: disable=too-many-arguments
def from_swc(
cls,
roots: Iterable[str],
ext: str = ".swc",
intersect: bool = True,
check_same: bool = False,
labels: Optional[Iterable[str]] = None,
**kwargs,
) -> "Populations":
) -> Self:
"""Get population from dirs.
Parameters
Expand Down Expand Up @@ -275,7 +315,7 @@ def from_swc( # pylint: disable=too-many-arguments
)
for i, d in enumerate(roots)
]
return Populations(populations, labels=labels)
return cls(populations, labels=labels)

@classmethod
def from_eswc(
Expand All @@ -285,7 +325,7 @@ def from_eswc(
*,
ext: str = ".eswc",
**kwargs,
) -> "Populations":
) -> Self:
extra_cols = list(extra_cols) if extra_cols is not None else []
extra_cols.extend(k for k, t in eswc_cols)
return cls.from_swc(roots, extra_cols=extra_cols, ext=ext, **kwargs)
Expand Down

0 comments on commit a1ee5bb

Please sign in to comment.