Skip to content

Commit

Permalink
Add Checkpointer step (#1114)
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss authored Jan 28, 2025
1 parent 84ea198 commit f5ddbc6
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 5 deletions.
59 changes: 59 additions & 0 deletions docs/sections/how_to_guides/advanced/checkpointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Push data to the hub while the pipeline is running

Long-running pipelines can be resource-intensive, and ensuring everything is functioning as expected is crucial. To make this process seamless, the [HuggingFaceHubCheckpointer][distilabel.steps.checkpointer.HuggingFaceHubCheckpointer] step has been designed to integrate directly into the pipeline workflow.

The [`HuggingFaceHubCheckpointer`](https://distilabel.argilla.io/dev/sections/getting_started/quickstart/) allows you to periodically save your generated data as a Hugging Face Dataset at configurable intervals (every `input_batch_size` examples generated).

Just add the [`HuggingFaceHubCheckpointer`](https://distilabel.argilla.io/dev/sections/getting_started/quickstart/) as any other step in your pipeline.

## Sample pipeline with dummy data to see the checkpoint strategy in action

The following pipeline starts from a fake dataset with dummy data, passes that through a fake `DoNothing` step (any other step/s work here, but this can be useful to explore the behavior), and makes use of the [`HuggingFaceHubCheckpointer`](https://distilabel.argilla.io/dev/sections/getting_started/quickstart/) step to push the data to the hub.

```python
from datasets import Dataset

from distilabel.pipeline import Pipeline
from distilabel.steps import HuggingFaceHubCheckpointer
from distilabel.steps.base import Step, StepInput
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from distilabel.typing import StepOutput

dataset = Dataset.from_dict({"a": [1, 2, 3, 4] * 50, "b": [5, 6, 7, 8] * 50})

class DoNothing(Step):
def process(self, *inputs: StepInput) -> "StepOutput":
for input in inputs:
yield input

with Pipeline(name="pipeline-with-checkpoints") as pipeline:
text_generation = DoNothing(
input_batch_size=60
)
checkpoint = HuggingFaceHubCheckpointer(
repo_id="username/streaming_test_1", # (1)
private=True,
input_batch_size=50 # (2)
)
text_generation >> checkpoint


if __name__ == "__main__":
distiset = pipeline.run(
dataset=dataset,
use_cache=False
)
distiset.push_to_hub(repo_id="username/streaming_test")
```

1. The name of the dataset for the checkpoints, can be different to the final distiset. This dataset
will contain less information than the final distiset to make it faster while the pipeline is running.
2. The `input_batch_size` determines how often the data is pushed to the Hugging Face Hub. If the process is really slow, say for a big model, a value like 100 may be on point, for smaller models or pipelines that generate data faster, 10.000 maybe more relevant. It's better to explore the value for a given use case.

The final datasets can be found in the following links:

- Checkpoint dataset: [distilabel-internal-testing/streaming_test_1](https://huggingface.co/datasets/distilabel-internal-testing/streaming_test_1)

- Final distiset: [distilabel-internal-testing/streaming_test](https://huggingface.co/datasets/distilabel-internal-testing/streaming_test)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ nav:
- Exporting data to Argilla: "sections/how_to_guides/advanced/argilla.md"
- Structured data generation: "sections/how_to_guides/advanced/structured_generation.md"
- Offline Batch Generation: "sections/how_to_guides/advanced/offline_batch_generation.md"
- Push data to the hub while the pipeline is running: "sections/how_to_guides/advanced/checkpointing.md"
- Specifying requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md"
- Load groups and execution stages: "sections/how_to_guides/advanced/load_groups_and_execution_stages.md"
- Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md"
Expand Down
12 changes: 9 additions & 3 deletions src/distilabel/distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,15 @@ def _extract_readme_metadata(
Returns:
The metadata extracted from the README.md file of the dataset repository as a dict.
"""
readme_path = Path(
hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token)
)
import requests

try:
readme_path = Path(
hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token)
)
except requests.exceptions.HTTPError:
# This can fail when using the checkpoint step
return {}
# Remove the '---' from the metadata
metadata = re.findall(r"---\n(.*?)\n---", readme_path.read_text(), re.DOTALL)[0]
metadata = yaml.safe_load(metadata)
Expand Down
7 changes: 7 additions & 0 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
cache_dir: Optional[Union[str, "PathLike"]] = None,
enable_metadata: bool = False,
requirements: Optional[List[str]] = None,
dump_batch_size: int = 50,
) -> None:
"""Initialize the `BasePipeline` instance.
Expand All @@ -189,6 +190,9 @@ def __init__(
requirements: List of requirements that must be installed to run the pipeline.
Defaults to `None`, but can be helpful to inform in a pipeline to be shared
that this requirements must be installed.
dump_batch_size: Determines the frequency of writing the buffer to the file,
as it will determine when the buffer is full and we should write to the file.
Defaults to 50 (every 50 elements in the buffer we can check for writes).
"""
self.name = name or _PIPELINE_DEFAULT_NAME
self.description = description
Expand Down Expand Up @@ -234,6 +238,8 @@ def __init__(

self._log_queue: Union["Queue[Any]", None] = None

self._dump_batch_size = dump_batch_size

def __enter__(self) -> Self:
"""Set the global pipeline instance when entering a pipeline context."""
_GlobalPipelineManager.set_pipeline(self)
Expand Down Expand Up @@ -1022,6 +1028,7 @@ def _setup_write_buffer(self, use_cache: bool = True) -> None:
].use_cache
for step_name in self.dag
},
dump_batch_size=self._dump_batch_size,
)

def _print_load_stages_info(self) -> None:
Expand Down
7 changes: 5 additions & 2 deletions src/distilabel/pipeline/write_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
path: "PathLike",
leaf_steps: Set[str],
steps_cached: Optional[Dict[str, bool]] = None,
dump_batch_size: int = 50,
) -> None:
"""
Args:
Expand All @@ -48,6 +49,9 @@ def __init__(
use_cache. We will use this to determine whether we have to read
a previous parquet table to concatenate before saving the cached
datasets.
dump_batch_size: Determines the frequency of writing the buffer to the file,
as it will determine when the buffer is full and we should write to the file.
Defaults to 50 (every 50 elements in the buffer we can check for writes).
Raises:
ValueError: If the path is not a directory.
Expand All @@ -64,9 +68,8 @@ def __init__(
self._buffers: Dict[str, List[Dict[str, Any]]] = {
step: [] for step in leaf_steps
}
# TODO: make this configurable
self._buffers_dump_batch_size: Dict[str, int] = {
step: 50 for step in leaf_steps
step: dump_batch_size for step in leaf_steps
}
self._buffer_last_schema = {}
self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps}
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
StepInput,
StepResources,
)
from distilabel.steps.checkpointer import HuggingFaceHubCheckpointer
from distilabel.steps.clustering.dbscan import DBSCAN
from distilabel.steps.clustering.text_clustering import TextClustering
from distilabel.steps.clustering.umap import UMAP
Expand Down Expand Up @@ -76,6 +77,7 @@
"GeneratorStepOutput",
"GlobalStep",
"GroupColumns",
"HuggingFaceHubCheckpointer",
"KeepColumns",
"LoadDataFromDicts",
"LoadDataFromDisk",
Expand Down
131 changes: 131 additions & 0 deletions src/distilabel/steps/checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import tempfile
from typing import TYPE_CHECKING, Optional

from pydantic import PrivateAttr

from distilabel.steps.base import Step, StepInput

if TYPE_CHECKING:
from distilabel.typing import StepOutput

from huggingface_hub import HfApi


class HuggingFaceHubCheckpointer(Step):
"""Special type of step that uploads the data to a Hugging Face Hub dataset.
A `Step` that uploads the data to a Hugging Face Hub dataset. The data is uploaded in JSONL format
in a specific Hugging Face Dataset, which can be different to the one where the main distiset
pipeline is saved. The data is checked every `input_batch_size` inputs, and a new file is created
in the `repo_id` repository. There will be different config files depending on the leaf steps
as in the pipeline, and each file will be numbered sequentially. As there will be writes every
`input_batch_size` inputs, it's advisable not to set a small number on this step, as that
will slow down the process.
Attributes:
repo_id:
The ID of the repository to push to in the following format: `<user>/<dataset_name>` or
`<org>/<dataset_name>`. Also accepts `<dataset_name>`, which will default to the namespace
of the logged-in user.
private:
Whether the dataset repository should be set to private or not. Only affects repository creation:
a repository that already exists will not be affected by that parameter.
token:
An optional authentication token for the Hugging Face Hub. If no token is passed, will default
to the token saved locally when logging in with `huggingface-cli login`. Will raise an error
if no token is passed and the user is not logged-in.
Categories:
- helper
Examples:
Do checkpoints of the data generated in a Hugging Face Hub dataset:
```python
from typing import TYPE_CHECKING
from datasets import Dataset
from distilabel.pipeline import Pipeline
from distilabel.steps import HuggingFaceHubCheckpointer
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from distilabel.typing import StepOutput
# Create a dummy dataset
dataset = Dataset.from_dict({"instruction": ["tell me lies"] * 100})
with Pipeline(name="pipeline-with-checkpoints") as pipeline:
text_generation = TextGeneration(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
),
template="Follow the following instruction: {{ instruction }}"
)
checkpoint = HuggingFaceHubCheckpointer(
repo_id="username/streaming_checkpoint",
private=True,
input_batch_size=50 # Will save write the data to the dataset every 50 inputs
)
text_generation >> checkpoint
```
"""

repo_id: str
private: bool = True
token: Optional[str] = None

_counter: int = PrivateAttr(0)

def load(self) -> None:
super().load()
if self.token is None:
from distilabel.utils.huggingface import get_hf_token

self.token = get_hf_token(self.__class__.__name__, "token")

self._api = HfApi(token=self.token)
# Create the repo if it doesn't exist
if not self._api.repo_exists(repo_id=self.repo_id, repo_type="dataset"):
self._logger.info(f"Creating repo {self.repo_id}")
self._api.create_repo(
repo_id=self.repo_id, repo_type="dataset", private=self.private
)

def process(self, *inputs: StepInput) -> "StepOutput":
for i, input in enumerate(inputs):
# Each section of *inputs corresponds to a different configuration of the pipeline
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl") as temp_file:
for item in input:
json_line = json.dumps(item, ensure_ascii=False)
temp_file.write(json_line + "\n")
try:
self._api.upload_file(
path_or_fileobj=temp_file.name,
path_in_repo=f"config-{i}/train-{str(self._counter).zfill(5)}.jsonl",
repo_id=self.repo_id,
repo_type="dataset",
commit_message=f"Checkpoint {i}-{self._counter}",
)
self._logger.info(f"⬆️ Uploaded checkpoint {i}-{self._counter}")
finally:
self._counter += 1

yield from inputs
2 changes: 2 additions & 0 deletions src/distilabel/utils/mkdocs/components_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
"save": ":material-content-save:",
"image-generation": ":material-image:",
"labelling": ":label:",
"helper": ":fontawesome-solid-kit-medical:",
}

_STEP_CATEGORY_TO_DESCRIPTION = {
Expand All @@ -116,6 +117,7 @@
"save": "Save steps are used to save the data.",
"image-generation": "Image generation steps are used to generate images based on a given prompt.",
"labelling": "Labelling steps are used to label the data.",
"helper": "Helper steps are used to do extra tasks during the pipeline execution.",
}


Expand Down

0 comments on commit f5ddbc6

Please sign in to comment.