From 35719fcde0faa6079fe8c4572aa921ecaa051388 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 22 May 2024 18:03:07 -0700 Subject: [PATCH 01/14] added support for old-style retries_on_error --- python/dolma/core/loggers.py | 9 +- python/dolma/core/mp_tools.py | 130 +++++++++ python/dolma/core/parallel.py | 476 ++++++++++++++++--------------- python/dolma/core/progressbar.py | 276 ++++++++++++++++++ python/dolma/core/utils.py | 82 +++++- tests/python/test_parallel.py | 113 ++++++++ tests/python/test_utils.py | 45 ++- tests/python/utils.py | 8 +- 8 files changed, 884 insertions(+), 255 deletions(-) create mode 100644 python/dolma/core/mp_tools.py create mode 100644 python/dolma/core/progressbar.py diff --git a/python/dolma/core/loggers.py b/python/dolma/core/loggers.py index f34ba864..f0ff05f3 100644 --- a/python/dolma/core/loggers.py +++ b/python/dolma/core/loggers.py @@ -5,15 +5,20 @@ DOLMA_PREFIX = "dolma" -def get_logger(name: str) -> logging.Logger: +def get_logger(name: str, level: Union[int, str] = logging.WARN) -> logging.Logger: if (proc_name := multiprocessing.current_process().name) == "MainProcess": proc_name = "main" proc_name = proc_name.replace(" ", "_") + # set the log level + level = level if isinstance(level, int) else getattr(logging, level.strip().upper(), logging.WARN) + + # set name name = f"{proc_name}.dolma.{name}" logger = logging.getLogger(name) - logger.setLevel(logging.WARN) + logger.setLevel(level) + # add handler if not logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( diff --git a/python/dolma/core/mp_tools.py b/python/dolma/core/mp_tools.py new file mode 100644 index 00000000..f88477e4 --- /dev/null +++ b/python/dolma/core/mp_tools.py @@ -0,0 +1,130 @@ +import multiprocessing +import time +from contextlib import ExitStack +from multiprocessing.managers import SyncManager +from multiprocessing.pool import Pool +from queue import Queue +from typing import Any, Callable, Dict, Generic, Iterable, Optional, TypeVar, Union + +T = TypeVar("T") +R = TypeVar("R") + + +def get_manager(pool: Union[Pool, "PoolWithDebug"]) -> Union[SyncManager, "ManagerWithDebug"]: + if getattr(pool, "debug", False): + return ManagerWithDebug() + else: + return multiprocessing.Manager() + + +class ResultWithDebug(Generic[T]): + def __init__(self, result: T, *args, **kwargs): + self.result = result + + def get(self, timeout: Optional[float] = None) -> T: + return self.result + + def wait(self, timeout: Optional[float] = None) -> None: + time.sleep(timeout or 0) + + def successful(self) -> bool: + return True + + def ready(self) -> bool: + return True + + +class ManagerWithDebug: + def Queue(self): + return Queue() + + def shutdown(self) -> None: + pass + + +class PoolWithDebug: + """A wrapper around multiprocessing.Pool that allows for debugging (i.e., running without multiprocessing). + Supports creating a manager for shared memory objects (mock in case of debugging).""" + + def __init__( + self, + processes: Optional[int] = None, + initializer: Optional[Callable[..., Any]] = None, + initargs: Iterable[Any] = (), + maxtasksperchild: Optional[int] = None, + debug: bool = False, + ): + self.processes = processes + self.initializer = initializer + self.initargs = initargs + self.maxtasksperchild = maxtasksperchild + self.debug = debug + + # we are gonna keep track of resources in stack; but also keeping them indexed + # separately for easy access + self.stack = ExitStack() + self._manager: Optional[SyncManager] = None + self._pool: Optional[Pool] = None + + # let's make sure that the start method is spawn for best performance + try: + multiprocessing.set_start_method("spawn") + except RuntimeError: + assert multiprocessing.get_start_method() == "spawn", "Multiprocessing start method must be spawn" + + def __enter__(self): + if self._pool is None and not self.debug: + self._pool = self.stack.enter_context( + Pool( + processes=self.processes, + initializer=self.initializer, + initargs=self.initargs, + maxtasksperchild=self.maxtasksperchild, + ) + ) + return self + + def Manager(self): + if self._manager is None: + self._manager = ( + ManagerWithDebug() # type: ignore + if self.debug + else self.stack.enter_context(multiprocessing.Manager()) + ) + return self._manager + + def __exit__(self, *exc): + return self.stack.close() + + def apply_async( + self, + func: Callable[..., R], + args: Iterable[Any] = (), + kwds: Dict[str, Any] = {}, + callback: Optional[Callable[[R], Any]] = None, + error_callback: Optional[Callable[[Any], Any]] = None, + ): + if self._pool is None: + if self.initializer: + # run the initializer once by calling it with the initargs and then setting it to None + self.initializer(*self.initargs) + self.initializer = None + try: + resp = func(*args, **kwds) + if callback is not None: + callback(resp) + return ResultWithDebug(resp) + except Exception as e: + if error_callback is not None: + error_callback(e) + raise e + else: + return self._pool.apply_async( + func=func, args=args, kwds=kwds, callback=callback, error_callback=error_callback + ) + + def close(self): + return self._pool and self._pool.close() + + def join(self): + return self._pool and self._pool.join() diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index 0bbfc75f..1823b791 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -1,39 +1,39 @@ -import inspect import itertools import logging import multiprocessing import pickle import random import re -import time -from contextlib import ExitStack from datetime import datetime from functools import partial from queue import Queue -from threading import Thread -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union +import backoff import smart_open -import tqdm +from backoff.types import Details from typing_extensions import TypeAlias from .errors import DolmaError, DolmaRetryableFailure from .loggers import get_logger +from .mp_tools import PoolWithDebug, get_manager from .paths import ( add_suffix, + exists, glob_path, join_path, make_relative, mkdir_p, parent, split_path, - sub_prefix, ) +from .progressbar import BaseProgressBar +from .utils import batch_iterator METADATA_SUFFIX = ".done.txt" # we need to quote the type alias because we want to support Python 3.8 -QueueType: TypeAlias = "Queue[Union[None, Tuple[int, ...]]]" +QueueType: TypeAlias = Queue[Union[None, Tuple[int, ...]]] KwargsType: TypeAlias = Dict[str, Any] BPP = TypeVar("BPP", bound="BaseParallelProcessor") @@ -45,9 +45,28 @@ class AllPathsTuple(NamedTuple): kwargs: List[KwargsType] @classmethod - def empty(cls) -> "AllPathsTuple": + def new(cls) -> "AllPathsTuple": return AllPathsTuple([], [], [], []) + def __len__(self) -> int: + return len(self.src) + + @property + def empty(self) -> bool: + return len(self.src) == 0 + + def partition(self, k: int = 1) -> List["AllPathsTuple"]: + """Partition the paths into k / n slices containing k files each.""" + return [ + AllPathsTuple( + src=self.src[i : i + k], + dst=self.dst[i : i + k], + meta=self.meta[i : i + k], + kwargs=self.kwargs[i : i + k], + ) + for i in range(0, len(self.src), k) + ] + class BaseParallelProcessor: """A base parallel processor that supports applying the same process_single method to a list of files. @@ -60,6 +79,8 @@ class BaseParallelProcessor: See documentation of both methods for more details on how to implement them correctly. """ + PROGRESS_BAR_CLS: Type[BaseProgressBar] + def __init__( self, source_prefix: Union[str, List[str]], @@ -70,11 +91,16 @@ def __init__( seed: int = 0, pbar_timeout: float = 1e-3, ignore_existing: bool = False, + skip_source_glob: bool = False, + shuffle_src_paths: bool = True, include_paths: Optional[List[str]] = None, exclude_paths: Optional[List[str]] = None, files_regex_pattern: Optional[str] = None, - retries_on_error: int = 0, + batch_size: int = 1, process_single_kwargs: Union[None, KwargsType, List[KwargsType]] = None, + backoff_max_time: Optional[float] = None, + backoff_max_tries: int = 1, + backoff_exceptions: Optional[Union[Type[Exception], Tuple[Type[Exception], ...]]] = None, ): """Initialize the parallel processor. @@ -95,22 +121,27 @@ def __init__( seed (int, optional): The random seed to use when shuffling input files. Defaults to 0. pbar_timeout (float, optional): How often to update progress bars in seconds. Defaults to 0.01 seconds. + skip_source_glob (bool, optional): Do not glob source files. Off by default. ignore_existing (bool, optional): Whether to ignore files that have been already processed and re-run the processor on all files from scratch. Defaults to False. - include_paths (Optional[List[str]], optional): A list of paths to include. If provided, only files + shuffle_src_paths (bool, optional): Whether to shuffle the source paths before processing them. + Defaults to True. + include_paths (List[str], optional): A list of paths to include. If provided, only files that match one of the paths will be processed. Defaults to None. - exclude_paths (Optional[List[str]], optional): A list of paths to exclude. If provided, files that + exclude_paths (List[str], optional): A list of paths to exclude. If provided, files that match one of the paths will be skipped. Defaults to None. - files_regex_pattern (Optional[str], optional): A regex pattern to match files. If provided, only + files_regex_pattern (str, optional): A regex pattern to match files. If provided, only files that match the pattern will be processed. Defaults to None. - retries_on_error (int, optional): The number of retries to attempt if an error occurs. - Defaults to 0. - process_single_kwargs (Union[None, KwargsType, List[KwargsType]], optional): Additional kwargs to + batch_size: (int, optional): number of files to group in a single bat + process_single_kwargs (Union[None, KwargsType, List[KwargsType], optional): Additional kwargs to pass to the process_single method. If a single dict is provided, it will be used for all source prefixes. If a list of dicts is provided, each dict will be used for the corresponding source. By default, no additional kwargs are passed. + backoff_max_time (float, optional): The maximum time to backoff. Defaults to None. + backoff_max_tries (int, optional): The maximum number of tries to backoff. Defaults to 1. + backoff_exceptions (Union[Type[Exception], Tuple[Type[Exception], ...]], optional): The + exceptions to backoff on. Defaults to `dolma.core.errors.DolmaRetryableFailure`. """ - self.src_prefixes = [source_prefix] if isinstance(source_prefix, str) else source_prefix self.dst_prefixes = [destination_prefix] if isinstance(destination_prefix, str) else destination_prefix self.meta_prefixes = [metadata_prefix] if isinstance(metadata_prefix, str) else metadata_prefix @@ -120,10 +151,24 @@ def __init__( self.pbar_timeout = pbar_timeout self.ignore_existing = ignore_existing + self.logger = self.get_logger() + self.include_paths = set(include_paths) if include_paths is not None else None self.exclude_paths = set(exclude_paths) if exclude_paths is not None else None self.files_regex_pattern = re.compile(files_regex_pattern) if files_regex_pattern else None - self.retries_on_error = retries_on_error + self.shuffle_src_paths = shuffle_src_paths + + # this manages how many files to pass to a single processor + self.batch_size = batch_size + + # this controls backoff + self.backoff_max_time: float = float(backoff_max_time or "inf") + self.backoff_max_tries: int = int(backoff_max_tries) + self.backoff_exceptions: Tuple[Type[Exception], ...] = ( + (backoff_exceptions,) + if isinstance(backoff_exceptions, type) + else backoff_exceptions or (DolmaRetryableFailure,) + ) # this are additional kwargs to pass to the process_single method process_single_kwargs = process_single_kwargs or {} @@ -132,23 +177,8 @@ def __init__( else: self.process_single_kwargs = process_single_kwargs - # checking that the increment_progressbar method is subclassed correctly - sig = inspect.signature(self.increment_progressbar) - if "queue" not in sig.parameters or sig.parameters["queue"].kind != inspect.Parameter.POSITIONAL_ONLY: - raise AttributeError( - "increment_progressbar must have a positional-only argument named 'queue'; " - "Check that you have subclassed BaseParallelProcessor correctly!" - ) - if "kwargs" in sig.parameters and sig.parameters["kwargs"].kind == inspect.Parameter.VAR_KEYWORD: - raise AttributeError( - "increment_progressbar must not have a **kwargs argument; " - "Check that you have subclassed BaseParallelProcessor correctly!" - ) - if any(p.name != "queue" and p.default != 0 for p in sig.parameters.values()): - raise AttributeError( - "increment_progressbar must have a default value of 0 for all arguments except 'queue'; " - "Check that you have subclassed BaseParallelProcessor correctly!" - ) + if not hasattr(self, "PROGRESS_BAR_CLS"): + self.PROGRESS_BAR_CLS = BaseProgressBar.from_increment_function(self) if len(self.src_prefixes) != len(self.dst_prefixes): raise ValueError( @@ -169,13 +199,75 @@ def __init__( if len(self.src_prefixes) == 0: raise ValueError("At least one source prefix must be provided.") + self.skip_source_glob = skip_source_glob + if any("*" in p for p in itertools.chain(self.dst_prefixes, self.meta_prefixes)): raise ValueError("Destination and metadata prefixes cannot contain wildcards.") + if not hasattr(self, "PROGRESS_BAR_CLS"): + raise AttributeError("BaseParallelProcessor subclasses must define the PROGRESS_BAR_CLS attribute.") + + def __add__(self: BPP, other: BPP) -> BPP: + """Combine two parallel processors into one.""" + if not type(self) is type(other): + raise TypeError(f"Cannot add {type(self)} and {type(other)}") + + # we try combining the two list of include paths; if they are both None, then set the combo back to none + include_paths: Union[List[str], None] = [*(self.include_paths or []), *(other.include_paths or [])] + include_paths = sorted(set(include_paths or [])) if len(include_paths or []) else None + + # do the same for exclude paths + exclude_paths: Union[List[str], None] = [*(self.exclude_paths or []), *(other.exclude_paths or [])] + exclude_paths = sorted(set(exclude_paths or [])) if len(exclude_paths or []) else None + + # for the regex, do a simple or if both are set + regex_pattern: Union[str, None] = None + if self.files_regex_pattern and other.files_regex_pattern: + regex_pattern = "(" + self.files_regex_pattern.pattern + "|" + other.files_regex_pattern.pattern + ")" + elif self.files_regex_pattern: + regex_pattern = self.files_regex_pattern.pattern + elif other.files_regex_pattern: + regex_pattern = other.files_regex_pattern.pattern + + return type(self)( + source_prefix=[*self.src_prefixes, *other.src_prefixes], + destination_prefix=[*self.dst_prefixes, *other.dst_prefixes], + metadata_prefix=[*self.meta_prefixes, *other.meta_prefixes], + num_processes=max(self.num_processes, other.num_processes), + debug=self.debug or other.debug, + seed=self.seed, + pbar_timeout=max(self.pbar_timeout, other.pbar_timeout), + ignore_existing=self.ignore_existing or other.ignore_existing, + include_paths=include_paths, + exclude_paths=exclude_paths, + files_regex_pattern=regex_pattern, + batch_size=max(self.batch_size, other.batch_size), + process_single_kwargs=[*self.process_single_kwargs, *other.process_single_kwargs], + backoff_max_time=min(self.backoff_max_time, other.backoff_max_time), + backoff_max_tries=min(self.backoff_max_tries, other.backoff_max_tries), + backoff_exceptions=tuple(set(self.backoff_exceptions + other.backoff_exceptions)), + ) + + def __radd__(self: BPP, other: BPP) -> BPP: + """Combine two parallel processors into one.""" + return other.__add__(self) + @classmethod def get_logger(cls) -> logging.Logger: """Get the logger for the class.""" - return get_logger(cls.__name__) + return get_logger(cls.__name__, "info") + + @classmethod + def process_batch( + cls, + source_paths: List[str], + destination_paths: List[str], + queue: QueueType, + kwargs: List[Dict[str, Any]], + ): + """Process multiple files. Naively calls process_single for each file, but can be overridden.""" + for src_path, dst_path, single_kwargs in zip(source_paths, destination_paths, kwargs): + cls.process_single(source_path=src_path, destination_path=dst_path, queue=queue, **single_kwargs) @classmethod def process_single( @@ -199,36 +291,61 @@ def process_single( raise NotImplementedError() @classmethod - def _process_single_and_save_status( + def _log_backoff(cls, details: Details): + """Log backoff details.""" + message = ( + f"Backing off `{details['target'].__name__}` " + f"after {details['tries']:,} " + f"tries (wait: {details.get('wait', 0.0):.2f}s)" + ) + if ex := details.get("exception"): + # add details about the exception to the message + import traceback # pylint: disable=import-outside-toplevel + + message += " due to " + "\n".join(traceback.format_exception_only(ex)).strip() # type: ignore + + cls.get_logger().warning(message) + + @classmethod + def _process_batch_and_save_status( cls, - source_path: str, - destination_path: str, - metadata_path: str, + source_paths: List[str], + destination_paths: List[str], + metadata_paths: List[str], queue: QueueType, - serialized_kwargs: bytes, + serialized_kwargs: List[bytes], + backoff_max_time: float, + backoff_max_tries: int, + backoff_exceptions: Tuple[Type[Exception], ...], ): """A wrapper around process single that saves a metadata file if processing is successful.""" # make destination directory if it doesn't exist for the destination and metadata paths - mkdir_p(parent(destination_path)) - mkdir_p(parent(metadata_path)) - - kwargs = pickle.loads(serialized_kwargs) - retries_on_error = kwargs.get("retries_on_error", 0) + 1 - while True: - try: - cls.process_single( - source_path=source_path, destination_path=destination_path, queue=queue, **kwargs - ) - break - except DolmaRetryableFailure as exception: - retries_on_error -= 1 - if retries_on_error == 0: - raise DolmaError from exception + for path in itertools.chain(destination_paths, metadata_paths): + mkdir_p(parent(path)) + + # we unpickle the serialized kwargs + deserialized_kwargs = [pickle.loads(kw) for kw in serialized_kwargs] + + # use backoff library to retry on failure; function _log_backoff is called on backoff + # to inform the user of the backoff details. + fn_with_backoff = backoff.on_exception( + backoff.expo, + exception=backoff_exceptions, + max_tries=backoff_max_tries, + max_time=backoff_max_time, + on_backoff=cls._log_backoff, + )(cls.process_batch) + + # start processing the file here + fn_with_backoff( + source_paths=source_paths, destination_paths=destination_paths, queue=queue, kwargs=deserialized_kwargs + ) - # write the metadata file - with smart_open.open(metadata_path, "wt") as f: - f.write(datetime.now().isoformat()) + # write the metadata files + for path in metadata_paths: + with smart_open.open(path, "wt") as f: + f.write(datetime.now().isoformat()) @classmethod def increment_progressbar(cls, queue: QueueType, /, **kwargs: int) -> Dict[str, int]: @@ -247,126 +364,7 @@ def increment_progressbar(self, queue, /, files = 0, documents = 0): # we use queue.put(tuple(kwargs.get(k, 0) for k in kwargs)) return kwargs - @classmethod - def _run_threaded_progressbar( - cls, - queue: QueueType, - timeout: float, - ): - """Run a progress bar in a separate thread. - - Args: - queue (QueueType): The queue to increment the progress bars. - timeout (float): How often to update the progress bars in seconds. - """ - - sample_queue_output = cls.increment_progressbar(queue) - - with ExitStack() as stack: - pbars = [ - stack.enter_context( - tqdm.tqdm(desc=str(k), unit=str(k)[:1], position=i, unit_scale=True) # pyright: ignore - ) - for i, k in enumerate(sample_queue_output) - ] - - while True: - item = queue.get() - if item is None: - break - - for pbar, value in zip(pbars, item): - pbar.update(value) - - time.sleep(timeout) - - def _debug_run_all( - self, - all_source_paths: List[str], - all_destination_paths: List[str], - all_metadata_paths: List[str], - all_process_kwargs: Union[List[KwargsType], None] = None, - **process_single_kwargs: Any, - ): - """Run files one by one on the main process - - Args: - all_source_paths (List[MultiPath]): The list of source paths to process. - all_destination_paths (List[MultiPath]): The list of destination paths to save. - all_metadata_paths (List[MultiPath]): The locations where to save metadata. - all_process_kwargs (Union[List[KwargsType], None]): Additional kwargs to pass to the process_single - """ - - arguments_iterator = zip( - # source paths - all_source_paths, - # destination paths - all_destination_paths, - # this is where we save the metadata to keep track of which files have been processed - all_metadata_paths, - # additional kwargs to pass to the process_single; if not provided, we use an empty dict - # will be merged with the process_single_kwargs - all_process_kwargs or [{} for _ in all_source_paths], - ) - pbar_queue: QueueType = Queue() - thread = Thread(target=self._run_threaded_progressbar, args=(pbar_queue, self.pbar_timeout), daemon=True) - thread.start() - - for source_path, destination_path, metadata_path, process_kwargs in arguments_iterator: - self._process_single_and_save_status( - source_path=source_path, - destination_path=destination_path, - metadata_path=metadata_path, - queue=pbar_queue, - serialized_kwargs=pickle.dumps({**process_kwargs, **process_single_kwargs}), - ) - - pbar_queue.put(None) - thread.join() - - def __add__(self: BPP, other: BPP) -> BPP: - """Combine two parallel processors into one.""" - if not type(self) is type(other): - raise TypeError(f"Cannot add {type(self)} and {type(other)}") - - # we try combining the two list of include paths; if they are both None, then set the combo back to none - include_paths: Union[List[str], None] = [*(self.include_paths or []), *(other.include_paths or [])] - include_paths = sorted(set(include_paths or [])) if len(include_paths or []) else None - - # do the same for exclude paths - exclude_paths: Union[List[str], None] = [*(self.exclude_paths or []), *(other.exclude_paths or [])] - exclude_paths = sorted(set(exclude_paths or [])) if len(exclude_paths or []) else None - - # for the regex, do a simple or if both are set - regex_pattern: Union[str, None] = None - if self.files_regex_pattern and other.files_regex_pattern: - regex_pattern = "(" + self.files_regex_pattern.pattern + "|" + other.files_regex_pattern.pattern + ")" - elif self.files_regex_pattern: - regex_pattern = self.files_regex_pattern.pattern - elif other.files_regex_pattern: - regex_pattern = other.files_regex_pattern.pattern - - return type(self)( - source_prefix=[*self.src_prefixes, *other.src_prefixes], - destination_prefix=[*self.dst_prefixes, *other.dst_prefixes], - metadata_prefix=[*self.meta_prefixes, *other.meta_prefixes], - num_processes=max(self.num_processes, other.num_processes), - debug=self.debug or other.debug, - seed=self.seed, - pbar_timeout=max(self.pbar_timeout, other.pbar_timeout), - ignore_existing=self.ignore_existing or other.ignore_existing, - include_paths=include_paths, - exclude_paths=exclude_paths, - files_regex_pattern=regex_pattern, - retries_on_error=max(self.retries_on_error, other.retries_on_error), - process_single_kwargs=[*self.process_single_kwargs, *other.process_single_kwargs], - ) - - def __radd__(self: BPP, other: BPP) -> BPP: - """Combine two parallel processors into one.""" - return other.__add__(self) - - def _multiprocessing_run_all( + def _run_all( self, all_source_paths: List[str], all_destination_paths: List[str], @@ -389,47 +387,51 @@ def _multiprocessing_run_all( all_process_kwargs = all_process_kwargs or [{} for _ in all_source_paths] - arguments_iterator = zip( - # source paths - all_source_paths, - # destination paths - all_destination_paths, - # this is where we save the metadata to keep track of which files have been processed - all_metadata_paths, - # additional kwargs to pass to the process_single; if not provided, we use an empty dict - # will be merged with the process_single_kwargs - all_process_kwargs, + batches = list( + batch_iterator( + # source paths + all_source_paths, + # destination paths + all_destination_paths, + # this is where we save the metadata to keep track of which files have been processed + all_metadata_paths, + # additional kwargs to pass to the process_single; if not provided, we use an empty dict + # will be merged with the process_single_kwargs + all_process_kwargs, + # batch size is equal to 1 by default + batch_size=self.batch_size, + ) ) + self.logger.info("Processing in %s batches", len(batches)) - # no need to be wasteful with processes: we only need as many cores a the minimum of the number of - # source paths, destination paths, metadata paths, and process kwargs. - num_processes = min( - self.num_processes, - len(all_source_paths), - len(all_destination_paths), - len(all_metadata_paths), - len(all_process_kwargs), - ) + # no need to be wasteful with processes: we only need as many cores a the number of batches + num_processes = min(self.num_processes, len(batches)) + self.logger.info("Using %s processes", num_processes) - with multiprocessing.Pool(processes=num_processes) as pool: - pbar_queue: QueueType = (manager := multiprocessing.Manager()).Queue() - thread = Thread( - target=self._run_threaded_progressbar, args=(pbar_queue, self.pbar_timeout), daemon=True - ) - thread.start() + with PoolWithDebug(processes=num_processes, debug=self.debug) as pool: + pbar_queue: QueueType = (manager := get_manager(pool)).Queue() + (pbar := self.PROGRESS_BAR_CLS(pbar_queue)).start() process_single_fn = partial(self.process_single, queue=pbar_queue) results = [] - for source_path, destination_path, metadata_path, process_kwargs in arguments_iterator: + for source_paths, destination_paths, metadata_paths, process_kwargs in batches: + # we need to merge the process_single_kwargs with the additional kwargs + # mypy is confused by the type of process_kwargs; we need to ignore the error + serialized_kwargs = [ + pickle.dumps({**kw, **process_single_kwargs}) for kw in process_kwargs # type: ignore + ] + process_single_fn = partial( - self._process_single_and_save_status, + self._process_batch_and_save_status, queue=pbar_queue, - source_path=source_path, - destination_path=destination_path, - metadata_path=metadata_path, - # we need to merge the process_single_kwargs with the additional kwargs - serialized_kwargs=pickle.dumps({**process_kwargs, **process_single_kwargs}), + source_paths=source_paths, # pyright: ignore + destination_paths=destination_paths, # pyright: ignore + metadata_paths=metadata_paths, # pyright: ignore + serialized_kwargs=serialized_kwargs, + backoff_max_time=self.backoff_max_time, + backoff_max_tries=self.backoff_max_tries, + backoff_exceptions=self.backoff_exceptions, ) result = pool.apply_async(process_single_fn) results.append(result) @@ -439,9 +441,7 @@ def _multiprocessing_run_all( pool.close() pool.join() - - pbar_queue.put(None) - thread.join() + pbar.stop() manager.shutdown() def _valid_path(self, path: str) -> bool: @@ -453,14 +453,14 @@ def _valid_path(self, path: str) -> bool: return False return True - def _get_all_paths(self) -> AllPathsTuple: + def _get_all_paths(self) -> Tuple[AllPathsTuple, bool]: """Get all paths to process using prefixes provided""" - all_paths = AllPathsTuple.empty() + all_paths = AllPathsTuple.new() for src_prefix, dst_prefix, meta_prefix, kwargs_prefix in zip( self.src_prefixes, self.dst_prefixes, self.meta_prefixes, self.process_single_kwargs ): - current_source_prefixes = sorted(glob_path(src_prefix)) + current_source_prefixes = sorted([src_prefix] if self.skip_source_glob else glob_path(src_prefix)) if len(current_source_prefixes) > 1: # make relative only makes sense if there is more than one path; otherwise, it's unclear @@ -474,45 +474,49 @@ def _get_all_paths(self) -> AllPathsTuple: else: raise ValueError(f"Could not find any files matching {src_prefix}") - # shuffle the order of the files so time estimation in progress bars is more accurate - random.shuffle(rel_paths) + if self.shuffle_src_paths: + # shuffle the order of the files so time estimation in progress bars is more accurate + random.shuffle(rel_paths) - # get a list of which metadata files already exist - existing_metadata_names = set( - re.sub(rf"{METADATA_SUFFIX}$", "", sub_prefix(path, meta_prefix)) - for path in glob_path(meta_prefix) - ) + # # get a list of which metadata files already exist + some_already_processed = False for path in rel_paths: - if not self.ignore_existing and path in existing_metadata_names: - continue + metadata_path = add_suffix(meta_prefix, path) + METADATA_SUFFIX if not self._valid_path(path): + # invalid path; skip + continue + + if not self.ignore_existing and exists(metadata_path): + # metadata file exists, which indicates that the file has already been processed + some_already_processed = True continue # create new paths to pass to taggers all_paths.src.append(add_suffix(prefix, path)) all_paths.dst.append(add_suffix(dst_prefix, path)) - all_paths.meta.append(add_suffix(meta_prefix, path) + METADATA_SUFFIX) + all_paths.meta.append(metadata_path) all_paths.kwargs.append(kwargs_prefix or {}) - return all_paths + return all_paths, some_already_processed def __call__(self, **process_single_kwargs: Any): """Run the processor.""" random.seed(self.seed) - # in case the user wants to override the default kwargs for retries - process_single_kwargs.setdefault("retries_on_error", self.retries_on_error) - - all_paths = self._get_all_paths() + all_paths, some_already_processed = self._get_all_paths() + self.logger.info("Found %s files to process", len(all_paths.src)) - print(f"Found {len(all_paths.src):,} files to process") - - fn = self._debug_run_all if self.debug else self._multiprocessing_run_all + if all_paths.empty: + if some_already_processed: + self.logger.info("All files already processed; skipping.") + return + else: + raise DolmaError("No files found to process.") - fn( + self._run_all( all_source_paths=all_paths.src, all_destination_paths=all_paths.dst, all_metadata_paths=all_paths.meta, diff --git a/python/dolma/core/progressbar.py b/python/dolma/core/progressbar.py new file mode 100644 index 00000000..26fbdf91 --- /dev/null +++ b/python/dolma/core/progressbar.py @@ -0,0 +1,276 @@ +import multiprocessing +import time +import warnings +from contextlib import ExitStack +from functools import reduce +from hashlib import sha1 +from inspect import Parameter, get_annotations +from inspect import signature as get_signature # type: ignore +from queue import Queue +from threading import Thread +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type + +import tqdm +from typing_extensions import TypeAlias, Union + +from .loggers import get_logger + +if TYPE_CHECKING: + from .parallel import BaseParallelProcessor + + +QueueType: TypeAlias = "Queue[Union[None, Tuple[int, ...]]]" + + +class BaseProgressBar: + """One or more progress bars that track progress of a process. + + This class is meant to be subclassed. The subclass must provide one or more attributes of type int, e.g. + + ```python + class MyProgressBar(BaseProgressBar): + files: int = 0 + documents: int = 0 + ``` + + This class can be used for both adding and running through the progress bars. To start: + + ```python + queue = Queue() + pb = MyProgressBar(queue) + pb.start() + + ... # do some work + + pb.stop() + ``` + + it can also be used in a multiprocessing context: + + ```python + with Pool(processes=4) as pool: + queue = mutliprocessing.Manager().Queue() + pb = MyProgressBar(queue) + pb.start() + + ... # do some work + + pool.close() + pool.join() + pb.stop() + ``` + + If you want to use this class to update a queue: + + ```python + pb = MyProgressBar(queue) + pb.files += 1 + pb.documents += 100 + ``` + """ + + def __init__(self, queue: QueueType, min_step: int = 1, min_time: float = 1e-1, thread: bool = False): + """ + Initialize the ProgressBar object. + + Args: + queue (QueueType): The queue object to track progress. + min_step (int, optional): The minimum step size for progress updates. Defaults to 1. + min_time (float, optional): The minimum time interval between progress updates. Defaults to 1e-1. + thread (bool, optional): Whether to start the progress bar or use object as client. Defaults to False. + """ + self._logger = get_logger(self.__class__.__name__, "warn") + self._queue = queue + self._last_update_delta_time = 0 + self._last_update_delta_step = 0 + + self._update_every_seconds = min_time + self._update_every_steps = min_step + + for field in self.fields(): + setattr(self, field, 0) + + self._thread = ( + Thread( + target=self._run, + kwargs={"queue": queue, "update_every_seconds": min_time, "fields": self.fields()}, + daemon=True, + ) + if thread + else None + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"{', '.join(f'{k}={getattr(self, k)}' for k in self.fields())};" + f" min_step={self._update_every_steps}, min_time={self._update_every_seconds})" + ")" + ) + + def __str__(self) -> str: + return self.__repr__() + + def __setattr__(self, name: str, value: Any) -> None: + super().__setattr__(name, value) + if name in self.fields() and value > 0: + self.update() + + @classmethod + def from_increment_function(cls, processor: "BaseParallelProcessor") -> "Type[BaseProgressBar]": + # print deprecation warning + msg = ( + "Deriving progress bar from `increment_progressbar` is deprecated; add a `PROGRESS_BAR_CLS` " + f"attribute to {type(processor).__name__} instead." + ) + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) + + # checking that the increment_progressbar method is subclassed correctly + sig = get_signature(processor.increment_progressbar) + if "queue" not in sig.parameters or sig.parameters["queue"].kind != Parameter.POSITIONAL_ONLY: + raise AttributeError( + "increment_progressbar must have a positional-only argument named 'queue'; " + "Check that you have subclassed BaseParallelProcessor correctly!" + ) + if "kwargs" in sig.parameters and sig.parameters["kwargs"].kind == Parameter.VAR_KEYWORD: + raise AttributeError( + "increment_progressbar must not have a **kwargs argument; " + "Check that you have subclassed BaseParallelProcessor correctly!" + ) + if any(p.name != "queue" and p.default != 0 for p in sig.parameters.values()): + raise AttributeError( + "increment_progressbar must have a default value of 0 for all arguments except 'queue'; " + "Check that you have subclassed BaseParallelProcessor correctly!" + ) + params = sorted(k for k, p in sig.parameters.items() if k != "queue" and p.kind != Parameter.empty) + h = reduce(lambda h, e: h.update(e.encode()) or h, params, sha1()).hexdigest() # type: ignore + + # create a new class + cls_dict = {"__annotations__": {k: int for k in params}, **{p: 0 for p in params}} + new_cls = type(f"{cls.__name__}{h[-6:]}", (cls,), cls_dict) + return new_cls + + @classmethod + def fields(cls) -> Tuple[str, ...]: + """ + Returns a tuple of field names in the class that are of type int. + + Raises: + ValueError: If the class does not have at least one field of type int. + + Returns: + Tuple[str, ...]: A tuple of field names. + """ + fields: Optional[Tuple[str, ...]] = cls.__dict__.get("__fields__") + + if fields is None: + annotations = get_annotations(cls) + fields = tuple(sorted(n for n, t in annotations.items() if issubclass(t, int))) + setattr(cls, "__fields__", fields) + + if len(fields) == 0: + raise ValueError(f"Class {cls.__name__} must have at least one field of type int.") + + return fields + + @classmethod + def parse(cls, values: Optional[Tuple[int, ...]]) -> Dict[str, int]: + """ + Parses the value from the queue and returns a dictionary mapping field names to their corresponding values. + + Args: + values (Optional[Tuple[int, ...]]): The values to be parsed for the queue. + + Returns: + Dict[str, int]: A dictionary mapping field names to their corresponding values. + """ + if not values: + return {k: 0 for k in cls.fields()} + return {k: v for k, v in zip(cls.fields(), values)} + + def _update(self): + # get the current values + update = tuple(getattr(self, k, 0) for k in self.fields()) + + # time to do an update + self._queue.put_nowait(update) + + # reset the steps + self._last_update_delta_step = 0 + + # reset the steps + for k in self.fields(): + setattr(self, k, 0) + + def update(self): + # update the number of steps since the last update + self._last_update_delta_step += 1 + + if self._update_every_steps > self._last_update_delta_step: + return + + self._update() + + # check if we wanna update frequency based on steps + if self._queue.qsize() >= multiprocessing.cpu_count(): + self._update_every_steps *= 2 + return + + # check if we wanna update frequency based on time + self._last_update_delta_time = -(time.time() - self._last_update_delta_time) + if self._last_update_delta_time < self._update_every_seconds: + self._update_every_steps *= 2 + return + + @staticmethod + def _run(queue: QueueType, update_every_seconds: float, fields: Tuple[str, ...]): + """ + Runs the progress bar. + + This method initializes and updates the progress bars based on the items in the queue. + It continuously retrieves items from the queue and updates the progress bars accordingly. + The method exits when a `None` item is retrieved from the queue. + + Returns: + None + """ + with ExitStack() as stack: + pbars = [ + stack.enter_context(tqdm.tqdm(desc=k, unit=k[:1], position=i, unit_scale=True)) # pyright: ignore + for i, k in enumerate(fields) + ] + + while True: + # loop until we get a None + item = queue.get() + if item is None: + break + + for pbar, value in zip(pbars, item): + pbar.update(value) + + time.sleep(update_every_seconds) + + def start(self): + """Run the progress bar in a separate thread.""" + if self._thread: + self._thread.start() + + def stop(self): + """Stop the progress bar. + + This method stops the progress bar by adding a `None` item to the queue and joining the thread. + """ + self._update() + + if self._thread is not None: + self._queue.put(None) + time.sleep(self._update_every_seconds * 2) + self._thread.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.stop() diff --git a/python/dolma/core/utils.py b/python/dolma/core/utils.py index 2f5c5eb6..16080ef4 100644 --- a/python/dolma/core/utils.py +++ b/python/dolma/core/utils.py @@ -3,17 +3,12 @@ import re import string import sys -from typing import List, Union, cast - -try: - import blingfire - - BLINGFIRE_AVAILABLE = True -except Exception: - BLINGFIRE_AVAILABLE = False +from itertools import islice +from typing import Generator, Iterable, List, Tuple, TypeVar, Union, cast import nltk import uniseg.wordbreak +from necessary import necessary from nltk.tokenize.punkt import PunktSentenceTokenizer from omegaconf import OmegaConf as om @@ -22,13 +17,26 @@ except LookupError: nltk.download("punkt") - -from .data_types import TextSlice +from .data_types import Span, TextSlice from .loggers import get_logger +try: + import blingfire + + BLINGFIRE_AVAILABLE = True +except (ImportError, OSError): + BLINGFIRE_AVAILABLE = False + + sent_tokenizer = PunktSentenceTokenizer() logger = get_logger(__name__) +T = TypeVar("T") + + +# digits after the decimal point +TAGGER_SCORE_PRECISION = 5 + def make_variable_name(name: str, remove_multiple_underscores: bool = False) -> str: # use underscores for any non-valid characters in variable name @@ -44,6 +52,16 @@ def make_variable_name(name: str, remove_multiple_underscores: bool = False) -> return name +def format_span_output(span: Span) -> Tuple[int, int, float]: + """Formats a span for output.""" + return (span.start, span.end, round(float(span.score), TAGGER_SCORE_PRECISION)) + + +def format_span_key(experiment: str, tagger: str, span: Span) -> str: + """Formats a span key for output.""" + return f"{experiment}__{tagger}__{make_variable_name(span.type)}" + + def split_words(text: str, remove_empty: bool = True) -> List[TextSlice]: """ Split a string into words, as defined by the unicode standard. @@ -134,7 +152,7 @@ def import_modules(modules_path: Union[List[str], None]): sys.path.insert(0, module_parent) importlib.import_module(module_name) elif module_path in sys.modules[module_name].__path__: - logger.info(f"{module_path} has already been imported.") + logger.info("%s has already been imported.", module_path) else: raise ImportError( f"Failed to import {module_path} because the corresponding module name " @@ -148,3 +166,45 @@ def dataclass_to_dict(dataclass_instance) -> dict: # force typecasting because a dataclass instance will always be a dict return cast(dict, om.to_object(om.structured(dataclass_instance))) + + +def batch_iterator( + *iterables: Iterable[T], batch_size: int = 1, drop_last: bool = False +) -> Generator[List[Tuple[T, ...]], None, None]: + """ + Group one or more iterables into batches of size `batch_size`. + + Args: + iterables (Iterable[T]): One or more iterables to group into batches. + batch_size (int): The size of each batch. Defaults to 1. + drop_last (bool): Whether to drop the last batch if it is smaller than `batch_size`. Defaults to False. + """ + grouped_iterator = iter(zip(*iterables)) + while True: + batch = list(islice(grouped_iterator, batch_size)) + if not batch: + break + if len(batch) < batch_size and drop_last: + break + yield list(zip(*batch)) + + +with necessary(("smart_open", "7.0.4"), soft=True) as SMART_OPEN_NO_ZSTD: + if SMART_OPEN_NO_ZSTD: + import io + + import zstandard + from smart_open import register_compressor + + def _handle_zstd(file_obj, mode): + result = zstandard.open(filename=file_obj, mode=mode) + # zstandard.open returns an io.TextIOWrapper in text mode, but otherwise + # returns a raw stream reader/writer, and we need the `io` wrapper + # to make FileLikeProxy work correctly. + if "b" in mode and "w" in mode: + result = io.BufferedWriter(result) + elif "b" in mode and "r" in mode: + result = io.BufferedReader(result) + return result + + register_compressor(".zst", _handle_zstd) diff --git a/tests/python/test_parallel.py b/tests/python/test_parallel.py index 1287247a..6ee0636b 100644 --- a/tests/python/test_parallel.py +++ b/tests/python/test_parallel.py @@ -3,12 +3,14 @@ import os from pathlib import Path from tempfile import TemporaryDirectory +from time import sleep from typing import Any from unittest import TestCase import smart_open from dolma.core.parallel import BaseParallelProcessor, QueueType +from dolma.core.progressbar import BaseProgressBar LOCAL_DATA = Path(__file__).parent.parent / "data" @@ -31,7 +33,80 @@ def process_single( queue.put((1,)) +class MockPbar(BaseProgressBar): + a: int = 0 + b: int = 0 + + +class NewStyleMockProcessor(BaseParallelProcessor): + PROGRESS_BAR_CLS = MockPbar + + @classmethod + def process_single( + cls, + source_path: str, + destination_path: str, + queue: QueueType, + **kwargs: Any, + ): + with MockPbar(queue) as pbar: + for _ in range(10): + pbar.a += 1 + pbar.b += 5 + + +class MockProcessorWithFail(MockProcessor): + @classmethod + def process_single( + cls, + source_path: str, + destination_path: str, + queue: QueueType, + **kwargs: Any, + ): + sleep(1) + raise ValueError(f"Failed on {source_path}") + + class TestParallel(TestCase): + def _read(self, path): + with smart_open.open(path, "rb") as f: + return f.read() + + def test_new_style(self): + with TemporaryDirectory() as d: + proc = NewStyleMockProcessor( + source_prefix=str(LOCAL_DATA / "expected"), + destination_prefix=f"{d}/destination", + metadata_prefix=f"{d}/metadata", + ignore_existing=False, + ) + proc() + + def test_debug(self): + with self.assertRaises(ValueError): + MockProcessor(source_prefix=[], destination_prefix=[], metadata_prefix=[]) + + with TemporaryDirectory() as d: + proc = MockProcessor( + source_prefix=str(LOCAL_DATA / "expected"), + destination_prefix=f"{d}/destination", + metadata_prefix=f"{d}/metadata", + ignore_existing=False, + debug=True, + ) + proc() + src = [p for p in os.listdir(LOCAL_DATA / "expected") if not p.startswith(".")] + meta = [p.rstrip(".done.txt") for p in os.listdir(f"{d}/metadata")] + dest = [p for p in os.listdir(f"{d}/destination") if not p.startswith(".")] + self.assertEqual(sorted(src), sorted(meta)) + self.assertEqual(sorted(src), sorted(dest)) + + for s, e in zip(src, dest): + s_ = LOCAL_DATA / "expected" / s + e_ = f"{d}/destination/{e}" + self.assertEqual(self._read(s_), self._read(e_)) + def test_base_parallel_processor(self): with self.assertRaises(ValueError): MockProcessor(source_prefix=[], destination_prefix=[], metadata_prefix=[]) @@ -42,6 +117,7 @@ def test_base_parallel_processor(self): destination_prefix=f"{d}/destination", metadata_prefix=f"{d}/metadata", ignore_existing=False, + num_processes=2, ) proc() src = [p for p in os.listdir(LOCAL_DATA / "expected") if not p.startswith(".")] @@ -50,6 +126,12 @@ def test_base_parallel_processor(self): self.assertEqual(sorted(src), sorted(meta)) self.assertEqual(sorted(src), sorted(dest)) + for s, e in zip(src, dest): + s_ = LOCAL_DATA / "expected" / s + e_ = f"{d}/destination/{e}" + self.assertEqual(self._read(s_), self._read(e_)) + + def test_two_stages(self): with TemporaryDirectory() as d: proc = MockProcessor( source_prefix=str(LOCAL_DATA / "expected" / "*-paragraphs.*"), @@ -63,3 +145,34 @@ def test_base_parallel_processor(self): dest = [p for p in os.listdir(f"{d}/destination")] self.assertEqual(sorted(src), sorted(meta)) self.assertEqual(sorted(src), sorted(dest)) + + proc = MockProcessor( + source_prefix=str(LOCAL_DATA / "expected" / "*"), + destination_prefix=f"{d}/destination", + metadata_prefix=f"{d}/metadata", + ignore_existing=False, + ) + proc() + + # the oldest two files are from the first stage + dest2 = sorted( + [p for p in os.listdir(f"{d}/destination")], key=lambda x: os.stat(f"{d}/destination/{x}").st_ctime + ) + self.assertEqual(sorted(dest), sorted(dest2[:2])) + + def test_failure(self): + with TemporaryDirectory() as d: + proc = MockProcessorWithFail( + source_prefix=str(LOCAL_DATA / "expected"), + destination_prefix=f"{d}/destination", + metadata_prefix=f"{d}/metadata", + ignore_existing=False, + backoff_exceptions=(ValueError,), + backoff_max_time=3, + backoff_max_tries=3, + debug=True, + ) + with self.assertRaises(ValueError): + proc() + self.assertEqual(len(os.listdir(f"{d}/destination")), 0) + self.assertEqual(len(os.listdir(f"{d}/metadata")), 0) diff --git a/tests/python/test_utils.py b/tests/python/test_utils.py index 38bf268d..9909f091 100644 --- a/tests/python/test_utils.py +++ b/tests/python/test_utils.py @@ -2,14 +2,14 @@ Tests for the utils module. -@kylel +@kylel, @soldni """ from unittest import TestCase from dolma.core.data_types import TextSlice -from dolma.core.utils import split_paragraphs, split_sentences +from dolma.core.utils import batch_iterator, split_paragraphs, split_sentences class TestUtils(TestCase): @@ -84,3 +84,44 @@ def test_split_sentences_with_newline_and_spaces(self): self.assertEqual(text[sentences[0].start : sentences[0].end], sentences[0].text) self.assertEqual(sentences[1].text, "This is another sentence.") self.assertEqual(text[sentences[1].start : sentences[1].end], sentences[1].text) + + +class TestBatching(TestCase): + def test_batching(self): + a = [1, 2, 3, 4, 5] + b = [6, 7, 8, 9, 0] + + output = list(batch_iterator(a, b, batch_size=2)) + self.assertEqual(len(output), 3) + self.assertEqual(output[0], [(1, 2), (6, 7)]) + self.assertEqual(output[1], [(3, 4), (8, 9)]) + self.assertEqual(output[2], [(5,), (0,)]) + + def test_single_batching(self): + a = [1, 2, 3, 4, 5] + + output = list(batch_iterator(a, batch_size=2)) + + self.assertEqual(len(output), 3) + self.assertEqual(output[0], [(1, 2)]) + self.assertEqual(output[1], [(3, 4)]) + self.assertEqual(output[2], [(5,)]) + + def test_longer_batch_than_slice(self): + a = list(range(3)) + b = list(range(3, 6)) + c = list(range(6, 9)) + + output = list(batch_iterator(a, b, c, batch_size=4)) + + self.assertEqual(len(output), 1) + self.assertEqual(output[0], [(0, 1, 2), (3, 4, 5), (6, 7, 8)]) + + def test_drop_last(self): + a = [1, 2, 3, 4, 5] + b = [6, 7, 8, 9, 0] + + output = list(batch_iterator(a, b, batch_size=2, drop_last=True)) + self.assertEqual(len(output), 2) + self.assertEqual(output[0], [(1, 2), (6, 7)]) + self.assertEqual(output[1], [(3, 4), (8, 9)]) diff --git a/tests/python/utils.py b/tests/python/utils.py index 47ddcd18..086f69ea 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -12,7 +12,7 @@ import boto3 import smart_open -from smart_open import open +import yaml from dolma.core.paths import glob_path, mkdir_p @@ -21,6 +21,7 @@ DOLMA_TESTS_S3_PREFIX_DEFAULT = "s3://dolma-tests" LOGGER = logging.getLogger(__name__) +LOGGER.setLevel(logging.INFO) def parse_s3_path(s3_path: str) -> Tuple[str, str]: @@ -64,9 +65,8 @@ def get_test_prefix() -> str: def skip_aws_tests() -> bool: - dolma_tests_skip = os.environ.get(DOLMA_TESTS_SKIP_AWS_ENV_VAR) - LOGGER.info(f"{DOLMA_TESTS_SKIP_AWS_ENV_VAR}: {dolma_tests_skip}") - return (dolma_tests_skip or "false").lower() == "true" + dolma_tests_skip = yaml.safe_load(os.environ.get(DOLMA_TESTS_SKIP_AWS_ENV_VAR) or "false") + return bool(dolma_tests_skip) def upload_test_documents(local_input: str, test_prefix: str) -> Tuple[str, str]: From 67b3bda154208d44079ffdf9798d9402c2a18031 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 22 May 2024 18:03:09 -0700 Subject: [PATCH 02/14] added support for retries_on_error --- python/dolma/core/parallel.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index 1823b791..b591faac 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -100,6 +100,7 @@ def __init__( process_single_kwargs: Union[None, KwargsType, List[KwargsType]] = None, backoff_max_time: Optional[float] = None, backoff_max_tries: int = 1, + retries_on_error: Optional[int] = None, backoff_exceptions: Optional[Union[Type[Exception], Tuple[Type[Exception], ...]]] = None, ): """Initialize the parallel processor. @@ -141,6 +142,8 @@ def __init__( backoff_max_tries (int, optional): The maximum number of tries to backoff. Defaults to 1. backoff_exceptions (Union[Type[Exception], Tuple[Type[Exception], ...]], optional): The exceptions to backoff on. Defaults to `dolma.core.errors.DolmaRetryableFailure`. + retries_on_error (int, optional): Deprecated. The number of retries to attempt on error. + Defaults to None. """ self.src_prefixes = [source_prefix] if isinstance(source_prefix, str) else source_prefix self.dst_prefixes = [destination_prefix] if isinstance(destination_prefix, str) else destination_prefix @@ -161,6 +164,13 @@ def __init__( # this manages how many files to pass to a single processor self.batch_size = batch_size + if retries_on_error is not None: + self.logger.warning( + "The `retries_on_error` parameter is deprecated and will be removed in a future release. " + "Please use `backoff_max_tries` instead." + ) + backoff_max_tries = retries_on_error + 1 + # this controls backoff self.backoff_max_time: float = float(backoff_max_time or "inf") self.backoff_max_tries: int = int(backoff_max_tries) From 155319c9f4267c6e9fd93ca0bd49413f36dba47c Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 22 May 2024 18:14:37 -0700 Subject: [PATCH 03/14] data --- python/dolma/tokenizer/executor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/dolma/tokenizer/executor.py b/python/dolma/tokenizer/executor.py index 61c54854..4291670d 100644 --- a/python/dolma/tokenizer/executor.py +++ b/python/dolma/tokenizer/executor.py @@ -256,8 +256,7 @@ def __call__(self, num_readers: Optional[int] = None, **process_single_kwargs: A ) # finally run the processors - fn = self._debug_run_all if self.debug else self._multiprocessing_run_all - fn( + self._run_all( all_source_paths=source_indices, all_destination_paths=all_destination_paths, all_metadata_paths=all_metadata_path, From d8cb6811b8a0982db83d7c9df1a1e34b9b67b505 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 22 May 2024 18:25:49 -0700 Subject: [PATCH 04/14] deps --- pyproject.toml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7be2a77e..978561e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "omegaconf>=2.3.0", # "pycld2==0.41", # "pycld3==0.22", # does not install correctly + "hyperscan>=0.7.0", "platformdirs>=4.2.0", "pyyaml", "requests", @@ -30,6 +31,8 @@ dependencies = [ "numpy", "necessary>=0.4.3", "charset-normalizer>=3.2.0", + "zstandard>=0.20.0", + "backoff>=2.0.0", ] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -99,7 +102,7 @@ dolma = "dolma.cli.__main__:main" [project.optional-dependencies] dev = [ - "black>=22.6.0", + "black[jupyter]>=22.6.0", "flake8>=5.0", "flake8-pyi>=22.8.1", "Flake8-pyproject>=1.1.0", @@ -127,7 +130,6 @@ warc = [ "fastwarc", "w3lib", "url-normalize", - ] trafilatura = [ # must include warc dependencies @@ -159,7 +161,7 @@ all = [ [build-system] requires = [ - "maturin[patchelf]>=1.1,<2.0", + "maturin>=1.1,<2.0", "setuptools >= 61.0.0", "wheel" ] From e6270dc16faae87ed37b5406c8bcb07f087814df Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 22 May 2024 22:16:23 -0700 Subject: [PATCH 05/14] get_annotations not available --- python/dolma/core/progressbar.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/dolma/core/progressbar.py b/python/dolma/core/progressbar.py index 26fbdf91..c4878f7d 100644 --- a/python/dolma/core/progressbar.py +++ b/python/dolma/core/progressbar.py @@ -4,7 +4,7 @@ from contextlib import ExitStack from functools import reduce from hashlib import sha1 -from inspect import Parameter, get_annotations +from inspect import Parameter from inspect import signature as get_signature # type: ignore from queue import Queue from threading import Thread @@ -164,8 +164,7 @@ def fields(cls) -> Tuple[str, ...]: fields: Optional[Tuple[str, ...]] = cls.__dict__.get("__fields__") if fields is None: - annotations = get_annotations(cls) - fields = tuple(sorted(n for n, t in annotations.items() if issubclass(t, int))) + fields = tuple(sorted(n for n, t in cls.__annotations__.items() if issubclass(t, int))) setattr(cls, "__fields__", fields) if len(fields) == 0: From 75a5b0db4266eb74e9178970a16b7d4b8f65e46f Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 22 May 2024 22:27:37 -0700 Subject: [PATCH 06/14] fixes --- .devcontainer/postInstall.sh | 2 +- Makefile | 2 +- pyproject.toml | 2 +- python/dolma/core/mp_tools.py | 2 +- python/dolma/core/progressbar.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.devcontainer/postInstall.sh b/.devcontainer/postInstall.sh index cf3761a9..f2b12ea5 100755 --- a/.devcontainer/postInstall.sh +++ b/.devcontainer/postInstall.sh @@ -2,4 +2,4 @@ PATH=/home/vscode/.cargo/bin:$PATH cd dolma -source /home/vscode/miniforge3/bin/activate && pip install cmake "maturin[patchelf]>=1.1,<2.0" +source /home/vscode/miniforge3/bin/activate && pip install cmake "maturin>=1.5,<2.0" diff --git a/Makefile b/Makefile index d7e2a73a..7a508485 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ setup: $(shell "${PROTOBUF_SETUP}") $(shell "${OPENSSL_SETUP}") which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - which maturin || pip install maturin[patchelf] + which maturin || pip install 'maturin>=1.5,<2.0' publish: maturin publish diff --git a/pyproject.toml b/pyproject.toml index 978561e2..fe57a13a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,7 +161,7 @@ all = [ [build-system] requires = [ - "maturin>=1.1,<2.0", + "maturin>=1.5,<2.0", "setuptools >= 61.0.0", "wheel" ] diff --git a/python/dolma/core/mp_tools.py b/python/dolma/core/mp_tools.py index f88477e4..98c8faf0 100644 --- a/python/dolma/core/mp_tools.py +++ b/python/dolma/core/mp_tools.py @@ -87,7 +87,7 @@ def __enter__(self): def Manager(self): if self._manager is None: self._manager = ( - ManagerWithDebug() # type: ignore + ManagerWithDebug() # pyright: ignore if self.debug else self.stack.enter_context(multiprocessing.Manager()) ) diff --git a/python/dolma/core/progressbar.py b/python/dolma/core/progressbar.py index c4878f7d..b7cfe181 100644 --- a/python/dolma/core/progressbar.py +++ b/python/dolma/core/progressbar.py @@ -5,7 +5,7 @@ from functools import reduce from hashlib import sha1 from inspect import Parameter -from inspect import signature as get_signature # type: ignore +from inspect import signature as get_signature from queue import Queue from threading import Thread from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type From 86371d62847e763a8a609709d2174bbe13028618 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 07:51:20 -0700 Subject: [PATCH 07/14] quoting type aliases --- python/dolma/core/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index b591faac..e5f2e404 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -33,7 +33,7 @@ METADATA_SUFFIX = ".done.txt" # we need to quote the type alias because we want to support Python 3.8 -QueueType: TypeAlias = Queue[Union[None, Tuple[int, ...]]] +QueueType: TypeAlias = "Queue[Union[None, Tuple[int, ...]]]" KwargsType: TypeAlias = Dict[str, Any] BPP = TypeVar("BPP", bound="BaseParallelProcessor") From 73aad082333a1eb8c5e0dff00aef0efd2713488c Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 09:27:05 -0700 Subject: [PATCH 08/14] 3.8 compatibility --- python/dolma/core/parallel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index e5f2e404..72fdda98 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -311,8 +311,7 @@ def _log_backoff(cls, details: Details): if ex := details.get("exception"): # add details about the exception to the message import traceback # pylint: disable=import-outside-toplevel - - message += " due to " + "\n".join(traceback.format_exception_only(ex)).strip() # type: ignore + message += " due to " + "".join(traceback.format_exception_only(type(ex), ex)).strip() # type: ignore cls.get_logger().warning(message) From b9ec3ebe81f5fd94ba2da2f70cbe87cd8a93b7be Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 09:38:07 -0700 Subject: [PATCH 09/14] more style --- python/dolma/core/parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index 72fdda98..084b6400 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -311,7 +311,8 @@ def _log_backoff(cls, details: Details): if ex := details.get("exception"): # add details about the exception to the message import traceback # pylint: disable=import-outside-toplevel - message += " due to " + "".join(traceback.format_exception_only(type(ex), ex)).strip() # type: ignore + + message += " due to " + "".join(traceback.format_exception_only(type(ex), ex)).strip() cls.get_logger().warning(message) From e42f9fcc31dcb149ce0cc20155b3a1de2e9f9949 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 10:35:51 -0700 Subject: [PATCH 10/14] pyi --- pyproject.toml | 2 +- python/dolma/core/mp_tools.pyi | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 python/dolma/core/mp_tools.pyi diff --git a/pyproject.toml b/pyproject.toml index fe57a13a..f8860d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,7 +177,7 @@ features = ["pyo3/extension-module"] where = ["src"] [tool.setuptools.package-data] -dolma = ["py.typed", "data/*"] +dolma = ["py.typed", "data/*", "*.pyi"] [tool.black] line-length = 115 diff --git a/python/dolma/core/mp_tools.pyi b/python/dolma/core/mp_tools.pyi new file mode 100644 index 00000000..30fc1f2e --- /dev/null +++ b/python/dolma/core/mp_tools.pyi @@ -0,0 +1,19 @@ +from collections.abc import Callable, Iterable +from multiprocessing.managers import SyncManager +from multiprocessing.pool import ApplyResult, Pool +from typing import Any + +class ResultWithDebug(ApplyResult): ... # noqa: E701,E302 +class ManagerWithDebug(SyncManager): ... # noqa: E701 + +class PoolWithDebug(Pool): # noqa: E302 + def __init__( # noqa: E704 + self, + processes: int | None = None, + initializer: Callable[..., Any] | None = None, + initargs: Iterable[Any] = (), + maxtasksperchild: int | None = None, + debug: bool = False, + ): ... + +def get_manager(pool: Pool) -> SyncManager: ... # noqa: E701, E704, E302 From be6c98432690342bd9183f69167d861ee25933dd Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 17:34:10 -0700 Subject: [PATCH 11/14] viz pbar --- python/dolma/core/parallel.py | 2 +- python/dolma/core/progressbar.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index 084b6400..795d2ef1 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -420,7 +420,7 @@ def _run_all( with PoolWithDebug(processes=num_processes, debug=self.debug) as pool: pbar_queue: QueueType = (manager := get_manager(pool)).Queue() - (pbar := self.PROGRESS_BAR_CLS(pbar_queue)).start() + (pbar := self.PROGRESS_BAR_CLS(pbar_queue, thread=True)).start() process_single_fn = partial(self.process_single, queue=pbar_queue) results = [] diff --git a/python/dolma/core/progressbar.py b/python/dolma/core/progressbar.py index b7cfe181..332ecb37 100644 --- a/python/dolma/core/progressbar.py +++ b/python/dolma/core/progressbar.py @@ -81,7 +81,7 @@ def __init__(self, queue: QueueType, min_step: int = 1, min_time: float = 1e-1, """ self._logger = get_logger(self.__class__.__name__, "warn") self._queue = queue - self._last_update_delta_time = 0 + self._last_update_time = 0 self._last_update_delta_step = 0 self._update_every_seconds = min_time @@ -196,6 +196,7 @@ def _update(self): # reset the steps self._last_update_delta_step = 0 + self._last_update_time = time.time() # reset the steps for k in self.fields(): @@ -208,6 +209,7 @@ def update(self): if self._update_every_steps > self._last_update_delta_step: return + time_before_update = self._last_update_time self._update() # check if we wanna update frequency based on steps @@ -216,8 +218,7 @@ def update(self): return # check if we wanna update frequency based on time - self._last_update_delta_time = -(time.time() - self._last_update_delta_time) - if self._last_update_delta_time < self._update_every_seconds: + if (self._last_update_time - time_before_update) < self._update_every_seconds: self._update_every_steps *= 2 return From f5c696cfed42f417f7903a784c694a8ae7aee78b Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 17:42:42 -0700 Subject: [PATCH 12/14] fixing small regression in tests --- python/dolma/core/progressbar.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/dolma/core/progressbar.py b/python/dolma/core/progressbar.py index 332ecb37..74e88ea3 100644 --- a/python/dolma/core/progressbar.py +++ b/python/dolma/core/progressbar.py @@ -81,8 +81,8 @@ def __init__(self, queue: QueueType, min_step: int = 1, min_time: float = 1e-1, """ self._logger = get_logger(self.__class__.__name__, "warn") self._queue = queue - self._last_update_time = 0 - self._last_update_delta_step = 0 + self._last_update_time = time.time() + self._last_update_step = 0 self._update_every_seconds = min_time self._update_every_steps = min_step @@ -195,7 +195,7 @@ def _update(self): self._queue.put_nowait(update) # reset the steps - self._last_update_delta_step = 0 + self._last_update_step = 0 self._last_update_time = time.time() # reset the steps @@ -204,9 +204,9 @@ def _update(self): def update(self): # update the number of steps since the last update - self._last_update_delta_step += 1 + self._last_update_step += 1 - if self._update_every_steps > self._last_update_delta_step: + if self._update_every_steps > self._last_update_step: return time_before_update = self._last_update_time From e941f055f01b5ae5aaad12481573e7f1e171cea2 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 17:47:41 -0700 Subject: [PATCH 13/14] order from user --- python/dolma/core/progressbar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dolma/core/progressbar.py b/python/dolma/core/progressbar.py index 74e88ea3..aa70c767 100644 --- a/python/dolma/core/progressbar.py +++ b/python/dolma/core/progressbar.py @@ -142,7 +142,7 @@ def from_increment_function(cls, processor: "BaseParallelProcessor") -> "Type[Ba "increment_progressbar must have a default value of 0 for all arguments except 'queue'; " "Check that you have subclassed BaseParallelProcessor correctly!" ) - params = sorted(k for k, p in sig.parameters.items() if k != "queue" and p.kind != Parameter.empty) + params = [k for k, p in sig.parameters.items() if k != "queue" and p.kind != Parameter.empty] h = reduce(lambda h, e: h.update(e.encode()) or h, params, sha1()).hexdigest() # type: ignore # create a new class @@ -164,7 +164,7 @@ def fields(cls) -> Tuple[str, ...]: fields: Optional[Tuple[str, ...]] = cls.__dict__.get("__fields__") if fields is None: - fields = tuple(sorted(n for n, t in cls.__annotations__.items() if issubclass(t, int))) + fields = tuple(n for n, t in cls.__annotations__.items() if issubclass(t, int)) setattr(cls, "__fields__", fields) if len(fields) == 0: From 1e292ff52b406f5520f58aa7332c63fd15829de0 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Fri, 31 May 2024 21:35:26 +0000 Subject: [PATCH 14/14] min timeout --- python/dolma/core/parallel.py | 2 +- python/dolma/core/progressbar.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index 795d2ef1..4fb134c9 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -420,7 +420,7 @@ def _run_all( with PoolWithDebug(processes=num_processes, debug=self.debug) as pool: pbar_queue: QueueType = (manager := get_manager(pool)).Queue() - (pbar := self.PROGRESS_BAR_CLS(pbar_queue, thread=True)).start() + (pbar := self.PROGRESS_BAR_CLS(queue=pbar_queue, min_time=self.pbar_timeout, thread=True)).start() process_single_fn = partial(self.process_single, queue=pbar_queue) results = [] diff --git a/python/dolma/core/progressbar.py b/python/dolma/core/progressbar.py index aa70c767..9a5d9459 100644 --- a/python/dolma/core/progressbar.py +++ b/python/dolma/core/progressbar.py @@ -69,7 +69,7 @@ class MyProgressBar(BaseProgressBar): ``` """ - def __init__(self, queue: QueueType, min_step: int = 1, min_time: float = 1e-1, thread: bool = False): + def __init__(self, queue: QueueType, min_step: int = 1, min_time: float = 1e-3, thread: bool = False): """ Initialize the ProgressBar object.