diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml
index 12830c9d5..ba6baa1ea 100644
--- a/.github/workflows/codspeed.yml
+++ b/.github/workflows/codspeed.yml
@@ -13,12 +13,12 @@ concurrency:
jobs:
benchmarks:
- runs-on: ubuntu-latest
+ runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: "3.12"
# Looks like it's not working very well for other people:
@@ -26,7 +26,7 @@ jobs:
# cache: "pip"
# cache-dependency-path: pyproject.toml
- - uses: actions/cache@v3
+ - uses: actions/cache@v4
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -37,7 +37,7 @@ jobs:
run: ./scripts/install_dependencies.sh
- name: Run benchmarks
- uses: CodSpeedHQ/action@v2
+ uses: CodSpeedHQ/action@v3
with:
token: ${{ secrets.CODSPEED_TOKEN }}
run: pytest tests/ --codspeed
diff --git a/.github/workflows/docs-pr-close.yml b/.github/workflows/docs-pr-close.yml
index 4f9a307b8..71f4e5ff9 100644
--- a/.github/workflows/docs-pr-close.yml
+++ b/.github/workflows/docs-pr-close.yml
@@ -19,12 +19,12 @@ jobs:
fetch-depth: 0
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
- python-version: ${{ matrix.python-version }}
+ python-version: "3.11"
- name: Install dependencies
- run: pip install -e .[docs]
+ run: ./scripts/install_docs_dependencies.sh
- name: Set git credentials
run: |
diff --git a/.github/workflows/docs-pr.yml b/.github/workflows/docs-pr.yml
index 306af85ea..48c7236a5 100644
--- a/.github/workflows/docs-pr.yml
+++ b/.github/workflows/docs-pr.yml
@@ -22,15 +22,11 @@ jobs:
- uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
- python-version: ${{ matrix.python-version }}
- # Looks like it's not working very well for other people:
- # https://github.com/actions/setup-python/issues/436
- # cache: "pip"
- # cache-dependency-path: pyproject.toml
+ python-version: "3.11"
- - uses: actions/cache@v3
+ - uses: actions/cache@v4
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -38,7 +34,7 @@ jobs:
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
- run: pip install -e .[docs]
+ run: ./scripts/install_docs_dependencies.sh
- name: Set git credentials
run: |
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 62c141248..dd59a5129 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -24,15 +24,11 @@ jobs:
- uses: actions/checkout@v4
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
- python-version: ${{ matrix.python-version }}
- # Looks like it's not working very well for other people:
- # https://github.com/actions/setup-python/issues/436
- # cache: "pip"
- # cache-dependency-path: pyproject.toml
+ python-version: "3.11"
- - uses: actions/cache@v3
+ - uses: actions/cache@v4
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -40,7 +36,10 @@ jobs:
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
- run: pip install -e .[docs]
+ run: ./scripts/install_docs_dependencies.sh
+
+ - name: Check no warnings
+ run: mkdocs build --strict
- name: Set git credentials
run: |
diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md
index 7f8cb0579..5826756a2 100644
--- a/docs/api/exceptions.md
+++ b/docs/api/exceptions.md
@@ -1,6 +1,6 @@
# Exceptions
-This section contains the `distilabel` custom exceptions. Unlike [errors][../errors.md], exceptions in `distilabel` are used to handle specific situations that can be anticipated and that can be handled in a controlled way internally by the library.
+This section contains the `distilabel` custom exceptions. Unlike [errors](errors.md), exceptions in `distilabel` are used to handle specific situations that can be anticipated and that can be handled in a controlled way internally by the library.
:::distilabel.exceptions.DistilabelException
:::distilabel.exceptions.DistilabelGenerationException
diff --git a/docs/assets/images/sections/caching/caching_1.png b/docs/assets/images/sections/caching/caching_1.png
new file mode 100644
index 000000000..cde228769
Binary files /dev/null and b/docs/assets/images/sections/caching/caching_1.png differ
diff --git a/docs/assets/images/sections/caching/caching_2.png b/docs/assets/images/sections/caching/caching_2.png
new file mode 100644
index 000000000..8f0d9d4d5
Binary files /dev/null and b/docs/assets/images/sections/caching/caching_2.png differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_1.png b/docs/assets/images/sections/caching/caching_pipe_1.png
deleted file mode 100644
index f41f38a60..000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_1.png and /dev/null differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_2.png b/docs/assets/images/sections/caching/caching_pipe_2.png
deleted file mode 100644
index 22adebc1a..000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_2.png and /dev/null differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_3.png b/docs/assets/images/sections/caching/caching_pipe_3.png
deleted file mode 100644
index b41a3a6c8..000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_3.png and /dev/null differ
diff --git a/docs/assets/images/sections/caching/caching_pipe_4.png b/docs/assets/images/sections/caching/caching_pipe_4.png
deleted file mode 100644
index 12ea2c7f2..000000000
Binary files a/docs/assets/images/sections/caching/caching_pipe_4.png and /dev/null differ
diff --git a/docs/assets/images/sections/how_to_guides/tasks/task_print.png b/docs/assets/images/sections/how_to_guides/tasks/task_print.png
new file mode 100644
index 000000000..95498c8c6
Binary files /dev/null and b/docs/assets/images/sections/how_to_guides/tasks/task_print.png differ
diff --git a/docs/assets/pipelines/clair.png b/docs/assets/pipelines/clair.png
new file mode 100644
index 000000000..c80e801f9
Binary files /dev/null and b/docs/assets/pipelines/clair.png differ
diff --git a/docs/assets/tutorials-assets/overview-apigen.jpg b/docs/assets/tutorials-assets/overview-apigen.jpg
new file mode 100644
index 000000000..61deefac9
Binary files /dev/null and b/docs/assets/tutorials-assets/overview-apigen.jpg differ
diff --git a/docs/sections/getting_started/quickstart.md b/docs/sections/getting_started/quickstart.md
index 04d84d978..7af9bca8f 100644
--- a/docs/sections/getting_started/quickstart.md
+++ b/docs/sections/getting_started/quickstart.md
@@ -30,7 +30,7 @@ pip install distilabel[hf-inference-endpoints] --upgrade
## Define a pipeline
-In this guide we will walk you through the process of creating a simple pipeline that uses the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text based on the dataset using the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task.
+In this guide we will walk you through the process of creating a simple pipeline that uses the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text based on the dataset using the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task.
> You can check the available models in the [Hugging Face Model Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) and filter by `Inference status`.
@@ -53,12 +53,14 @@ with Pipeline( # (1)
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
), # (5)
+ system_prompt="You are a creative AI Assistant writer.",
+ template="Follow the following instruction: {{ instruction }}" # (6)
)
- load_dataset >> text_generation # (6)
+ load_dataset >> text_generation # (7)
if __name__ == "__main__":
- distiset = pipeline.run( # (7)
+ distiset = pipeline.run( # (8)
parameters={
load_dataset.name: {
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
@@ -74,7 +76,7 @@ if __name__ == "__main__":
},
},
)
- distiset.push_to_hub(repo_id="distilabel-example") # (8)
+ distiset.push_to_hub(repo_id="distilabel-example") # (9)
```
1. We define a [`Pipeline`][distilabel.pipeline.Pipeline] with the name `simple-text-generation-pipeline` and a description `A simple text generation pipeline`. Note that the `name` is mandatory and will be used to calculate the `cache` signature path, so changing the name will change the cache path and will be identified as a different pipeline.
@@ -83,12 +85,14 @@ if __name__ == "__main__":
3. We define a [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] step named `load_dataset` that will load a dataset from the Hugging Face Hub, as provided via runtime parameters in the `pipeline.run` method below, but it can also be defined within the class instance via the arg `repo_id=...`. This step will produce output batches with the rows from the dataset, and the column `prompt` will be mapped to the `instruction` field.
-4. We define a [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct`.
+4. We define a [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct`.
-5. We define the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct` that will be used by the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task. In this case, since the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] is used, we assume that the `HF_TOKEN` environment variable is set.
+5. We define the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct` that will be used by the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task. In this case, since the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] is used, we assume that the `HF_TOKEN` environment variable is set.
-6. We connect the `load_dataset` step to the `text_generation` task using the `rshift` operator, meaning that the output from the `load_dataset` step will be used as input for the `text_generation` task.
+6. Both `system_prompt` and `template` are optional fields. The `template` must be informed as a string following the [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) template format, and the fields that appear there ("instruction" in this case, which corresponds to the default) must be informed in the `columns` attribute. The component gallery for [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) has examples to get you started.
-7. We run the pipeline with the parameters for the `load_dataset` and `text_generation` steps. The `load_dataset` step will use the repository `distilabel-internal-testing/instruction-dataset-mini` and the `test` split, and the `text_generation` task will use the `generation_kwargs` with the `temperature` set to `0.7` and the `max_new_tokens` set to `512`.
+7. We connect the `load_dataset` step to the `text_generation` task using the `rshift` operator, meaning that the output from the `load_dataset` step will be used as input for the `text_generation` task.
-8. Optionally, we can push the generated [`Distiset`][distilabel.distiset.Distiset] to the Hugging Face Hub repository `distilabel-example`. This will allow you to share the generated dataset with others and use it in other pipelines.
+8. We run the pipeline with the parameters for the `load_dataset` and `text_generation` steps. The `load_dataset` step will use the repository `distilabel-internal-testing/instruction-dataset-mini` and the `test` split, and the `text_generation` task will use the `generation_kwargs` with the `temperature` set to `0.7` and the `max_new_tokens` set to `512`.
+
+9. Optionally, we can push the generated [`Distiset`][distilabel.distiset.Distiset] to the Hugging Face Hub repository `distilabel-example`. This will allow you to share the generated dataset with others and use it in other pipelines.
diff --git a/docs/sections/how_to_guides/advanced/caching.md b/docs/sections/how_to_guides/advanced/caching.md
index 1fc941494..d4a03c09f 100644
--- a/docs/sections/how_to_guides/advanced/caching.md
+++ b/docs/sections/how_to_guides/advanced/caching.md
@@ -1,135 +1,60 @@
-# Cache and recover pipeline executions
+# Pipeline cache
-Distilabel `Pipelines` automatically save all the intermediate steps to avoid losing any data in case of error.
+`distilabel` will automatically save all the intermediate outputs generated by each [`Step`][distilabel.steps.base.Step] of a [`Pipeline`][distilabel.pipeline.local.Pipeline], so these outputs can be reused to recover the state of a pipeline execution that was stopped before finishing or to not have to re-execute steps from a pipeline after adding a new downstream step.
-## Cache directory
+## How to enable/disable the cache
-Out of the box, the `Pipeline` will use the `~/.cache/distilabel/pipelines` directory to store the different pipelines[^1]:
+The use of the cache can be toggled using the `use_cache` parameter of the [`Pipeline.use_cache`][distilabel.pipeline.base.BasePipeline.run] method. If `True`, then `distilabel ` will use the reuse the outputs of previous executions for the new execution. If `False`, then `distilabel` will re-execute all the steps of the pipeline to generate new outputs for all the steps.
```python
-from distilabel.pipeline.local import Pipeline
-
-with Pipeline(name="cache_testing") as pipeline:
+with Pipeline(name="my-pipeline") as pipeline:
...
-```
-This directory can be modified by setting the `DISTILABEL_CACHE_DIR` environment variable (`export DISTILABEL_CACHE_DIR=my_cache_dir`) or by explicitly passing the `cache_dir` variable to the `Pipeline` constructor like so:
-
-```python
-with Pipeline(name="cache_testing", cache_dir="~/my_cache_dir") as pipeline:
- ...
+if __name__ == "__main__":
+ distiset = pipeline.run(use_cache=False) # (1)
```
-[^1]:
-
- The pipelines will be organized according to the pipeline's name attribute, and then by the hash, in case you want to look for something manually, like the following example:
-
- ```bash
- $ tree ~/.cache/distilabel/pipelines/
- ├── cache_testing
- │ └── 13da04d2cc255b2180d6bebb50fb5be91124f70d
- │ ├── batch_manager.json
- │ ├── batch_manager_steps
- │ │ └── succeed_always_0.json
- │ ├── data
- │ │ └── succeed_always_0
- │ │ └── 00001.parquet
- │ ├── pipeline.log
- │ └── pipeline.yaml
- └── test-pipe
- └── f23b95d7ad4e9301a70b2a54c953f8375ebfcd5c
- ├── batch_manager.json
- ├── batch_manager_steps
- │ └── text_generation_0.json
- ├── data
- │ └── text_generation_0
- │ └── 00001.parquet
- ├── pipeline.log
- └── pipeline.yaml
- ```
-
-## How does it work?
-
-Let's take a look at the logging messages from a sample pipeline.
-
-When we run a `Pipeline` for the first time
-
-![Pipeline 1](../../../assets/images/sections/caching/caching_pipe_1.png)
-
-If we decide to stop the pipeline (say we kill the run altogether via `CTRL + C` or `CMD + C` in *macOS*), we will see the signal sent to the different workers:
-
-![Pipeline 2](../../../assets/images/sections/caching/caching_pipe_2.png)
-
-After this step, when we run again the pipeline, the first log message we see corresponds to "Load pipeline from cache", which will restart processing from where it stopped:
-
-![Pipeline 3](../../../assets/images/sections/caching/caching_pipe_3.png)
+1. Pipeline cache is disabled
-Finally, if we decide to run the same `Pipeline` after it has finished completely, it won't start again but resume the process, as we already have all the data processed:
+In addition, the cache can be enabled/disabled at [`Step`][distilabel.steps.base.Step] level using its `use_cache` attribute. If `True`, then the outputs of the step will be reused in the new pipeline execution. If `False`, then the step will be re-executed to generate new outputs. If the cache of one step is disabled and the outputs have to be regenerated, then the outputs of the steps that depend on this step will also be regenerated.
-![Pipeline 4](../../../assets/images/sections/caching/caching_pipe_4.png)
-
-### Serialization
-
-Let's see what gets serialized by looking at a sample `Pipeline`'s cached folder:
-
-```bash
-$ tree ~/.cache/distilabel/pipelines/73ca3f6b7a613fb9694db7631cc038d379f1f533
-├── batch_manager.json
-├── batch_manager_steps
-│ ├── generate_response.json
-│ └── rename_columns.json
-├── data
-│ └── generate_response
-│ ├── 00001.parquet
-│ └── 00002.parquet
-└── pipeline.yaml
+```python
+with Pipeline(name="writting-assistant") as pipeline:
+ load_data = LoadDataFromDicts(
+ data=[
+ {
+ "instruction": "How much is 2+2?"
+ }
+ ]
+ )
+
+ generation = TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="Qwen/Qwen2.5-72B-Instruct",
+ generation_kwargs={
+ "temperature": 0.8,
+ "max_new_tokens": 512,
+ },
+ ),
+ use_cache=False # (1)
+ )
+
+ load_data >> generation
+
+if __name__ == "__main__":
+ distiset = pipeline.run()
```
-The `Pipeline` will have a signature created from the arguments that define it so we can find it afterwards, and the contents are the following:
-
-- `batch_manager.json`
-
- Folder that stores the content of the internal batch manager to keep track of the data. Along with the `batch_manager_steps/` they store the information to restart the `Pipeline`. One shouldn't need to know about it.
-
-- `pipeline.yaml`
-
- This file contains a representation of the `Pipeline` in *YAML* format. If we push a `Distiset` to the Hugging Face Hub as obtained from calling `Pipeline.run`, this file will be stored at our datasets' repository, allowing to reproduce the `Pipeline` using the `CLI`:
+1. Step cache is disabled and every time the pipeline is executed, this step will be re-executed
- ```bash
- distilabel pipeline run --config "path/to/pipeline.yaml"
- ```
+## How a cache hit is triggered
-- `data/`
+`distilabel` groups information and data generated by a `Pipeline` using the name of the pipeline, so the first factor that triggers a cache hit is the name of the pipeline. The second factor, is the [`Pipeline.signature`][distilabel.pipeline.local.Pipeline.signature] property. This property returns a hash that is generated using the names of the steps used in the pipeline and their connections. The third factor, is the [`Pipeline.aggregated_steps_signature`][distilabel.pipeline.local.Pipeline.aggregated_steps_signature] property which is used to determine if the new pipeline execution is exactly the same as one of the previous i.e. the pipeline contains exactly the same steps, with exactly the same connections and the steps are using exactly the same parameters. If these three factors are met, then the cache hit is triggered and the pipeline won't get re-executed and instead the function [`create_distiset`][distilabel.distiset.create_distiset] will be used to create the resulting [`Distiset`][distilabel.distiset.Distiset] using the outputs of the previous execution, as it can be seen in the following image:
- Folder that stores the data generated, with a special folder to keep track of each `leaf_step` separately. We can recreate a `Distiset` from the contents of this folder (*Parquet* files), as we will see next.
-
-- `pipeline.log`
-
- This file stores the logs that the `Pipeline` generated while processing. Just as with the `pipeline.yaml` file, it will be pushed to the Hugging Face Hub datasets` repository to keep track of the information.
-
-## create_distiset
-
-In case we wanted to regenerate the dataset from the `cache`, we can do it using the [`create_distiset`][distilabel.distiset.create_distiset] function and passing the path to the `/data` folder inside our `Pipeline`:
-
-```python
-from pathlib import Path
-from distilabel.distiset import create_distiset
-
-path = Path("~/.cache/distilabel/pipelines/73ca3f6b7a613fb9694db7631cc038d379f1f533/data")
-ds = create_distiset(path)
-ds
-# Distiset({
-# generate_response: DatasetDict({
-# train: Dataset({
-# features: ['instruction', 'response'],
-# num_rows: 80
-# })
-# })
-# })
-```
+![Complete cache hit](../../../assets/images/sections/caching/caching_1.png)
-!!! Note
+If the new pipeline execution have a different `Pipeline.aggregated_steps_signature` i.e. at least one step has changed its parameters, `distilabel` will reuse the outputs of the steps that have not changed and re-execute the steps that have changed, as it can be seen in the following image:
- Internally, the function will try to inject the `pipeline_path` variable if it's not passed via argument, assuming it's in the parent directory of the current one, called `pipeline.yaml`. If the file doesn't exist, it won't raise any error, but take into account that if the `Distiset` is pushed to the Hugging Face Hub, the `pipeline.yaml` won't be generated. The same happens with the `pipeline.log` file, it can be passed via `log_filename_path`, but it will try to locate it automatically.
+![Partial cache hit](../../../assets/images/sections/caching/caching_2.png)
- Lastly, there is the option of including the `distilabel_metadata` column in the final dataset. This column can contain custom metadata generated automatically by the pipeline, like the raw output from an `LLM` without formatting in case of failure, and we can decide whether to include it using the `enable_metadata` argument.
+The same pipeline from above gets executed a third time, but this time the last step `text_generation_1` changed, so it's needed to re-execute it. The other steps, as they have not been changed, doesn't need to be re-executed and their outputs are reused.
diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md
index 510459970..817e32815 100644
--- a/docs/sections/how_to_guides/basic/task/index.md
+++ b/docs/sections/how_to_guides/basic/task/index.md
@@ -57,6 +57,75 @@ As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] ta
)
```
+### Task.print
+
+!!! Info
+ New since version `1.4.0`, [`Task.print`][distilabel.steps.tasks.base._Task.print] `Task.print` method.
+
+The `Tasks` include a handy method to show what the prompt formatted for an `LLM` would look like, let's see an example with [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback], but it applies to any other `Task`.
+
+```python
+from distilabel.steps.tasks import UltraFeedback
+from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+uf = UltraFeedback(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+)
+uf.load()
+uf.print()
+```
+
+The result will be a rendered prompt, with the System prompt (if contained for the task) and the User prompt, rendered with rich (it will show exactly the same in a jupyter notebook).
+
+![task-print](../../../../assets/images/sections/how_to_guides/tasks/task_print.png)
+
+In case you want to test with a custom input, you can pass an example to the tasks` `format_input` method (or generate it on your own depending on the task), and pass it to the print method so that it shows your example:
+
+
+```python
+uf.print(
+ uf.format_input({"instruction": "test", "generations": ["1", "2"]})
+)
+```
+
+??? "Using a DummyLLM to avoid loading one"
+
+ In case you don't want to load an LLM to render the template, you can create a dummy one like the ones we could use for testing.
+
+ ```python
+ from distilabel.llms.base import LLM
+ from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+
+ class DummyLLM(AsyncLLM, MagpieChatTemplateMixin):
+ structured_output: Any = None
+ magpie_pre_query_template: str = "llama3"
+
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate(
+ self, input: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return ["output" for _ in range(num_generations)]
+ ```
+
+ You can use this `LLM` just as any of the other ones to `load` your task and call `print`:
+
+ ```python
+ uf = UltraFeedback(llm=DummyLLM())
+ uf.load()
+ uf.print()
+ ```
+
+!!! Note
+ When creating a custom task, the `print` method will be available by default, but it is limited to the most common scenarios for the inputs. If you test your new task and find it's not working as expected (for example, if your task contains one input consisting of a list of texts instead of a single one), you should override the `_sample_input` method. You can inspect the [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback] source code for this.
+
## Specifying the number of generations and grouping generations
All the `Task`s have a `num_generations` attribute that allows defining the number of generations that we want to have per input. We can update the example above to generate 3 completions per input:
@@ -142,35 +211,63 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe
- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output.
-```python
-from typing import Any, Dict, List, Union, TYPE_CHECKING
+=== "Inherit from `Task`"
+
+ When using the `Task` class inheritance method for creating a custom task, we can also optionally override the `Task.process` method to define a more complex processing logic involving an `LLM`, as the default one just calls the `LLM.generate` method once previously formatting the input and subsequently formatting the output. For example, [EvolInstruct][distilabel.steps.tasks.EvolInstruct] task overrides this method to call the `LLM.generate` multiple times (one for each evolution).
-from distilabel.steps.tasks.base import Task
+ ```python
+ from typing import Any, Dict, List, Union, TYPE_CHECKING
-if TYPE_CHECKING:
- from distilabel.steps.typing import StepColumns
- from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.tasks import Task
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns
+ from distilabel.steps.tasks.typing import ChatType
-class MyCustomTask(Task):
- @property
- def inputs(self) -> "StepColumns":
- return ["input_field"]
- def format_input(self, input: Dict[str, Any]) -> "ChatType":
- return [
- {
- "role": "user",
- "content": input["input_field"],
- },
- ]
+ class MyCustomTask(Task):
+ @property
+ def inputs(self) -> "StepColumns":
+ return ["input_field"]
- @property
- def outputs(self) -> "StepColumns":
- return ["output_field", "model_name"]
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ return [
+ {
+ "role": "user",
+ "content": input["input_field"],
+ },
+ ]
- def format_output(
- self, output: Union[str, None], input: Dict[str, Any]
- ) -> Dict[str, Any]:
+ @property
+ def outputs(self) -> "StepColumns":
+ return ["output_field", "model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ return {"output_field": output}
+ ```
+
+=== "Using the `@task` decorator"
+
+ If your task just needs a system prompt, a user message template and a way to format the output given by the `LLM`, then you can use the `@task` decorator to avoid writing too much boilerplate code.
+
+ ```python
+ from typing import Any, Dict, Union
+ from distilabel.steps.tasks import task
+
+
+ @task(inputs=["input_field"], outputs=["output_field"])
+ def MyCustomTask(output: Union[str, None], input: Union[Dict[str, Any], None] = None) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: |
+ My custom system prompt
+
+ user_message_template: |
+ My custom user message template: {input_field}
+ ---
+ """
+ # Format the `LLM` output here
return {"output_field": output}
-```
+ ```
diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md
index ed2f51686..0cff03101 100644
--- a/docs/sections/pipeline_samples/index.md
+++ b/docs/sections/pipeline_samples/index.md
@@ -37,6 +37,14 @@ hide: toc
[:octicons-arrow-right-24: Tutorial](tutorials/GenerateSentencePair.ipynb)
+- __Generate text classification data__
+
+ ---
+
+ Learn about how synthetic data generation for text classification can help address data imbalance or scarcity.
+
+ [:octicons-arrow-right-24: Tutorial](tutorials/generate_textcat_dataset.ipynb)
+
## Paper Implementations
@@ -83,6 +91,22 @@ hide: toc
[:octicons-arrow-right-24: Paper](papers/ultrafeedback.md)
+- __APIGen__
+
+ ---
+
+ Learn how to create verifiable high-quality datases for function-calling applications.
+
+ [:octicons-arrow-right-24: Paper](papers/apigen.md)
+
+- __CLAIR__
+
+ ---
+
+ Learn Contrastive Learning from AI Revisions (CLAIR), a data-creation method which leads to more contrastive preference pairs.
+
+ [:octicons-arrow-right-24: Paper](papers/clair.md)
+
## Examples
@@ -113,6 +137,14 @@ hide: toc
[:octicons-arrow-right-24: Example](examples/mistralai_with_instructor.md)
+- __Create a social network with FinePersonas__
+
+ ---
+
+ Learn how to leverage FinePersonas to create a synthetic social network and fine-tune adapters for Multi-LoRA.
+
+ [:octicons-arrow-right-24: Example](examples/fine_personas_social_network.md)
+
diff --git a/docs/sections/pipeline_samples/papers/apigen.md b/docs/sections/pipeline_samples/papers/apigen.md
new file mode 100644
index 000000000..5d3522c1b
--- /dev/null
+++ b/docs/sections/pipeline_samples/papers/apigen.md
@@ -0,0 +1,239 @@
+---
+hide: toc
+---
+
+# Create Function-Calling datasets with APIGen
+
+This example will introduce [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518), a data generation pipeline designed to synthesize verifiable high-quality datasets for function-calling applications.
+
+## Replication
+
+The following figure showcases the APIGen framework:
+
+![APIGen framework](../../../assets/tutorials-assets/overview-apigen.jpg)
+
+Now, let's walk through the key steps illustrated in the figure:
+
+- [`DataSampler`](https://distilabel.argilla.io/dev/components-gallery/step/datasampler/): With the help of this step and the original [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) we are getting the Seed QA Data Sampler for the prompt template.
+
+- [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/): This step does the job of the *Query-Answer Generator*, including the format checker from *Stage 1: Format Checker* thanks to the structured output generation.
+
+- [`APIGenExecutionChecker`](https://distilabel.argilla.io/dev/components-gallery/task/apigenexecutionchecker/): This step is in charge of the *Stage 2: Execution Checker*.
+
+- [`APIGenSemanticChecker`](https://distilabel.argilla.io/dev/components-gallery/task/apigensemanticchecker/): Step in charge of running *Stage 3: Semantic Checker*, can use the same or a different LLM, we are using the same as in [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/) step.
+
+The current implementation hasn't utilized the *Diverse Prompt Library*. To incorporate it, one could either adjust the prompt template within the [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/) or develop a new sampler specifically for this purpose. As for the *API Sampler*, while no specific data is shared here, we've created illustrative examples to demonstrate the pipeline's functionality. These examples represent a mix of data that could be used to replicate the sampler's output.
+
+## Data preparation
+
+The original paper tells about the data they used and give some hints, but nothing was shared. In this example, we will write a bunch of examples by hand to showcase how this pipeline can be built.
+
+Assume we have the following function names, and corresponding descriptions of their behaviour:
+
+```python
+data = [
+ {
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ },
+ {
+ "func_name": "permutation_count",
+ "func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
+ },
+ {
+ "func_name": "getdivision",
+ "func_desc": "Divides two numbers by making an API call to a division service.",
+ },
+ {
+ "func_name": "binary_addition",
+ "func_desc": "Adds two binary numbers and returns the result as a binary string.",
+ },
+ {
+ "func_name": "swapi_planet_resource",
+ "func_desc": "get a specific planets resource",
+ },
+ {
+ "func_name": "disney_character",
+ "func_desc": "Find a specific character using this endpoint",
+ }
+]
+```
+
+The original paper refers to both python functions and APIs, but we will make use of python functions exclusively for simplicity. In order to execute and check this functions/APIs, we need access to the code, which we have moved to a Python file: [lib_apigen.py](https://github.com/argilla-io/distilabel/blob/main/examples/lib_apigen.py). All this functions are executable, but we also need access to their *tool* representation. For this, we will make use of transformers' *get_json_schema* function[^1].
+
+[^1]: Read this nice blog post for more information on tools and the reasoning behind `get_json_schema`: [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use).
+
+We have all the machinery prepared in our libpath, except from the *tool* definition. With the help of our helper function `load_module_from_path` we will load this python module, collect all the tools, and add them to each row in our `data` variable.
+
+```python
+from distilabel.steps.tasks.apigen.utils import load_module_from_path
+
+libpath_module = load_module_from_path(libpath)
+tools = getattr(libpath_module, "get_tools")() # call get_tools()
+
+for row in data:
+ # The tools should have a mix where both the correct and irrelevant tools are present.
+ row.update({"tools": [tools[row["func_name"]]]})
+```
+
+Now we have all the necessary data for our prompt. Additionally, we will make use of the original dataset as few-shot examples to enhance the model:
+
+```python
+ds_og = (
+ load_dataset("Salesforce/xlam-function-calling-60k", split="train")
+ .shuffle(seed=42)
+ .select(range(500))
+ .to_list()
+)
+```
+
+We have just loaded a subset and transformed it to a list of dictionaries, as we will use it in the [`DataSampler`](https://distilabel.argilla.io/dev/components-gallery/steps/datasampler/) `GeneratorStep`, grabbing random examples from the original dataset.
+
+## Building the Pipeline
+
+Now that we've walked through each component, it's time to see how it all comes together, here's the Pipeline code:
+
+```python
+with Pipeline(name="apigen-example") as pipeline:
+ loader_seeds = LoadDataFromDicts(data=data) # (1)
+
+ sampler = DataSampler( # (2)
+ data=ds_og,
+ size=2,
+ samples=len(data),
+ batch_size=8,
+ )
+
+ prep_examples = PrepareExamples() # This step will add the 'examples' column
+
+ combine_steps = CombineOutputs() # (3)
+
+ model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct"
+ llm=InferenceEndpointsLLM( # (4)
+ model_id=model_id,
+ tokenizer_id=model_id,
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 2048,
+ },
+ )
+ apigen = APIGenGenerator( # (5)
+ llm=llm,
+ use_default_structured_output=True,
+ )
+
+ execution_checker = APIGenExecutionChecker(libpath=str(libpath)) # (6)
+ semantic_checker = APIGenSemanticChecker(llm=llm) # (7)
+
+ sampler >> prep_examples
+ (
+ [loader_seeds, prep_examples]
+ >> combine_steps
+ >> apigen
+ >> execution_checker
+ >> semantic_checker
+ )
+```
+
+1. Load the data seeds we are going to use to generate our function calling dataset.
+
+2. The `DataSampler` together with `PrepareExamples` will be used to help us create the few-shot
+examples from the original dataset to be fed in our prompt.
+
+3. Combine both columns to obtain a single stream of data
+
+4. Will reuse the same LLM for the generation and the semantic checks.
+
+5. Creates the `query` and `answers` that will be used together with the `tools` to fine-tune a new model. Will generate the structured outputs to ensure we have valid JSON formatted answers.
+
+6. Adds columns `keep_row_after_execution_check` and `execution_result`.
+
+7. Adds columns `keep_row_after_semantic_check` and `thought`.
+
+## Script and final dataset
+
+To see all the pieces in place, take a look at the full pipeline, as well as an example row that would be generated from this pipeline.
+
+??? Run
+
+ ```python
+ python examples/pipeline_apigen.py
+ ```
+
+```python title="pipeline_apigen.py"
+--8<-- "examples/pipeline_apigen.py"
+```
+
+Example row:
+
+```json
+{
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ "tools": [
+ {
+ "function": {
+ "description": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ "name": "final_velocity",
+ "parameters": {
+ "properties": {
+ "acceleration": {
+ "description": "The acceleration of the object.",
+ "type": "number"
+ },
+ "initial_velocity": {
+ "description": "The initial velocity of the object.",
+ "type": "number"
+ },
+ "time": {
+ "description": "The time elapsed.",
+ "type": "number"
+ }
+ },
+ "required": [
+ "initial_velocity",
+ "acceleration",
+ "time"
+ ],
+ "type": "object"
+ }
+ },
+ "type": "function"
+ }
+ ],
+ "examples": "## Query:\nRetrieve the first 15 comments for post ID '12345' from the Tokapi mobile API.\n## Answers:\n[{\"name\": \"v1_post_post_id_comments\", \"arguments\": {\"post_id\": \"12345\", \"count\": 15}}]\n\n## Query:\nRetrieve the detailed recipe for the cake with ID 'cake101'.\n## Answers:\n[{\"name\": \"detailed_cake_recipe_by_id\", \"arguments\": {\"is_id\": \"cake101\"}}]\n\n## Query:\nWhat are the frequently asked questions and their answers for Coca-Cola Company? Also, what are the suggested tickers based on Coca-Cola Company?\n## Answers:\n[{\"name\": \"symbols_faq\", \"arguments\": {\"ticker_slug\": \"KO\"}}, {\"name\": \"symbols_suggested\", \"arguments\": {\"ticker_slug\": \"KO\"}}]",
+ "query": "What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.",
+ "answers": "[{\"arguments\": {\"acceleration\": \"9.8\", \"initial_velocity\": \"0\", \"time\": \"10\"}, \"name\": \"final_velocity\"}]",
+ "distilabel_metadata": {
+ "raw_input_a_p_i_gen_generator_0": [
+ {
+ "content": "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively",
+ "role": "system"
+ },
+ {
+ "content": "Here are examples of queries and the corresponding answers for similar functions:\n## Query:\nRetrieve the first 15 comments for post ID '12345' from the Tokapi mobile API.\n## Answers:\n[{\"name\": \"v1_post_post_id_comments\", \"arguments\": {\"post_id\": \"12345\", \"count\": 15}}]\n\n## Query:\nRetrieve the detailed recipe for the cake with ID 'cake101'.\n## Answers:\n[{\"name\": \"detailed_cake_recipe_by_id\", \"arguments\": {\"is_id\": \"cake101\"}}]\n\n## Query:\nWhat are the frequently asked questions and their answers for Coca-Cola Company? Also, what are the suggested tickers based on Coca-Cola Company?\n## Answers:\n[{\"name\": \"symbols_faq\", \"arguments\": {\"ticker_slug\": \"KO\"}}, {\"name\": \"symbols_suggested\", \"arguments\": {\"ticker_slug\": \"KO\"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\n\nBased on these examples, generate 1 diverse query and answer pairs for the function `final_velocity`.\nThe detailed function description is the following:\nCalculates the final velocity of an object given its initial velocity, acceleration, and time.\n\nThese are the available tools to help you:\n[{'type': 'function', 'function': {'name': 'final_velocity', 'description': 'Calculates the final velocity of an object given its initial velocity, acceleration, and time.', 'parameters': {'type': 'object', 'properties': {'initial_velocity': {'type': 'number', 'description': 'The initial velocity of the object.'}, 'acceleration': {'type': 'number', 'description': 'The acceleration of the object.'}, 'time': {'type': 'number', 'description': 'The time elapsed.'}}, 'required': ['initial_velocity', 'acceleration', 'time']}}}]\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n {\n \"query\": \"The generated query.\",\n \"answers\": [\n {\n \"name\": \"api_name\",\n \"arguments\": {\n \"arg_name\": \"value\"\n ... (more arguments as required)\n }\n },\n ... (more API calls as required)\n ]\n }\n]\n```\n\nNow please generate 1 diverse query and answer pairs following the above format.",
+ "role": "user"
+ }
+ ],
+ "raw_input_a_p_i_gen_semantic_checker_0": [
+ {
+ "content": "As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user\u2019s intentions.\n\nDo not pass if:\n1. The function call does not align with the query\u2019s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user\u2019s intentions.\n4. The execution results are irrelevant and do not match the function\u2019s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.",
+ "role": "system"
+ },
+ {
+ "content": "Given Information:\n- All Available Functions:\nCalculates the final velocity of an object given its initial velocity, acceleration, and time.\n- User Query: What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.\n- Generated Function Calls: [{\"arguments\": {\"acceleration\": \"9.8\", \"initial_velocity\": \"0\", \"time\": \"10\"}, \"name\": \"final_velocity\"}]\n- Execution Results: ['9.8']\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query's intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n \"thought\": \"Concisely describe your reasoning here\",\n \"passes\": \"yes\" or \"no\"\n}\n```\n",
+ "role": "user"
+ }
+ ],
+ "raw_output_a_p_i_gen_generator_0": "{\"pairs\": [\n {\n \"answers\": [\n {\n \"arguments\": {\n \"acceleration\": \"9.8\",\n \"initial_velocity\": \"0\",\n \"time\": \"10\"\n },\n \"name\": \"final_velocity\"\n }\n ],\n \"query\": \"What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.\"\n }\n]}",
+ "raw_output_a_p_i_gen_semantic_checker_0": "{\n \"thought\": \"\",\n \"passes\": \"yes\"\n}"
+ },
+ "model_name": "meta-llama/Meta-Llama-3.1-70B-Instruct",
+ "keep_row_after_execution_check": true,
+ "execution_result": [
+ "9.8"
+ ],
+ "thought": "",
+ "keep_row_after_semantic_check": true
+}
+```
diff --git a/docs/sections/pipeline_samples/papers/clair.md b/docs/sections/pipeline_samples/papers/clair.md
new file mode 100644
index 000000000..8c0887460
--- /dev/null
+++ b/docs/sections/pipeline_samples/papers/clair.md
@@ -0,0 +1,84 @@
+# Contrastive Learning From AI Revisions (CLAIR)
+
+["Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment"](https://huggingface.co/papers/2408.06266) introduces both Contrastive
+Learning from AI Revisions (CLAIR), a data-creation method which leads to more contrastive preference pairs, and Anchored Preference Optimization (APO), a controllable and more stable alignment objective. While APO can be found in [TRL](https://huggingface.co/docs/trl/dpo_trainer#loss-functions), we have implemented a task for CLAIR in `distilabel`.
+
+CLAIR is a method for creating preference pairs which minimally revises one output to express a preference, resulting in a more precise learning signal as opposed to conventional methods which use a judge to select a preferred response.
+
+![CLAIR overview](../../../assets/pipelines/clair.png)
+
+The athors from the original paper shared a [collection of datasets from CLAIR and APO](https://huggingface.co/collections/ContextualAI/clair-and-apo-66b52868672bb1c984d1f3d5), where [ContextualAI/ultrafeedback_clair_32k](https://huggingface.co/datasets/ContextualAI/ultrafeedback_clair_32k) corresponds to the CLAIR implementation.
+
+### Replication
+
+!!! NOTE
+ The section is named `Replication` but in this case we are showing how to use the [`CLAIR`][distilabel.steps.tasks.clair.CLAIR] task create revisions for your generations using `distilabel`.
+
+To showcase CLAIR we will be using the [`CLAIR`][distilabel.steps.tasks.PrometheusEval] task implemented in `distilabel` and we are reusing a small sample of the already generated dataset by ContextualAI [`ContextualAI/ultrafeedback_clair_32k`](https://huggingface.co/datasets/ContextualAI/ultrafeedback_clair_32k) for testing.
+
+#### Installation
+
+To reproduce the code below, one will need to install `distilabel` as follows:
+
+```bash
+pip install "distilabel>=1.4.0"
+```
+
+Depending on the LLM provider you want to use, the requirements may vary, take a look at the dependencies in that case, we are using for the example the free inference endpoints from Hugging Face, but that won't apply for a bigger dataset.
+
+#### Building blocks
+
+In this case where we already have instructions and their generations, we will just need to load the data and the corresponding CLAIR task for the revisions:
+
+- [`CLAIR`](https://distilabel.argilla.io/dev/components-gallery/tasks/clair/) to generate the revisions.
+
+#### Code
+
+Let's see the full pipeline applied to `ContextualAI/ultrafeedback_clair_32k` in `distilabel`:
+
+```python
+from typing import Any, Dict
+
+from datasets import load_dataset
+
+from distilabel.pipeline import Pipeline
+from distilabel.steps.tasks import CLAIR
+from distilabel.llms import InferenceEndpointsLLM
+
+
+def transform_ultrafeedback(example: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "task": example["prompt"],
+ "student_solution": example["rejected"][1]["content"],
+ }
+
+dataset = (
+ load_dataset("ContextualAI/ultrafeedback_clair_32k", split="train")
+ .select(range(10)) # We collect just 10 examples
+ .map(transform_ultrafeedback) # Apply the transformation to get just the text
+)
+
+with Pipeline(name="CLAIR UltraFeedback sample") as pipeline:
+ clair = CLAIR( # (1)
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 4096
+ }
+ )
+ )
+
+
+if __name__ == "__main__":
+ distiset = pipeline.run(dataset=dataset) # (2)
+ distiset.push_to_hub(repo_id="username/clair-test", include_script=True) # (3)
+```
+
+1. This Pipeline uses just CLAIR because we already have the generations, but one can just include a first task to create generations from instructions, and then the revisions with CLAIR.
+
+2. Include the dataset directly in the run method for simplicity.
+
+3. Push the distiset to the hub with the script for reproducibility.
+
+An example dataset can be found at: [distilabel-internal-testing/clair-test](https://huggingface.co/datasets/distilabel-internal-testing/clair-test).
diff --git a/docs/sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb
new file mode 100644
index 000000000..c993f6acd
--- /dev/null
+++ b/docs/sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb
@@ -0,0 +1,981 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Generate synthetic text classification data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **Goal**: Generate synthetic text classification data to augment an imbalanced and limited dataset for training a topic classifier. In addition, generate new data for training a fact-based versus opinion-based classifier to add a new label.\n",
+ "- **Libraries**: [argilla](https://github.com/argilla-io/argilla), [hf-inference-endpoints](https://github.com/huggingface/huggingface_hub), [SetFit](https://github.com/huggingface/setfit)\n",
+ "- **Components**: [LoadDataFromDicts](https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromdicts/), [EmbeddingTaskGenerator](https://distilabel.argilla.io/latest/components-gallery/tasks/embeddingtaskgenerator/), [GenerateTextClassificationData](https://distilabel.argilla.io/latest/components-gallery/tasks/generatetextclassificationdata/)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Getting started\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install the dependencies\n",
+ "\n",
+ "To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using **the free but rate-limited Hugging Face serverless Inference API** for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"transformers~=4.40\" \"torch~=2.0\" \"setfit~=1.0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's make the required imports:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "from collections import Counter\n",
+ "\n",
+ "from datasets import load_dataset, Dataset\n",
+ "from distilabel.llms import InferenceEndpointsLLM\n",
+ "from distilabel.pipeline import Pipeline\n",
+ "from distilabel.steps import LoadDataFromDicts\n",
+ "from distilabel.steps.tasks import (\n",
+ " GenerateTextClassificationData,\n",
+ ")\n",
+ "from setfit import SetFitModel, Trainer, sample_dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You'll need an `HF_TOKEN` to use the HF Inference Endpoints. Log in to use it directly within this notebook.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from huggingface_hub import login\n",
+ "\n",
+ "login(token=os.getenv(\"HF_TOKEN\"), add_to_git_credential=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### (optional) Deploy Argilla\n",
+ "\n",
+ "You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).\n",
+ "\n",
+ "Along with that, you will need to install Argilla as a distilabel extra.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[argilla, hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## The dataset\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will use the [`fancyzhx/ag_news`](https://huggingface.co/datasets/fancyzhx/ag_news) dataset from the Hugging Face Hub as our original data source. To simulate a real-world scenario with imbalanced and limited data, we will load only 20 samples from this dataset.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hf_dataset = load_dataset(\"fancyzhx/ag_news\", split=\"train[-20:]\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can retrieve the available labels in the dataset and examine the current data distribution."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}\n",
+ "Counter({0: 12, 1: 6, 2: 2})\n"
+ ]
+ }
+ ],
+ "source": [
+ "labels_topic = hf_dataset.features[\"label\"].names\n",
+ "id2str = {i: labels_topic[i] for i in range(len(labels_topic))}\n",
+ "print(id2str)\n",
+ "print(Counter(hf_dataset[\"label\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As observed, the dataset is imbalanced, with most samples falling under the `World` category, while the `Sci/Tech` category is entirely missing. Moreover, there are insufficient samples to effectively train a topic classification model.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will also define the labels for the new classification task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "labels_fact_opinion = [\"Fact-based\", \"Opinion-based\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define the text classification task\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To generate the data we will use the `GenerateTextClassificationData` task. This task will use as input classification tasks and we can define the language, difficulty and clarity required for the generated data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[{'role': 'user', 'content': 'You have been assigned a text classification task: Classify the news article as fact-based or opinion-based\\n\\nYour mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:\\n - \"input_text\": a string, the input text specified by the classification task.\\n - \"label\": a string, the correct label of the input text.\\n - \"misleading_label\": a string, an incorrect label that is related to the task.\\n\\nPlease adhere to the following guidelines:\\n - The \"input_text\" should be diverse in expression.\\n - The \"misleading_label\" must be a valid label for the given task, but not as appropriate as the \"label\" for the \"input_text\".\\n - The values for all fields should be in English.\\n - Avoid including the values of the \"label\" and \"misleading_label\" fields in the \"input_text\", that would make the task too easy.\\n - The \"input_text\" is clear and requires college level education to comprehend.\\n\\nYour output must always be a JSON object only, do not explain yourself or output anything else. Be creative!'}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "task = GenerateTextClassificationData(\n",
+ " language=\"English\",\n",
+ " difficulty=\"college\",\n",
+ " clarity=\"clear\",\n",
+ " num_generations=1,\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.4},\n",
+ " ),\n",
+ " input_batch_size=5,\n",
+ ")\n",
+ "task.load()\n",
+ "result = next(\n",
+ " task.process([{\"task\": \"Classify the news article as fact-based or opinion-based\"}])\n",
+ ")\n",
+ "print(result[0][\"distilabel_metadata\"][\"raw_input_generate_text_classification_data_0\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For our use case, we only need to generate data for two tasks: a topic classification task and a fact versus opinion classification task. Therefore, we will define the tasks accordingly. As we will be using an smaller model for generation, we will select 2 random labels for each topic classification task and change the order for the fact versus opinion classification task ensuring more diversity in the generated data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "task_templates = [\n",
+ " \"Determine the news article as {}\",\n",
+ " \"Classify news article as {}\",\n",
+ " \"Identify the news article as {}\",\n",
+ " \"Categorize the news article as {}\",\n",
+ " \"Label the news article using {}\",\n",
+ " \"Annotate the news article based on {}\",\n",
+ " \"Determine the theme of a news article from {}\",\n",
+ " \"Recognize the topic of the news article as {}\",\n",
+ "]\n",
+ "\n",
+ "classification_tasks = [\n",
+ " {\"task\": action.format(\" or \".join(random.sample(labels_topic, 2)))}\n",
+ " for action in task_templates for _ in range(4)\n",
+ "] + [\n",
+ " {\"task\": action.format(\" or \".join(random.sample(labels_fact_opinion, 2)))}\n",
+ " for action in task_templates\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run the pipeline\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, it's time to define and run the pipeline. As mentioned, we will load the written tasks and feed them into the `GenerateTextClassificationData` task. For our use case, we will be using `Meta-Llama-3.1-8B-Instruct` via the `InferenceEndpointsLLM`, with different degrees of difficulty and clarity.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "difficulties = [\"college\", \"high school\", \"PhD\"]\n",
+ "clarity = [\"clear\", \"understandable with some effort\", \"ambiguous\"]\n",
+ "\n",
+ "with Pipeline(\"texcat-generation-pipeline\") as pipeline:\n",
+ "\n",
+ " tasks_generator = LoadDataFromDicts(data=classification_tasks)\n",
+ "\n",
+ " generate_data = []\n",
+ " for difficulty in difficulties:\n",
+ " for clarity_level in clarity:\n",
+ " task = GenerateTextClassificationData(\n",
+ " language=\"English\",\n",
+ " difficulty=difficulty,\n",
+ " clarity=clarity_level,\n",
+ " num_generations=2,\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " ),\n",
+ " input_batch_size=5,\n",
+ " )\n",
+ " generate_data.append(task)\n",
+ "\n",
+ " for task in generate_data:\n",
+ " tasks_generator.connect(task)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's now run the pipeline and generate the synthetic data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset = pipeline.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'task': 'Determine the news article as Business or World',\n",
+ " 'input_text': \"The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone's economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.\",\n",
+ " 'label': 'Business',\n",
+ " 'misleading_label': 'World',\n",
+ " 'distilabel_metadata': {'raw_output_generate_text_classification_data_0': '{\\n \"input_text\": \"The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone\\'s economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.\",\\n \"label\": \"Business\",\\n \"misleading_label\": \"World\"\\n}'},\n",
+ " 'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct'}"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "distiset[\"generate_text_classification_data_0\"][\"train\"][0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can push the dataset to the Hub for sharing with the community and [embed it to explore the data](https://huggingface.co/docs/hub/datasets-viewer-embed).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset.push_to_hub(\"[your-owner-name]/example-texcat-generation-dataset\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "By examining the distiset distribution, we can confirm that it includes at least the 8 required samples for each label to train our classification models with SetFit."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({'Sci/Tech': 275,\n",
+ " 'Business': 130,\n",
+ " 'World': 86,\n",
+ " 'Fact-based': 86,\n",
+ " 'Sports': 64,\n",
+ " 'Opinion-based': 54,\n",
+ " None: 20,\n",
+ " 'Opinion Based': 1,\n",
+ " 'News/Opinion': 1,\n",
+ " 'Science': 1,\n",
+ " 'Environment': 1,\n",
+ " 'Opinion': 1})"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "all_labels = [\n",
+ " entry[\"label\"]\n",
+ " for dataset_name in distiset\n",
+ " for entry in distiset[dataset_name][\"train\"]\n",
+ "]\n",
+ "\n",
+ "Counter(all_labels)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will create two datasets with the required labels and data for our use cases."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def extract_rows(distiset, labels):\n",
+ " return [\n",
+ " {\n",
+ " \"text\": entry[\"input_text\"],\n",
+ " \"label\": entry[\"label\"],\n",
+ " \"id\": i\n",
+ " }\n",
+ " for dataset_name in distiset\n",
+ " for i, entry in enumerate(distiset[dataset_name][\"train\"])\n",
+ " if entry[\"label\"] in labels\n",
+ " ]\n",
+ "\n",
+ "data_topic = extract_rows(distiset, labels_topic)\n",
+ "data_fact_opinion = extract_rows(distiset, labels_fact_opinion)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## (Optional) Evaluate with Argilla\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note \"Get started in Argilla\"\n",
+ " If you are not familiar with Argilla, we recommend taking a look at the [Argilla quickstart docs](https://docs.argilla.io/latest/getting_started/quickstart/). Alternatively, you can use your Hugging Face account to login to the [Argilla demo Space](https://argilla-argilla-template-space.hf.space).\n",
+ "\n",
+ "To get the most out of our data, we will use Argilla. First, we need to connect to the Argilla instance.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import argilla as rg\n",
+ "\n",
+ "# Replace api_url with your url if using Docker\n",
+ "# Replace api_key with your API key under \"My Settings\" in the UI\n",
+ "# Uncomment the last line and set your HF_TOKEN if your space is private\n",
+ "client = rg.Argilla(\n",
+ " api_url=\"https://[your-owner-name]-[your_space_name].hf.space\",\n",
+ " api_key=\"[your-api-key]\",\n",
+ " # headers={\"Authorization\": f\"Bearer {HF_TOKEN}\"}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will create a `Dataset` for each task, with an input `TextField` for the text classification text and a `LabelQuestion` to ensure the generated labels are correct.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def create_texcat_dataset(dataset_name, labels):\n",
+ " settings = rg.Settings(\n",
+ " fields=[rg.TextField(\"text\")],\n",
+ " questions=[\n",
+ " rg.LabelQuestion(\n",
+ " name=\"label\",\n",
+ " title=\"Classify the texts according to the following labels\",\n",
+ " labels=labels,\n",
+ " ),\n",
+ " ],\n",
+ " )\n",
+ " return rg.Dataset(name=dataset_name, settings=settings).create()\n",
+ "\n",
+ "\n",
+ "rg_dataset_topic = create_texcat_dataset(\"topic-classification\", labels_topic)\n",
+ "rg_dataset_fact_opinion = create_texcat_dataset(\n",
+ " \"fact-opinion-classification\", labels_fact_opinion\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can upload the generated data to Argilla and evaluate it. We will use the generated labels as suggestions.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rg_dataset_topic.records.log(data_topic)\n",
+ "rg_dataset_fact_opinion.records.log(data_fact_opinion)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records. If the suggestions are correct, you can just click on `Submit`. Otherwise, you can select the correct label.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note\n",
+ " Check this [how-to guide](https://docs.argilla.io/latest/how_to_guides/annotate/) to know more about annotating in the UI.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Once, you get the annotations, let's continue by retrieving the data from Argilla and format it as a dataset with the required data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rg_dataset_topic = client.datasets(\"topic-classification\")\n",
+ "rg_dataset_fact_opinion = client.datasets(\"fact-opinion-classification\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "status_filter = rg.Query(filter=rg.Filter((\"response.status\", \"==\", \"submitted\")))\n",
+ "\n",
+ "submitted_topic = rg_dataset_topic.records(status_filter).to_list(flatten=True)\n",
+ "submitted_fact_opinion = rg_dataset_fact_opinion.records(status_filter).to_list(\n",
+ " flatten=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def format_submitted(submitted):\n",
+ " return [\n",
+ " {\n",
+ " \"text\": r[\"text\"],\n",
+ " \"label\": r[\"label.responses\"][0],\n",
+ " \"id\": i,\n",
+ " }\n",
+ " for i, r in enumerate(submitted)\n",
+ " ]\n",
+ "\n",
+ "data_topic = format_submitted(submitted_topic)\n",
+ "data_fact_opinion = format_submitted(submitted_fact_opinion)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train your models\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In our case, we will fine-tune using SetFit. However, you can select the one that best fits your requirements.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Formatting the data\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The next step will be to format the data to be compatible with SetFit. In the case of the topic classification, we will need to combine the synthetic data with the original data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hf_topic = hf_dataset.to_list()\n",
+ "num = len(data_topic)\n",
+ "\n",
+ "data_topic.extend(\n",
+ " [\n",
+ " {\n",
+ " \"text\": r[\"text\"],\n",
+ " \"label\": id2str[r[\"label\"]],\n",
+ " \"id\": num + i,\n",
+ " }\n",
+ " for i, r in enumerate(hf_topic)\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If we check the data distribution now, we can see that we have enough samples for each label to train our models.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({'Sci/Tech': 275, 'Business': 132, 'World': 98, 'Sports': 70})"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "labels = [record[\"label\"] for record in data_topic]\n",
+ "Counter(labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({'Fact-based': 86, 'Opinion-based': 54})"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "labels = [record[\"label\"] for record in data_fact_opinion]\n",
+ "Counter(labels)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, let's create our training and validation datasets. The training dataset will gather 8 samples by label. In this case, the validation datasets will contain the remaining samples not included in the training datasets.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def sample_and_split(dataset, label_column, num_samples):\n",
+ " train_dataset = sample_dataset(\n",
+ " dataset, label_column=label_column, num_samples=num_samples\n",
+ " )\n",
+ " eval_dataset = dataset.filter(lambda x: x[\"id\"] not in set(train_dataset[\"id\"]))\n",
+ " return train_dataset, eval_dataset\n",
+ "\n",
+ "\n",
+ "dataset_topic_full = Dataset.from_list(data_topic)\n",
+ "dataset_fact_opinion_full = Dataset.from_list(data_fact_opinion)\n",
+ "\n",
+ "train_dataset_topic, eval_dataset_topic = sample_and_split(\n",
+ " dataset_topic_full, \"label\", 8\n",
+ ")\n",
+ "train_dataset_fact_opinion, eval_dataset_fact_opinion = sample_and_split(\n",
+ " dataset_fact_opinion_full, \"label\", 8\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### The actual training\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's train our models for each task! We will use [TaylorAI/bge-micro-v2](https://huggingface.co/TaylorAI/bge-micro-v2), available in the Hugging Face Hub. You can check the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select the best model for your use case."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 126,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def train_model(model_name, dataset, eval_dataset):\n",
+ " model = SetFitModel.from_pretrained(model_name)\n",
+ "\n",
+ " trainer = Trainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " )\n",
+ " trainer.train()\n",
+ " metrics = trainer.evaluate(eval_dataset)\n",
+ " print(metrics)\n",
+ "\n",
+ " return model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 125,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running training *****\n",
+ " Num unique pairs = 768\n",
+ " Batch size = 16\n",
+ " Num epochs = 1\n",
+ " Total optimization steps = 48\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'embedding_loss': 0.1873, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.02}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running evaluation *****\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'train_runtime': 4.9767, 'train_samples_per_second': 154.318, 'train_steps_per_second': 9.645, 'epoch': 1.0}\n",
+ "{'accuracy': 0.8333333333333334}\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_topic = train_model(\n",
+ " model_name=\"TaylorAI/bge-micro-v2\",\n",
+ " dataset=train_dataset_topic,\n",
+ " eval_dataset=eval_dataset_topic,\n",
+ ")\n",
+ "model_topic.save_pretrained(\"topic_classification_model\")\n",
+ "model_topic = SetFitModel.from_pretrained(\"topic_classification_model\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 128,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running training *****\n",
+ " Num unique pairs = 144\n",
+ " Batch size = 16\n",
+ " Num epochs = 1\n",
+ " Total optimization steps = 9\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'embedding_loss': 0.2985, 'learning_rate': 2e-05, 'epoch': 0.11}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running evaluation *****\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'train_runtime': 0.8327, 'train_samples_per_second': 172.931, 'train_steps_per_second': 10.808, 'epoch': 1.0}\n",
+ "{'accuracy': 0.9090909090909091}\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_fact_opinion = train_model(\n",
+ " model_name=\"TaylorAI/bge-micro-v2\",\n",
+ " dataset=train_dataset_fact_opinion,\n",
+ " eval_dataset=eval_dataset_fact_opinion,\n",
+ ")\n",
+ "model_fact_opinion.save_pretrained(\"fact_opinion_classification_model\")\n",
+ "model_fact_opinion = SetFitModel.from_pretrained(\"fact_opinion_classification_model\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Voilà! The models are now trained and ready to be used. You can start making predictions to check the model's performance and add the new label. Optionally, you can continue using distilabel to generate additional data or Argilla to verify the quality of the predictions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 129,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def predict(model, input, labels):\n",
+ " model.labels = labels\n",
+ " prediction = model.predict([input])\n",
+ " return prediction[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 130,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Sci/Tech'"
+ ]
+ },
+ "execution_count": 130,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "predict(\n",
+ " model_topic, \"The new iPhone is expected to be released next month.\", labels_topic\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 131,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Opinion-based'"
+ ]
+ },
+ "execution_count": 131,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "predict(\n",
+ " model_fact_opinion,\n",
+ " \"The new iPhone is expected to be released next month.\",\n",
+ " labels_fact_opinion,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Conclusions\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we showcased the detailed steps to build a pipeline for generating text classification data using distilabel. You can customize this pipeline for your own use cases and share your datasets with the community through the Hugging Face Hub.\n",
+ "\n",
+ "We defined two text classification tasks—a topic classification task and a fact versus opinion classification task—and generated new data using various models via the serverless Hugging Face Inference API. Then, we curated the generated data with Argilla. Finally, we trained the models with SetFit using both the original and synthetic data."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "distilabel-tutorials",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/arena_hard.py b/examples/arena_hard.py
index 4f6f0c05c..b193bc234 100644
--- a/examples/arena_hard.py
+++ b/examples/arena_hard.py
@@ -15,11 +15,12 @@
import re
from typing import Any, Dict, List, Optional, Union
+from typing_extensions import override
+
from distilabel.steps import GlobalStep, StepInput
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import StepOutput
-from typing_extensions import override
class ArenaHard(Task):
diff --git a/examples/deepseek_prover.py b/examples/deepseek_prover.py
index b61f1c683..07b050964 100644
--- a/examples/deepseek_prover.py
+++ b/examples/deepseek_prover.py
@@ -17,14 +17,15 @@
from textwrap import dedent
from typing import Any, Dict, List, Optional, Union
+from jinja2 import Template
+from pydantic import PrivateAttr
+from typing_extensions import override
+
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import ChatType
-from jinja2 import Template
-from pydantic import PrivateAttr
-from typing_extensions import override
_PARSE_DEEPSEEK_PROVER_AUTOFORMAL_REGEX = r"```lean4(.*?)```"
diff --git a/examples/lib_apigen.py b/examples/lib_apigen.py
new file mode 100644
index 000000000..d49f414e6
--- /dev/null
+++ b/examples/lib_apigen.py
@@ -0,0 +1,146 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+
+
+def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
+ """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
+
+ Args:
+ initial_velocity: The initial velocity of the object.
+ acceleration: The acceleration of the object.
+ time: The time elapsed.
+
+ Returns:
+ The final velocity
+ """
+ # Tool:
+ # {"name": "final_velocity", "description": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", "parameters": {"initial_velocity": {"description": "The initial velocity of the object.", "type": "float"}, "acceleration": {"description": "The acceleration of the object.", "type": "float"}, "time": {"description": "The time elapsed.", "type": "float"}}}
+ # Answer:
+ # {"name": "final_velocity", "arguments": {"initial_velocity": 5, "acceleration": 1.5, "time": 40}}
+ return initial_velocity + acceleration * time
+
+
+def permutation_count(n: int, k: int) -> int:
+ """Calculates the number of permutations of k elements from a set of n elements.
+
+ Args:
+ n: The total number of elements in the set.
+ k: The number of elements to choose for the permutation.
+
+ Returns:
+ The number of permutations.
+ """
+ # Tool:
+ # {"name": "permutation_count", "description": "Calculates the number of permutations of k elements from a set of n elements.", "parameters": {"n": {"description": "The total number of elements in the set.", "type": "int"}, "k": {"description": "The number of elements to choose for the permutation.", "type": "int"}}}
+ # Answer:
+ # {"name": "permutation_count", "arguments": {"n": 10, "k": 3}}
+ import math
+
+ return math.factorial(n) / math.factorial(n - k)
+
+
+def getdivision(dividend: int, divisor: int) -> float:
+ """Divides two numbers by making an API call to a division service.
+
+ Args:
+ dividend: The dividend in the division operation.
+ divisor: The divisor in the division operation.
+
+ Returns:
+ Division of the 2 numbers.
+ """
+ # Tool:
+ # {"name": "getdivision", "description": "Divides two numbers by making an API call to a division service.", "parameters": {"divisor": {"description": "The divisor in the division operation.", "type": "int", "default": ""}, "dividend": {"description": "The dividend in the division operation.", "type": "int", "default": ""}}}
+ # Answer:
+ # {"name": "getdivision", "arguments": {"divisor": 25, "dividend": 100}}
+ return dividend / divisor
+
+
+def binary_addition(a: str, b: str) -> str:
+ """Adds two binary numbers and returns the result as a binary string.
+
+ Args:
+ a: The first binary number.
+ b: The second binary number.
+
+ Raises:
+ ValueError: On invalid binary number.
+
+ Returns:
+ Binary string of the sum of the two numbers.
+ """
+ # Tool:
+ # {"name": "binary_addition", "description": "Adds two binary numbers and returns the result as a binary string.", "parameters": {"a": {"description": "The first binary number.", "type": "str"}, "b": {"description": "The second binary number.", "type": "str"}}}
+ # Answer:
+ # {"name": "binary_addition", "arguments": {"a": "1010", "b": "1101"}}
+ if not set(a).issubset("01") or not set(b).issubset("01"):
+ raise ValueError("Invalid binary number")
+
+ return bin(int(a, 2) + int(b, 2))[2:]
+
+
+def _make_request(url: str, params: Optional[Dict[str, Any]] = None):
+ import requests
+
+ req = requests.get(url, params=params)
+ return req.json()
+
+
+def swapi_planet_resource(id: str) -> Dict[str, Any]:
+ """get a specific planets resource
+
+ Args:
+ id: identifier of the planet
+
+ Returns:
+ Information about the planet.
+ """
+ # url = "https://swapi.dev/api/planets/1"
+ return _make_request(r"https://swapi.dev/api/planets/", params={"id": id})
+
+
+def disney_character(name: str) -> Dict[str, Any]:
+ """Find a specific character using this endpoint
+
+ Args:
+ name: Name of the character to look for.
+
+ Returns:
+ Infrmation about the character.
+ """
+ # Example:
+ # url = "https://api.disneyapi.dev/character"
+ # params = {"name": "mulan"}
+ return _make_request(r"https://api.disneyapi.dev/character", params={"name": name})
+
+
+def get_lib():
+ return {
+ "swapi_planet_resource": swapi_planet_resource,
+ "disney_character": disney_character,
+ "final_velocity": final_velocity,
+ "permutation_count": permutation_count,
+ "getdivision": getdivision,
+ "binary_addition": binary_addition,
+ }
+
+
+def get_tools() -> Dict[str, Dict[str, Any]]:
+ """Returns the tool representation of the functions in the library."""
+ # TODO: Improve the `get_json_schema`, it fails on a lot of examples.
+ from transformers.utils import get_json_schema
+
+ return {name: get_json_schema(func) for name, func in get_lib().items()}
diff --git a/examples/pipeline_apigen.py b/examples/pipeline_apigen.py
new file mode 100644
index 000000000..e63e16e39
--- /dev/null
+++ b/examples/pipeline_apigen.py
@@ -0,0 +1,116 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+
+from datasets import load_dataset
+
+from distilabel.llms import InferenceEndpointsLLM
+from distilabel.pipeline import Pipeline
+from distilabel.steps import CombineOutputs, DataSampler, LoadDataFromDicts
+from distilabel.steps.tasks import (
+ APIGenExecutionChecker,
+ APIGenGenerator,
+ APIGenSemanticChecker,
+)
+from distilabel.steps.tasks.apigen.utils import PrepareExamples, load_module_from_path
+
+libpath = Path(__file__).parent / "lib_apigen.py"
+
+data = [
+ {
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ },
+ {
+ "func_name": "permutation_count",
+ "func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
+ },
+ {
+ "func_name": "getdivision",
+ "func_desc": "Divides two numbers by making an API call to a division service.",
+ },
+ {
+ "func_name": "binary_addition",
+ "func_desc": "Adds two binary numbers and returns the result as a binary string.",
+ },
+ {
+ "func_name": "swapi_planet_resource",
+ "func_desc": "get a specific planets resource",
+ },
+ {
+ "func_name": "disney_character",
+ "func_desc": "Find a specific character using this endpoint",
+ },
+]
+
+libpath_module = load_module_from_path(libpath)
+tools = libpath_module.get_tools() # call get_tools()
+
+# TODO: Add in the tools between 0 and 2 extra tools to make the task more challenging.
+for row in data:
+ # The tools should have a mix where both the correct and irrelevant tools are present.
+ row.update({"tools": [tools[row["func_name"]]]})
+
+
+ds_og = (
+ load_dataset("Salesforce/xlam-function-calling-60k", split="train")
+ .shuffle(seed=42)
+ .select(range(500))
+ .to_list()
+)
+
+
+with Pipeline(name="APIGenPipeline") as pipeline:
+ loader_seeds = LoadDataFromDicts(data=data)
+ sampler = DataSampler(
+ data=ds_og,
+ size=2,
+ samples=len(data),
+ batch_size=8,
+ )
+
+ prep_examples = PrepareExamples()
+
+ model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct"
+ llm = InferenceEndpointsLLM(
+ model_id=model_id,
+ tokenizer_id=model_id,
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 2048,
+ },
+ )
+ apigen = APIGenGenerator(
+ llm=llm,
+ use_default_structured_output=True,
+ )
+ combine_steps = CombineOutputs()
+
+ execution_checker = APIGenExecutionChecker(libpath=str(libpath))
+ semantic_checker = APIGenSemanticChecker(llm=llm)
+
+ sampler >> prep_examples
+ (
+ [loader_seeds, prep_examples]
+ >> combine_steps
+ >> apigen
+ >> execution_checker
+ >> semantic_checker
+ )
+
+
+if __name__ == "__main__":
+ distiset = pipeline.run()
+ print(distiset["default"]["train"][0])
diff --git a/examples/structured_generation_with_instructor.py b/examples/structured_generation_with_instructor.py
index 48082886f..0808e56ca 100644
--- a/examples/structured_generation_with_instructor.py
+++ b/examples/structured_generation_with_instructor.py
@@ -14,11 +14,12 @@
from typing import List
+from pydantic import BaseModel, Field
+
from distilabel.llms import MistralLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
-from pydantic import BaseModel, Field
class Node(BaseModel):
diff --git a/examples/structured_generation_with_outlines.py b/examples/structured_generation_with_outlines.py
index 98ee59ed6..b92cb6082 100644
--- a/examples/structured_generation_with_outlines.py
+++ b/examples/structured_generation_with_outlines.py
@@ -15,12 +15,13 @@
from enum import Enum
from pathlib import Path
+from pydantic import BaseModel, StringConstraints, conint
+from typing_extensions import Annotated
+
from distilabel.llms import LlamaCppLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
-from pydantic import BaseModel, StringConstraints, conint
-from typing_extensions import Annotated
class Weapon(str, Enum):
diff --git a/mkdocs.yml b/mkdocs.yml
index 2ef623467..b174850f3 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -162,6 +162,7 @@ plugins:
- social
- mknotebooks
- material-plausible
+ - glightbox
- distilabel/components-gallery:
add_after_page: How-to guides
@@ -185,7 +186,7 @@ nav:
- Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md"
- Advanced:
- The Distiset dataset object: "sections/how_to_guides/advanced/distiset.md"
- - Cachinc and recovering pipelines: "sections/how_to_guides/advanced/caching.md"
+ - Pipeline cache: "sections/how_to_guides/advanced/caching.md"
- Exporting data to Argilla: "sections/how_to_guides/advanced/argilla.md"
- Structured data generation: "sections/how_to_guides/advanced/structured_generation.md"
- Offline Batch Generation: "sections/how_to_guides/advanced/offline_batch_generation.md"
@@ -202,12 +203,15 @@ nav:
- Generate a preference dataset: "sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb"
- Clean an existing preference dataset: "sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb"
- Synthetic data generation for fine-tuning custom retrieval and reranking models: "sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb"
+ - Generate synthetic text classification data: "sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb"
- Papers:
- DeepSeek Prover: "sections/pipeline_samples/papers/deepseek_prover.md"
- DEITA: "sections/pipeline_samples/papers/deita.md"
- Instruction Backtranslation: "sections/pipeline_samples/papers/instruction_backtranslation.md"
- Prometheus 2: "sections/pipeline_samples/papers/prometheus.md"
- UltraFeedback: "sections/pipeline_samples/papers/ultrafeedback.md"
+ - APIGen: "sections/pipeline_samples/papers/apigen.md"
+ - CLAIR: "sections/pipeline_samples/papers/clair.md"
- Examples:
- Benchmarking with distilabel: "sections/pipeline_samples/examples/benchmarking_with_distilabel.md"
- Structured generation with outlines: "sections/pipeline_samples/examples/llama_cpp_with_outlines.md"
diff --git a/pyproject.toml b/pyproject.toml
index b3b25efa1..72fac378e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,6 +55,7 @@ docs = [
"mkdocs-literate-nav >= 0.6.1",
"mkdocs-section-index >= 0.3.8",
"mkdocs-gen-files >= 0.5.0",
+ "mkdocs-glightbox >= 0.4.0",
"material-plausible-plugin>=0.2.0",
"mike >= 2.0.0",
"Pillow >= 9.5.0",
@@ -84,7 +85,7 @@ llama-cpp = ["llama-cpp-python >= 0.2.0"]
mistralai = ["mistralai >= 1.0.0"]
ollama = ["ollama >= 0.1.7"]
openai = ["openai >= 1.0.0"]
-outlines = ["outlines >= 0.0.40"]
+outlines = ["outlines >= 0.0.40", "numba >= 0.54.0"]
ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = [
@@ -99,7 +100,7 @@ faiss-gpu = ["faiss-gpu >= 1.7.2"]
text-clustering = [
"umap-learn >= 0.5.6",
"scikit-learn >= 1.4.1",
- "matplotlib >= 3.8.3" # For the figure (even though it's optional)
+ "matplotlib >= 3.8.3", # For the figure (even though it's optional)
]
# minhash
diff --git a/scripts/install_cpu_vllm.sh b/scripts/install_cpu_vllm.sh
index 7535c8821..bdaa7ad74 100755
--- a/scripts/install_cpu_vllm.sh
+++ b/scripts/install_cpu_vllm.sh
@@ -15,7 +15,7 @@ which python
echo "Installing Python build dependencies..."
python -m pip install --upgrade pip
-python -m pip install wheel packaging ninja "setuptools>=49.4.0" numpy
+python -m pip install wheel packaging ninja "setuptools>=49.4.0" numpy setuptools-scm
echo "Cloning 'vllm-project/vllm' GitHub repository..."
git clone https://github.com/vllm-project/vllm.git
diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh
index 767f6e6dd..0b2277f0f 100755
--- a/scripts/install_dependencies.sh
+++ b/scripts/install_dependencies.sh
@@ -9,10 +9,9 @@ python -m pip install uv
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]"
if [ "${python_version}" != "(3, 12)" ]; then
- uv pip install --system -e .[ray]
+ uv pip install --system -e .[ray]
fi
./scripts/install_cpu_vllm.sh
-uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git
uv pip install --system -e ".[dev,tests]"
diff --git a/scripts/install_docs_dependencies.sh b/scripts/install_docs_dependencies.sh
new file mode 100755
index 000000000..c768b2295
--- /dev/null
+++ b/scripts/install_docs_dependencies.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+set -e
+
+python_version=$(python -c "import sys; print(sys.version_info[:2])")
+
+python -m pip install uv
+
+uv pip install --system -e ".[docs]"
diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py
index bafff914b..47628af33 100644
--- a/src/distilabel/__init__.py
+++ b/src/distilabel/__init__.py
@@ -14,6 +14,6 @@
from rich import traceback as rich_traceback
-__version__ = "1.4.0"
+__version__ = "1.5.0"
rich_traceback.install(show_locals=True)
diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py
index b6b31f7a5..8e52c667d 100644
--- a/src/distilabel/distiset.py
+++ b/src/distilabel/distiset.py
@@ -695,8 +695,9 @@ def _grab_citations(dag: "DAG") -> List[str]:
for ref in references.values():
try:
bibtex_refs.append(get_bibtex(ref))
- except ValueError as e:
- print(f"Error: {e}")
+ except ValueError:
+ # No need to inform in this case, it's noise
+ pass
except AttributeError as e:
print(
f"Couldn't obtain the bibtex format for the ref: '{ref}', error: {e}"
diff --git a/src/distilabel/llms/_dummy.py b/src/distilabel/llms/_dummy.py
new file mode 100644
index 000000000..740f98cd4
--- /dev/null
+++ b/src/distilabel/llms/_dummy.py
@@ -0,0 +1,70 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, List
+
+from distilabel.llms.base import LLM, AsyncLLM
+from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
+
+
+class DummyAsyncLLM(AsyncLLM):
+ structured_output: Any = None
+
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ async def agenerate( # type: ignore
+ self, input: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return ["output" for _ in range(num_generations)]
+
+
+class DummySyncLLM(LLM):
+ structured_output: Any = None
+
+ def load(self) -> None:
+ super().load()
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate( # type: ignore
+ self, inputs: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return [["output" for _ in range(num_generations)] for _ in range(len(inputs))]
+
+
+class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate(
+ self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any
+ ) -> List["GenerateOutput"]:
+ return [
+ ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs))
+ ]
diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py
index b73efc223..48cac8a50 100644
--- a/src/distilabel/llms/openai.py
+++ b/src/distilabel/llms/openai.py
@@ -667,7 +667,7 @@ def _create_jsonl_row(
"""Creates a JSONL formatted row to be used by the OpenAI Batch API.
Args:
- inputs: a list of inputs in chat format to generate responses for, optionally
+ input: a list of inputs in chat format to generate responses for, optionally
including structured output.
custom_id: a custom ID to use for the row.
kwargs: the keyword arguments to use for the generation.
diff --git a/src/distilabel/mixins/signature.py b/src/distilabel/mixins/signature.py
new file mode 100644
index 000000000..b014f03e9
--- /dev/null
+++ b/src/distilabel/mixins/signature.py
@@ -0,0 +1,83 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+from typing import TYPE_CHECKING, Any, List, Set
+
+from pydantic import BaseModel, Field
+
+from distilabel.utils.serialization import TYPE_INFO_KEY
+
+if TYPE_CHECKING:
+ pass
+
+# Add here the name of the attributes that shouldn't be used to generate the signature.
+# Attributes from a `BaseModel` that is an attribute from the root class must be prefixed
+# with the name of the attribute followed by an underscore. For example, if the attribute
+# `jobs_ids` is an attribute from the `llm` attribute of the root class it should be added
+# as `llm_jobs_ids`.
+_EXCLUDE_FROM_SIGNATURE_DEFAULTS = {
+ TYPE_INFO_KEY,
+ "disable_cuda_device_placement",
+ "input_batch_size",
+ "gpu_memory_utilization",
+ "resources",
+ "exclude_from_signature",
+ "llm_jobs_ids",
+ "llm_offline_batch_generation_block_until_done",
+}
+
+
+class SignatureMixin(BaseModel):
+ """Mixin for creating a signature (for cache) of the class.
+
+ Attributes:
+ exclude_from_signature: list of attributes to exclude from the signature.
+ """
+
+ exclude_from_signature: Set[str] = Field(
+ default=_EXCLUDE_FROM_SIGNATURE_DEFAULTS, exclude=True
+ )
+
+ @property
+ def signature(self) -> str:
+ """Makes a signature (hash) of the class, using its attributes.
+
+ Returns:
+ signature of the class.
+ """
+
+ def flatten_dump(d: Any, parent_key: str = "", sep: str = "_") -> List:
+ items = []
+ for k, v in d.items():
+ new_key = parent_key + sep + k if parent_key else k
+ if isinstance(v, dict):
+ items.extend(flatten_dump(v, new_key, sep=sep))
+ elif isinstance(v, list):
+ if len(v) == 0:
+ items.append((new_key, ""))
+ elif isinstance(v[0], str):
+ items.append((new_key, "-".join(v)))
+ else:
+ for i, x in enumerate(v):
+ items.extend(flatten_dump(x, f"{new_key}-{i}", sep=sep))
+ elif new_key not in self.exclude_from_signature:
+ items.append((new_key, v))
+ return items
+
+ info = []
+ for name, value in flatten_dump(self.dump()):
+ info.append(f"{name}-{str(value)}")
+
+ return hashlib.sha1("-".join(info).encode()).hexdigest()
diff --git a/src/distilabel/pipeline/_dag.py b/src/distilabel/pipeline/_dag.py
index 3cc22a534..5962ecc4f 100644
--- a/src/distilabel/pipeline/_dag.py
+++ b/src/distilabel/pipeline/_dag.py
@@ -33,6 +33,7 @@
from distilabel.constants import (
CONVERGENCE_STEP_ATTR_NAME,
+ RECEIVES_ROUTED_BATCHES_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
@@ -253,6 +254,21 @@ def is_step_in_trophic_level(self, step_name: str, trophic_level: int) -> bool:
"""
return self.get_step_trophic_level(step_name) == trophic_level
+ def is_convergence_step(self, step_name: str) -> bool:
+ """Checks if a given step is a convegence step.
+
+ Args:
+ step_name: Name of the step to check if a convergence step.
+
+ Returns:
+ True if it is, False otherwise.
+ """
+ predecessors = list(self.get_step_predecessors(step_name))
+ return all(
+ self.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False)
+ for predecessor in predecessors
+ )
+
def step_in_last_trophic_level(self, step_name: str) -> bool:
"""Checks if a step is in the last trophic level.
diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py
index 1b4a2d827..c3397c392 100644
--- a/src/distilabel/pipeline/base.py
+++ b/src/distilabel/pipeline/base.py
@@ -14,6 +14,7 @@
import hashlib
import logging
import os
+import shutil
import signal
import threading
import time
@@ -51,7 +52,6 @@
from distilabel.utils.logging import setup_logging, stop_logging
from distilabel.utils.notebook import in_notebook
from distilabel.utils.serialization import (
- TYPE_INFO_KEY,
_Serializable,
read_json,
)
@@ -87,6 +87,7 @@ class _CacheLocation(TypedDict):
pipeline: Path
batch_manager: Path
+ steps_data: Path
data: Path
batch_input_data: Path
log_file: Path
@@ -125,7 +126,6 @@ def get_pipeline(cls) -> Union["BasePipeline", None]:
_STEP_LOAD_FAILED_CODE = -666
_STEP_NOT_LOADED_CODE = -999
-_ATTRIBUTES_IGNORED_CACHE = ("disable_cuda_device_placement", "jobs_ids")
_PIPELINE_DEFAULT_NAME = "__default_pipeline_name__"
@@ -242,48 +242,18 @@ def _set_pipeline_name(self) -> None:
if self.name == _PIPELINE_DEFAULT_NAME:
self.name = f"pipeline_{'_'.join(self.dag)}"
- def _create_signature(self) -> str:
+ @property
+ def signature(self) -> str:
"""Makes a signature (hash) of a pipeline, using the step ids and the adjacency between them.
The main use is to find the pipeline in the cache folder.
Returns:
- int: Signature of the pipeline.
+ Signature of the pipeline.
"""
- hasher = hashlib.sha1()
- steps_info = []
pipeline_dump = self.dump()["pipeline"]
-
- for step in pipeline_dump["steps"]:
- step_info = step["name"]
- for argument, value in sorted(step[constants.STEP_ATTR_NAME].items()):
- if (argument == TYPE_INFO_KEY) or (value is None):
- continue
-
- if isinstance(value, dict):
- # input_mappings/output_mappings
- step_info += "-".join(
- [
- f"{str(k)}={str(v)}"
- for k, v in value.items()
- if k not in _ATTRIBUTES_IGNORED_CACHE
- ]
- )
- elif isinstance(value, (list, tuple)):
- # runtime_parameters_info
- step_info += "-".join([str(v) for v in value])
- elif isinstance(value, (int, str, float, bool)):
- if argument not in _ATTRIBUTES_IGNORED_CACHE:
- # batch_size/name
- step_info += str(value)
- else:
- raise ValueError(
- f"Field '{argument}' in step '{step['name']}' has type {type(value)}, explicitly cast the type to 'str'."
- )
-
- steps_info.append(step_info)
-
+ steps_names = list(self.dag)
connections_info = [
f"{c['from']}-{'-'.join(c['to'])}" for c in pipeline_dump["connections"]
]
@@ -296,14 +266,13 @@ def _create_signature(self) -> str:
]
if type_info := routing_batch_function._get_type_info():
step += f"-{type_info}"
+ routing_batch_functions_info.append(step)
- hasher.update(
+ return hashlib.sha1(
",".join(
- steps_info + connections_info + routing_batch_functions_info
+ steps_names + connections_info + routing_batch_functions_info
).encode()
- )
-
- return hasher.hexdigest()
+ ).hexdigest()
def run(
self,
@@ -415,7 +384,7 @@ def run(
stop_logging()
return distiset
- self._setup_write_buffer()
+ self._setup_write_buffer(use_cache)
self._print_load_stages_info()
@@ -716,16 +685,32 @@ def _cache_location(self) -> "_CacheLocation":
Returns:
Path: Filenames where the pipeline content will be serialized.
"""
- folder = self._cache_dir / self.name / self._create_signature()
+ folder = self._cache_dir / self.name / self.signature
+ pipeline_execution_dir = folder / "executions" / self.aggregated_steps_signature
return {
- "pipeline": folder / "pipeline.yaml",
- "batch_manager": folder / "batch_manager.json",
- "data": folder / "data",
- "batch_input_data": folder / "batch_input_data",
- "log_file": folder / "pipeline.log",
- "stages_file": folder / "stages.json",
+ "pipeline": pipeline_execution_dir / "pipeline.yaml",
+ "batch_manager": pipeline_execution_dir / "batch_manager.json",
+ "steps_data": self._cache_dir / self.name / "steps_data",
+ "data": pipeline_execution_dir / "data",
+ "batch_input_data": pipeline_execution_dir / "batch_input_data",
+ "log_file": pipeline_execution_dir / "pipeline.log",
+ "stages_file": pipeline_execution_dir / "stages.json",
}
+ @property
+ def aggregated_steps_signature(self) -> str:
+ """Creates an aggregated signature using `Step`s signature that will be used for
+ the `_BatchManager`.
+
+ Returns:
+ The aggregated signature.
+ """
+ signatures = []
+ for step_name in self.dag:
+ step: "_Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
+ signatures.append(step.signature)
+ return hashlib.sha1("".join(signatures).encode()).hexdigest()
+
def _cache(self) -> None:
"""Saves the `BasePipeline` using the `_cache_filename`."""
if self._dry_run:
@@ -737,7 +722,10 @@ def _cache(self) -> None:
)
if self._batch_manager is not None:
- self._batch_manager.cache(self._cache_location["batch_manager"])
+ self._batch_manager.cache(
+ path=self._cache_location["batch_manager"],
+ steps_data_path=self._cache_location["steps_data"],
+ )
self._save_stages_status()
@@ -814,30 +802,90 @@ def recursively_handle_secrets_and_excluded_attributes(
def _load_batch_manager(self, use_cache: bool = True) -> None:
"""Will try to load the `_BatchManager` from the cache dir if found. Otherwise,
it will create one from scratch.
+
+ If the `_BatchManager` is loaded from cache, we check for invalid steps (those that
+ may have a different signature than the original in the pipeline folder), and
+ restart them, as well as their successors.
+
+ Args:
+ use_cache: whether the cache should be used or not.
"""
batch_manager_cache_loc = self._cache_location["batch_manager"]
+
+ # This first condition handles the case in which the pipeline is exactly the same
+ # no steps have been added, removed or changed.
if use_cache and batch_manager_cache_loc.exists():
self._logger.info(
f"💾 Loading `_BatchManager` from cache: '{batch_manager_cache_loc}'"
)
- self._batch_manager = _BatchManager.load_from_cache(batch_manager_cache_loc)
+ self._batch_manager = _BatchManager.load_from_cache(
+ dag=self.dag,
+ batch_manager_path=batch_manager_cache_loc,
+ steps_data_path=self._cache_location["steps_data"],
+ )
+ self._invalidate_steps_cache_if_required()
+ # In this other case, the pipeline has been changed. We need to create a new batch
+ # manager and if `use_cache==True` then check which outputs have we computed and
+ # cached for steps that haven't changed but that were executed in another pipeline,
+ # and therefore we can reuse
else:
- self._batch_manager = _BatchManager.from_dag(self.dag)
+ self._batch_manager = _BatchManager.from_dag(
+ dag=self.dag,
+ use_cache=use_cache,
+ steps_data_path=self._cache_location["steps_data"],
+ )
+
+ def _invalidate_steps_cache_if_required(self) -> None:
+ """Iterates over the steps of the pipeline and invalidates their cache if required."""
+ for step_name in self.dag:
+ # `GeneratorStep`s doesn't receive input data so no need to check their
+ # `_BatchManagerStep`
+ if self.dag.get_step(step_name)[constants.STEP_ATTR_NAME].is_generator:
+ continue
+
+ step: "_Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME]
+ if not step.use_cache:
+ self._batch_manager.invalidate_cache_for(
+ step_name=step.name,
+ dag=self.dag,
+ steps_data_path=self._cache_location["steps_data"],
+ ) # type: ignore
+ self._logger.info(
+ f"♻️ Step '{step.name}' won't use cache (`use_cache=False`). The cache of this step and their successors won't be"
+ " reused and the results will have to be recomputed."
+ )
+ break
- def _setup_write_buffer(self) -> None:
+ def _setup_write_buffer(self, use_cache: bool = True) -> None:
"""Setups the `_WriteBuffer` that will store the data of the leaf steps of the
pipeline while running, so the `Distiset` can be created at the end.
"""
+ if not use_cache and self._cache_location["data"].exists():
+ shutil.rmtree(self._cache_location["data"])
buffer_data_path = self._cache_location["data"] / constants.STEPS_OUTPUTS_PATH
self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'")
- self._write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps)
+ self._write_buffer = _WriteBuffer(
+ buffer_data_path,
+ self.dag.leaf_steps,
+ steps_cached={
+ step_name: self.dag.get_step(step_name)[
+ constants.STEP_ATTR_NAME
+ ].use_cache
+ for step_name in self.dag
+ },
+ )
def _print_load_stages_info(self) -> None:
"""Prints the information about the load stages."""
stages, _ = self.dag.get_steps_load_stages()
msg = ""
for stage, steps in enumerate(stages):
- msg += f"\n * Stage {stage}: {steps}"
+ steps_to_be_loaded = self._steps_to_be_loaded_in_stage(stage)
+ msg += f"\n * Stage {stage}:"
+ for step in steps:
+ msg += f"\n - '{step}'"
+ if step not in steps_to_be_loaded:
+ msg += " (results cached, won't be loaded and executed)"
self._logger.info(
f"⌛ The steps of the pipeline will be loaded in stages:{msg}"
)
@@ -1094,6 +1142,26 @@ def _is_step_running(self, step_name: str) -> bool:
with self._steps_load_status_lock:
return self._steps_load_status[step_name] >= 1
+ def _steps_to_be_loaded_in_stage(self, stage: int) -> List[str]:
+ """Returns the list of steps of the provided stage that should be loaded taking
+ into account if they have finished.
+
+ Args:
+ stage: the stage number
+
+ Returns:
+ A list containing the name of the steps that should be loaded in this stage.
+ """
+ assert self._batch_manager, "Batch manager is not set"
+
+ steps_stages, _ = self.dag.get_steps_load_stages()
+
+ return [
+ step
+ for step in steps_stages[stage]
+ if not self._batch_manager.step_has_finished(step)
+ ]
+
def _run_stage_steps_and_wait(self, stage: int) -> bool:
"""Runs the steps of the specified stage and waits for them to be ready.
@@ -1103,9 +1171,10 @@ def _run_stage_steps_and_wait(self, stage: int) -> bool:
Returns:
`True` if all the steps have been loaded correctly, `False` otherwise.
"""
+ assert self._batch_manager, "Batch manager is not set"
- steps_stages, _ = self.dag.get_steps_load_stages()
- steps = steps_stages[stage]
+ steps = self._steps_to_be_loaded_in_stage(stage)
+ self._logger.debug(f"Steps to be loaded in stage {stage}: {steps}")
# Run the steps of the stage
self._run_steps(steps=steps)
@@ -1290,7 +1359,9 @@ def _add_batches_back_to_batch_manager(self) -> None:
if not isinstance(batch, _Batch):
continue
self._batch_manager.add_batch( # type: ignore
- to_step=step_name, batch=batch, prepend=True
+ to_step=step_name,
+ batch=batch,
+ prepend=True,
)
self._logger.debug(
f"Adding batch back to the batch manager: {batch}"
@@ -1319,10 +1390,10 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
"""
assert self._batch_manager, "Batch manager is not set"
- self._register_batch(batch)
-
route_to, do_not_route_to, routed = self._get_successors(batch)
+ self._register_batch(batch)
+
# Keep track of the steps that the batch was routed to
if routed:
batch.batch_routed_to = route_to
@@ -1363,6 +1434,7 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
while new_batch := self._batch_manager.get_batch(step.name): # type: ignore
# if new_batch := self._batch_manager.get_batch(step.name): # type: ignore
self._send_batch_to_step(new_batch)
+
else:
self._request_more_batches_if_needed(step)
else:
@@ -1429,7 +1501,10 @@ def _register_batch(self, batch: "_Batch") -> None:
Args:
batch: The batch to register.
"""
- self._batch_manager.register_batch(batch) # type: ignore
+ assert self._batch_manager, "Batch manager is not set"
+ self._batch_manager.register_batch(
+ batch, steps_data_path=self._cache_location["steps_data"]
+ ) # type: ignore
self._logger.debug(
f"Batch {batch.seq_no} from step '{batch.step_name}' registered in batch"
" manager"
@@ -1453,7 +1528,6 @@ def _send_last_batch_flag_to_step(self, step_name: str) -> None:
def _request_initial_batches(self) -> None:
"""Requests the initial batches to the generator steps."""
assert self._batch_manager, "Batch manager is not set"
-
for step in self._batch_manager._steps.values():
if not self._is_step_running(step.step_name):
continue
@@ -1513,7 +1587,9 @@ def _handle_batch_on_stop(self, batch: "_Batch") -> None:
"""
assert self._batch_manager, "Batch manager is not set"
- self._batch_manager.register_batch(batch)
+ self._batch_manager.register_batch(
+ batch, steps_data_path=self._cache_location["steps_data"]
+ )
step: "Step" = self.dag.get_step(batch.step_name)[constants.STEP_ATTR_NAME]
for successor in self.dag.get_step_successors(step.name): # type: ignore
self._batch_manager.add_batch(successor, batch)
diff --git a/src/distilabel/pipeline/batch.py b/src/distilabel/pipeline/batch.py
index d8ad4312a..684328f53 100644
--- a/src/distilabel/pipeline/batch.py
+++ b/src/distilabel/pipeline/batch.py
@@ -37,8 +37,11 @@ class _Batch(_Serializable):
data_hash: The hash of the data. Defaults to `None`.
data_path: The path where the data of the batch is stored. Defaults to `None`.
accumulated: A flag to indicate if the batch is accumulated.
- created_from: A dictionary containing the `seq_no` of the batches of the steps that
- were used to create this batch.
+ created_from: A dictionary containing which batches from which steps were used
+ to created this batch. The keys are the names of the steps and the values
+ are lists for each step containing the `seq_no` of each batch used, the original containing the `seq_no` of the batches of the steps that
+ size of the batch used and the number of rows used from the batch to create
+ this batch.
size: The size of the batch.
"""
@@ -49,7 +52,7 @@ class _Batch(_Serializable):
data_hash: Optional[str] = None
data_path: Optional[str] = None
accumulated: bool = False
- created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict)
+ created_from: Dict[str, List[Tuple[int, int, int]]] = field(default_factory=dict)
batch_routed_to: List[str] = field(default_factory=list)
size: int = 0
_fs: Optional[fsspec.AbstractFileSystem] = None
@@ -99,6 +102,7 @@ def get_data(self, num_rows: Union[int, None] = None) -> List[Dict[str, Any]]:
data = self.data[0][:num_rows]
self.data[0] = self.data[0][num_rows:]
+ # self.size = len(self.data[0])
self._update_data_hash()
return data
diff --git a/src/distilabel/pipeline/batch_manager.py b/src/distilabel/pipeline/batch_manager.py
index 8ddfa30e0..9ca05e48e 100644
--- a/src/distilabel/pipeline/batch_manager.py
+++ b/src/distilabel/pipeline/batch_manager.py
@@ -13,9 +13,10 @@
# limitations under the License.
from collections import defaultdict
+from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from distilabel.constants import (
RECEIVES_ROUTED_BATCHES_ATTR_NAME,
@@ -70,6 +71,14 @@ class _BatchManagerStep(_Serializable):
batch from step A used by steps B and C and obtained from the `created_from`
of the batches created by them. It's used to avoid messing up the order of the
batches. Only used if `convergence_step=True`. Defaults to `0`.
+ step_signature: The signature that defines a given `Step`. It will be used for the
+ caching mechanism.
+ use_cache: Flag from the original `Step` to indicate whether this step should make use of
+ the cached data.
+ step_offset: Dictionary with each key the predecessor/s step/s and as value a dict
+ with keys `batch` and `offset`, containing the name of the file for the corresponding
+ batch, and the number of rows that were read from that step, respectively. Used
+ for caching mechanism.
"""
step_name: str
@@ -85,6 +94,9 @@ class _BatchManagerStep(_Serializable):
)
next_expected_created_from_batch_seq_no: int = 0
next_expected_seq_no: Dict[str, Tuple[int, int]] = field(default_factory=dict)
+ step_signature: Optional[str] = None
+ use_cache: bool = False
+ step_offset: Dict[str, Tuple[int, int]] = field(default_factory=dict)
def add_batch(self, batch: _Batch, prepend: bool = False) -> None:
"""Add a batch of data from `batch.step_name` to the step. It will accumulate the
@@ -124,14 +136,22 @@ def get_batch(self) -> Union[_Batch, None]:
if not self._ready_to_create_batch():
return None
+ seq_no = self._get_seq_no()
+
# `_last_batch` must be called before `_get_data`, as `_get_data` will update the
# list of data which is used to determine if the batch to be created is the last one.
- # TODO: remove `_last_batch` method and integrate logic in `_get_data`
last_batch = self._last_batch()
+
+ # Get the batch data and the information from which batches of the upstream steps
+ # the data was taken.
data, created_from, batch_routed_to = self._get_data()
+ # Update the step offset i.e. which is the last batch and last row index from that
+ # batch that the step has consumed
+ self._update_offset(created_from)
+
return _Batch(
- seq_no=self._get_seq_no(),
+ seq_no=seq_no,
step_name=self.step_name,
last_batch=last_batch,
data=data,
@@ -212,6 +232,9 @@ def from_step(
data={predecessor: [] for predecessor in predecessors},
convergence_step=convergence_step,
next_expected_seq_no={predecessor: (0, 0) for predecessor in predecessors},
+ step_signature=step.signature,
+ use_cache=step.use_cache,
+ step_offset={predecessor: (0, 0) for predecessor in predecessors},
)
def _get_seq_no(self) -> int:
@@ -226,7 +249,9 @@ def _get_seq_no(self) -> int:
def _get_data(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]:
+ ) -> Tuple[
+ List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]], List[str]
+ ]:
"""Gets the data needed to create a batch for the step to process. If the step is
accumulating data, then it will return a list with all the data received from the
predecessors. Otherwise, it will return a list of data with the `input_batch_size`
@@ -252,7 +277,7 @@ def _get_data(
def _get_data_for_accumulate(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step
is accumulating data. It will return a list with all the data received from the
predecessors. In addition, it will remove the data used to create the batch from
@@ -268,7 +293,7 @@ def _get_data_for_accumulate(
for step_name, batches in self.data.items():
batches_used[step_name] = []
for batch in batches:
- batches_used[step_name].append((batch.seq_no, batch.size))
+ batches_used[step_name].append((batch.seq_no, batch.size, batch.size))
data.append([row for batch in batches for row in batch.get_data()])
# Reset the data buffer
self.data = {step_name: [] for step_name in self.data}
@@ -276,7 +301,7 @@ def _get_data_for_accumulate(
def _get_data_for_convergence_step(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step is
a convergence step.
@@ -315,7 +340,7 @@ def _get_data_for_convergence_step(
remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore
# Keep track of the batches used to create the batch
- batches_used[batch.step_name].append((batch.seq_no, batch.size))
+ batches_used[batch.step_name].append((batch.seq_no, batch.size, num_rows))
# If the batch was entirely consumed, then remove it from the buffer
if len(batch.data[0]) == 0:
@@ -336,7 +361,9 @@ def _get_data_for_convergence_step(
def _get_data_normal(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]:
+ ) -> Tuple[
+ List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]], List[str]
+ ]:
"""Gets the data needed to create a batch for the step to process when the step is
not accumulating data. It will return a list of data with the `input_batch_size`
for each predecessor. In addition, it will remove the data used to create the batch
@@ -374,7 +401,7 @@ def _get_data_normal(
remaining_rows -= num_rows
# Keep track of the batches used to create the batch
- batches_used[step_name].append((batch.seq_no, batch.size))
+ batches_used[step_name].append((batch.seq_no, batch.size, num_rows))
next_expected_seq_no = batch.seq_no
@@ -552,6 +579,35 @@ def _last_batch(self) -> bool:
return self._last_batch_normal()
+ def _update_offset(
+ self, created_from: Dict[str, List[Tuple[int, int, int]]]
+ ) -> None:
+ """Update the offset for the batch buffers of the upstream steps.
+
+ Args:
+ created_from: A dictionary containing which batches from which steps were used
+ to created this batch. The keys are the names of the steps and the values
+ are lists for each step containing the `seq_no` of each batch used, the original containing the `seq_no` of the batches of the steps that
+ size of the batch used and the number of rows used from the batch to create
+ this batch.
+ """
+ for predecessor, seq_no_and_batch in created_from.items():
+ prev_last_batch_seq_no, prev_last_batch_offset = self.step_offset[
+ predecessor
+ ]
+ last_batch_seq_no, _, last_batch_size = seq_no_and_batch[-1]
+ batch_offset = (
+ prev_last_batch_offset + last_batch_size
+ if prev_last_batch_seq_no == last_batch_seq_no
+ else last_batch_size
+ )
+ last_batch_seq_no = (
+ last_batch_seq_no
+ if last_batch_seq_no > prev_last_batch_seq_no
+ else prev_last_batch_seq_no
+ )
+ self.step_offset[predecessor] = (last_batch_seq_no, batch_offset)
+
def _last_batch_accumulate(self) -> bool:
"""Checks if the batch to be created is the last one for an step accumulating data.
`True` if the last batch was received from all the predecessors.
@@ -596,11 +652,7 @@ def _last_batch_normal(self) -> bool:
num_rows = sum(len(batch.data[0]) for batch in batches)
- if (
- self.input_batch_size
- and num_rows > self.input_batch_size
- and step_name in self.last_batch_received
- ):
+ if self.input_batch_size and num_rows > self.input_batch_size:
return False
return True
@@ -619,12 +671,12 @@ def _group_batches_by_created_from(
for batches in self.data.values():
for batch in batches:
first_key = next(iter(batch.created_from))
- batch_seq_no, batch_size = batch.created_from[first_key][0]
+ batch_seq_no, batch_size, _ = batch.created_from[first_key][0]
grouped_batches[batch_seq_no].append((batch, batch_size))
return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items())
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
- """Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function.
+ """Dumps the content of the `_BatchManagerStep` to a dictionary.
Args:
obj: Unused, just kept to match the signature of the parent method.
@@ -648,8 +700,15 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"convergence_step_batches_consumed": self.convergence_step_batches_consumed,
"next_expected_created_from_batch_seq_no": self.next_expected_created_from_batch_seq_no,
"next_expected_seq_no": self.next_expected_seq_no,
+ "step_signature": self.step_signature,
+ "use_cache": self.use_cache,
+ "step_offset": self.step_offset,
}
+ @property
+ def signature(self) -> str:
+ return f"{self.step_name}_{self.step_signature}"
+
class _BatchManager(_Serializable):
"""Class to manage the batches received from the steps. It keeps track of the
@@ -675,9 +734,9 @@ def __init__(
Args:
steps: A dictionary with the step name as the key and a dictionary with the
predecessor step name as the key and a list of batches as the value.
- last_batch_received: A dictionary with the step name as the key and a the last
+ last_batch_received: A dictionary with the step name as the key and the last
`_Batch` received from the step.
- last_batch_sent: A dictionary with the step name as the key and a the last
+ last_batch_sent: A dictionary with the step name as the key and the last
`_Batch` sent to the step.
last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG`
was sent.
@@ -695,7 +754,6 @@ def can_generate(self) -> bool:
`True` if there are still batches to be processed by the steps. Otherwise,
`False`.
"""
-
for step_name, batch in self._last_batch_received.items():
if step_name not in self._last_batch_flag_sent_to:
if not batch:
@@ -709,17 +767,38 @@ def can_generate(self) -> bool:
return False
- def register_batch(self, batch: _Batch) -> None:
+ def register_batch(
+ self, batch: _Batch, steps_data_path: Optional["StrOrPath"] = None
+ ) -> None:
"""Method to register a batch received from a step. It will keep track of the
sequence number and the last batch received from the step in the internal maps.
Args:
batch: _Batch from which we will register the sequence number and the last batch received.
+ steps_data_path: The path where the outputs of each `Step` (considering its
+ signature) will be saved for later reuse in another pipelines executions.
"""
last_batch = self._last_batch_received[batch.step_name]
if not last_batch or (last_batch and last_batch.seq_no < batch.seq_no):
self._last_batch_received[batch.step_name] = batch
+ if steps_data_path:
+ self.write_batch_data(batch, steps_data_path)
+
+ def write_batch_data(self, batch: _Batch, steps_data_path: Path) -> None:
+ """Writes the batch to the steps data directory.
+
+ Argument:
+ batch: the batch to be written.
+ steps_data_path: the steps data base directory.
+ """
+ step = self._steps[batch.step_name]
+ batch_manager_data_dir = Path(steps_data_path) / step.signature
+ batch_manager_data_dir.mkdir(parents=True, exist_ok=True)
+ filename = batch_manager_data_dir / f"batch_{batch.seq_no}.json"
+ if not filename.exists():
+ self.save(path=filename, format="json", dump=batch.dump())
+
def get_last_batch(self, step_name: str) -> Union[_Batch, None]:
"""Gets the last batch received from a step.
@@ -731,7 +810,12 @@ def get_last_batch(self, step_name: str) -> Union[_Batch, None]:
"""
return self._last_batch_received.get(step_name)
- def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None:
+ def add_batch(
+ self,
+ to_step: str,
+ batch: _Batch,
+ prepend: bool = False,
+ ) -> None:
"""Add an output batch from `batch.step_name` to `to_step`.
Args:
@@ -745,7 +829,6 @@ def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None:
"""
if to_step not in self._steps:
raise ValueError(f"Step '{to_step}' not found in the batch manager.")
-
step = self._steps[to_step]
step.add_batch(batch, prepend)
@@ -824,25 +907,46 @@ def set_last_batch_flag_sent_to(self, step_name: str) -> None:
def set_next_expected_seq_no(
self, step_name: str, from_step: str, next_expected_seq_no: int
) -> None:
- """Sets the next expected sequence number of a `_Batch` received by `step` comming
+ """Sets the next expected sequence number of a `_Batch` received by `step` coming
from `from_step`.
Args:
- step_name: The step name which next expected sequence number for `from_step`
+ step_name: The step name whose next expected sequence number for `from_step`
has to be updated.
from_step: The name of the step from which its next expected sequence number
in step has to be updated.
- next_expected_seq_no: the next expected sequence number of a `_Batch` comming
+ next_expected_seq_no: the next expected sequence number of a `_Batch` coming
from `from_step`.
"""
self._steps[step_name].set_next_expected_seq_no(from_step, next_expected_seq_no)
+ def step_has_finished(self, step_name: str) -> bool:
+ """Indicates if the step has finished by checking if it sent a batch with `last_batch==True`
+ or it was sent the `LAST_BATCH_SENT_FLAG`.
+
+ Args:
+ step_name: the name of the step to be checked.
+
+ Returns:
+ `True` if step has finished generating batches, `False` otherwise.
+ """
+ return step_name in self._last_batch_flag_sent_to or (
+ self._last_batch_received[step_name] is not None
+ and self._last_batch_received[step_name].last_batch # type: ignore
+ )
+
@classmethod
- def from_dag(cls, dag: "DAG") -> "_BatchManager":
+ def from_dag( # noqa: C901
+ cls, dag: "DAG", use_cache: bool = False, steps_data_path: Optional[Path] = None
+ ) -> "_BatchManager":
"""Create a `_BatchManager` instance from a `DAG` instance.
Args:
dag: The `DAG` instance.
+ use_cache: whether or not to try loading outputs from steps of previous pipelines
+ executions. Defaults to `False`.
+ steps_data_path: The path where the outputs of each `Step` (considering its
+ signature) will be saved for later reuse in another pipelines executions.
Returns:
A `_BatchManager` instance.
@@ -850,12 +954,14 @@ def from_dag(cls, dag: "DAG") -> "_BatchManager":
steps = {}
last_batch_received = {}
last_batch_sent = {}
+ last_batch_flag_sent_to = []
+
+ load_batches = {}
+ steps_to_load_data_from_previous_executions: Dict[str, Union[Path, None]] = {}
for step_name in dag:
step: "_Step" = dag.get_step(step_name)[STEP_ATTR_NAME]
last_batch_received[step.name] = None
last_batch_sent[step.name] = None
- if step.is_generator:
- continue
predecessors = list(dag.get_step_predecessors(step_name))
convergence_step = all(
dag.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False)
@@ -866,8 +972,55 @@ def from_dag(cls, dag: "DAG") -> "_BatchManager":
predecessors=predecessors,
convergence_step=convergence_step,
)
+
+ all_step_precessors_use_cache = all(
+ dag.get_step(step_name)[STEP_ATTR_NAME].use_cache
+ for step_name in predecessors
+ )
+ if use_cache and step.use_cache and all_step_precessors_use_cache:
+ step_data_path = steps_data_path / batch_manager_step.signature
+ if step_data_path.exists():
+ steps_to_load_data_from_previous_executions[step_name] = (
+ step_data_path
+ )
+ # We only want to load the outputs that are directly needed by the added
+ # steps, so if we need to load the outputs of one step and one of its
+ # predecessors it's in the list, then we remove it.
+ for predecessor in predecessors:
+ if predecessor in steps_to_load_data_from_previous_executions:
+ steps_to_load_data_from_previous_executions[predecessor] = (
+ None
+ )
+
steps[step_name] = batch_manager_step
- return cls(steps, last_batch_received, last_batch_sent, [])
+
+ for (
+ step_name,
+ step_outputs_path,
+ ) in steps_to_load_data_from_previous_executions.items():
+ last_batch_flag_sent_to.append(step_name)
+ if step_outputs_path is None:
+ continue
+ load_batches[step_name] = sorted(
+ [
+ _Batch.from_json(batch_file)
+ for batch_file in step_outputs_path.glob("*.json")
+ if batch_file.is_file() and batch_file.suffix == ".json"
+ ],
+ key=lambda x: x.seq_no,
+ )
+ last_batch_received[step_name] = load_batches[step_name][-1]
+
+ # Load batches from previous steps in batch manager steps
+ for step_name, batch_manager_step in steps.items():
+ for predecessor in dag.get_step_predecessors(step_name):
+ if predecessor in load_batches:
+ batch_manager_step.data[predecessor] = deepcopy(
+ load_batches[predecessor]
+ )
+ batch_manager_step.last_batch_received.append(predecessor)
+
+ return cls(steps, last_batch_received, last_batch_sent, last_batch_flag_sent_to)
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the content of the `_BatchManager` to a dictionary.
@@ -892,12 +1045,14 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"last_batch_flag_sent_to": self._last_batch_flag_sent_to,
}
- def cache(self, path: "StrOrPath") -> None:
+ def cache(self, path: Path, steps_data_path: Path) -> None: # noqa: C901
"""Cache the `_BatchManager` to a file.
Args:
path: The path to the file where the `_BatchManager` will be cached. If `None`,
then the `_BatchManager` will be cached in the default cache folder.
+ steps_data_path: The path where the outputs of each `Step` (considering its
+ signature) will be saved for later reuse in another pipelines executions.
"""
def save_batch(
@@ -953,26 +1108,6 @@ def remove_files(keep_files: List[str], dir: Path) -> None:
# Remove built `_Batch`es that were consumed from cache
remove_files(step_dump["built_batches"], built_batches_dir)
- # Store each `_BatchManagerStep` `_Batch`es in a separate file
- for buffered_step_name in step_dump["data"]:
- step_batches_dir = batch_manager_step_dir / buffered_step_name
- step_batches_dir.mkdir(parents=True, exist_ok=True)
-
- # Store each `_Batch` in a separate file
- step_dump["data"][buffered_step_name] = [
- str(
- save_batch(
- batches_dir=step_batches_dir,
- batch_dump=batch_dump,
- batch_list=self._steps[step_name].data[buffered_step_name],
- )
- )
- for batch_dump in step_dump["data"][buffered_step_name]
- ]
-
- # Remove `_Batch`es that were consumed from cache
- remove_files(step_dump["data"][buffered_step_name], step_batches_dir)
-
# Store the `_BatchManagerStep` info
batch_manager_step_file = str(
path.parent / f"batch_manager_steps/{step_name}/batch_manager_step.json"
@@ -986,29 +1121,138 @@ def remove_files(keep_files: List[str], dir: Path) -> None:
self.save(path=path, format="json", dump=dump)
@classmethod
- def load_from_cache(cls, path: "StrOrPath") -> "_BatchManager":
+ def load_from_cache(
+ cls, dag: "DAG", batch_manager_path: "StrOrPath", steps_data_path: "StrOrPath"
+ ) -> "_BatchManager":
"""Loads the `_BatchManager` from a cache file.
Args:
path: The path to the cache file.
"""
- _check_is_dir(path)
- content = read_json(path)
+ _check_is_dir(batch_manager_path)
+ content = read_json(batch_manager_path)
# Read each `_BatchManagerStep` from file
steps = {}
for step_name, step_file in content["steps"].items():
steps[step_name] = read_json(step_file)
+ # When reading back from JSON, `next_expected_seq_no` and `step_offset` is a
+ # list (because JSON files do not have tuples).
+ steps[step_name]["next_expected_seq_no"] = {
+ k: tuple(v) for k, v in steps[step_name]["next_expected_seq_no"].items()
+ }
+ steps[step_name]["step_offset"] = {
+ k: tuple(v) for k, v in steps[step_name]["step_offset"].items()
+ }
+
+ # TODO: where are we writing built batches now? xD
# Read each `_Batch` from file
steps[step_name]["built_batches"] = [
read_json(batch) for batch in steps[step_name]["built_batches"]
]
- for buffered_step_name, batch_files in steps[step_name]["data"].items():
- steps[step_name]["data"][buffered_step_name] = [
- read_json(batch_file) for batch_file in batch_files
- ]
+ # Read the batches from the `steps_data` directory to populate back the `_BatchManagerStep`
+ step_offset = steps[step_name]["step_offset"]
+ for successor_step_name, offset in step_offset.items():
+ batch_offset, batch_row_offset = offset
+ step: "_Step" = dag.get_step(successor_step_name)[STEP_ATTR_NAME]
+ successor_step_data_path = (
+ steps_data_path / f"{step.name}_{step.signature}"
+ )
+
+ # read batches from successor step from the step data directory taking into
+ # account offset
+ batches = []
+ for batch_file in successor_step_data_path.glob("*.json"):
+ if not batch_file.is_file() or batch_file.suffix != ".json":
+ continue
+
+ # If the batch number is lower than the batch offset then we should
+ # skip it as it has already been processed by the step
+ batch_no = int(batch_file.stem.split("batch_")[1])
+ if batch_no < batch_offset:
+ continue
+
+ # read the batch and skip the first N rows of the first batch
+ batch = read_json(batch_file)
+ if batch_no == batch_offset:
+ batch["data"][0] = batch["data"][0][batch_row_offset:]
+
+ batches.append(batch)
+
+ # sort batches by `seq_no` as it's a requirement for checking if ready to
+ # create next batch
+ batches.sort(key=lambda batch: batch["seq_no"])
+ steps[step_name]["data"][successor_step_name] = batches
content["steps"] = steps
return cls.from_dict(content)
+
+ def invalidate_cache_for(
+ self, step_name: str, dag: "DAG", steps_data_path: Path
+ ) -> None:
+ """Invalidates the cache for the given step and its predecessors.
+
+ Args:
+ step_name: the name of the step for which the cache will be invalidated.
+ dag: the `DAG` of the pipeline containing the steps.
+ steps_data_path: the path where the output batches of each `Step` were saved
+ for reuse in another pipeline execution.
+ """
+ invalidate_if_predecessor = []
+ for sorted_step in dag:
+ if (sorted_step == step_name) or any(
+ predecessor in invalidate_if_predecessor
+ for predecessor in dag.get_step_predecessors(sorted_step)
+ ):
+ self._reset_batch_manager_for_step(sorted_step, dag)
+ invalidate_if_predecessor.append(sorted_step)
+
+ self._load_predecessor_batches(step_name, dag, steps_data_path)
+
+ def _reset_batch_manager_for_step(self, step_name: str, dag: "DAG") -> None:
+ """Resets the batch manager state for a given step i.e. creates a new clean `_BatchManagerStep`
+ for the step and removes the step name from the lists of states of the `BatchManager`.
+
+ Args:
+ step_name: the name of step for which its batch manager state needs to be cleaned.
+ dag: the `DAG` of the pipeline containing the steps.
+ """
+ predecessors = list(dag.get_step_predecessors(step_name))
+ convergence_step = dag.is_convergence_step(step_name)
+ step = dag.get_step(step_name)[STEP_ATTR_NAME]
+ self._steps[step_name] = _BatchManagerStep.from_step(
+ step, predecessors=predecessors, convergence_step=convergence_step
+ )
+
+ self._last_batch_received[step_name] = None
+ self._last_batch_sent[step_name] = None
+ if step_name in self._last_batch_flag_sent_to:
+ self._last_batch_flag_sent_to.remove(step_name)
+
+ def _load_predecessor_batches(
+ self, step_name: str, dag: "DAG", steps_data_path: Path
+ ) -> None:
+ """Loads the cached batches of the predecessors of the step in its `_BatchManagerStep`.
+
+ Args:
+ step_name: the name of the step whose predecessors' batches will be loaded.
+ dag: the `DAG` of the pipeline containing the steps.
+ steps_data_path: the path where the output batches of each `Step` were saved
+ for reuse in another pipeline execution.
+ """
+ for predecessor in dag.get_step_predecessors(step_name):
+ step_predecessor = dag.get_step(predecessor)[STEP_ATTR_NAME]
+ predecessor_step_data_path = (
+ steps_data_path
+ / f"{step_predecessor.name}_{step_predecessor.signature}"
+ )
+ batch_files = list_files_in_dir(
+ predecessor_step_data_path, key=lambda x: int(x.stem.split("_")[-1])
+ )
+ for file in batch_files:
+ batch = _Batch.from_file(file)
+ if batch.last_batch:
+ self._steps[step_name].last_batch_received.append(batch.step_name)
+ self._steps[step_name].data[predecessor].append(batch)
diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py
index be7919d56..c01cce303 100644
--- a/src/distilabel/pipeline/local.py
+++ b/src/distilabel/pipeline/local.py
@@ -216,7 +216,7 @@ def run(
initargs=(
self._log_queue,
self.name,
- self._create_signature(),
+ self.signature,
),
) as pool,
):
diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py
index cf9a26064..70bf205ab 100644
--- a/src/distilabel/pipeline/ray.py
+++ b/src/distilabel/pipeline/ray.py
@@ -310,7 +310,7 @@ def run(self) -> str:
),
log_queue=self._log_queue,
pipeline_name=self.name,
- pipeline_cache_id=self._create_signature(),
+ pipeline_cache_id=self.signature,
)
self._logger.debug(
diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py
index 661f99b1f..844648f20 100644
--- a/src/distilabel/pipeline/step_wrapper.py
+++ b/src/distilabel/pipeline/step_wrapper.py
@@ -166,6 +166,7 @@ def _generator_step_process_loop(self) -> None:
`process` method.
"""
step = cast("GeneratorStep", self.step)
+
try:
if (batch := self.input_queue.get()) is None:
self.step._logger.info(
@@ -277,13 +278,7 @@ def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]:
Args:
batch: The batch to impute.
"""
- result = []
- for row in batch.data[0]:
- data = row.copy()
- for output in self.step.outputs:
- data[output] = None
- result.append(data)
- return result
+ return self.step.impute_step_outputs(batch.data[0])
def _send_batch(self, batch: _Batch) -> None:
"""Sends a batch to the `output_queue`."""
diff --git a/src/distilabel/pipeline/write_buffer.py b/src/distilabel/pipeline/write_buffer.py
index a71ffdd9b..3fdb037e1 100644
--- a/src/distilabel/pipeline/write_buffer.py
+++ b/src/distilabel/pipeline/write_buffer.py
@@ -15,7 +15,7 @@
import logging
from os import PathLike
from pathlib import Path
-from typing import Any, Dict, List, Set
+from typing import Any, Dict, List, Optional, Set
import pyarrow as pa
import pyarrow.parquet as pq
@@ -33,12 +33,21 @@ class _WriteBuffer:
is full, the content is written to a parquet file.
"""
- def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None:
+ def __init__(
+ self,
+ path: "PathLike",
+ leaf_steps: Set[str],
+ steps_cached: Optional[Dict[str, bool]] = None,
+ ) -> None:
"""
Args:
path: Folder where the files will be written, the idea
is for this path to be in the cache folder under /data.
leaf_steps: Leaf steps from either the DAG of the Pipeline.
+ steps_cached: Dictionary with the name of a step and the variable
+ use_cache. We will use this to determine whether we have to read
+ a previous parquet table to concatenate before saving the cached
+ datasets.
Raises:
ValueError: If the path is not a directory.
@@ -61,6 +70,7 @@ def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None:
}
self._buffer_last_schema = {}
self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps}
+ self._steps_cached = steps_cached or {}
self._logger = logging.getLogger("distilabel.write_buffer")
def _get_filename(self, step_name: str) -> Path:
@@ -130,14 +140,28 @@ def _write(self, step_name: str) -> None:
self._buffer_last_schema[step_name] = table.schema
else:
if not last_schema.equals(table.schema):
- new_schema = pa.unify_schemas([last_schema, table.schema])
- self._buffer_last_schema[step_name] = new_schema
- table = table.cast(new_schema)
+ if set(last_schema.names) == set(table.schema.names):
+ table = table.select(last_schema.names)
+ else:
+ new_schema = pa.unify_schemas([last_schema, table.schema])
+ self._buffer_last_schema[step_name] = new_schema
+ table = table.cast(new_schema)
next_file_number = self._buffers_last_file[step_name]
self._buffers_last_file[step_name] = next_file_number + 1
parquet_file = step_parquet_dir / f"{str(next_file_number).zfill(5)}.parquet"
+ if parquet_file.exists():
+ # If the file already exists, due to some error in a pipeline that was cached
+ prev_table = pq.read_table(parquet_file)
+ # If some columns differ, it means some of the step changed, we won't load the previous table
+ # NOTE: If any step has use_cache=False, we cannot assume the previous parquet file is
+ # valid, so we will overwrite the previous parquet file. Is this the best option?
+ use_cache = False not in self._steps_cached.values()
+
+ if prev_table.column_names == table.column_names and use_cache:
+ table = pa.concat_tables([prev_table, table])
+
pq.write_table(table, parquet_file)
self._logger.debug(f"Written to file '{parquet_file}'")
diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py
index 79c10a268..cc1be59f9 100644
--- a/src/distilabel/steps/__init__.py
+++ b/src/distilabel/steps/__init__.py
@@ -45,6 +45,7 @@
FormatTextGenerationSFT,
)
from distilabel.steps.generators.data import LoadDataFromDicts
+from distilabel.steps.generators.data_sampler import DataSampler
from distilabel.steps.generators.huggingface import (
LoadDataFromDisk,
LoadDataFromFileSystem,
@@ -83,6 +84,7 @@
"FormatChatGenerationSFT",
"FormatTextGenerationSFT",
"LoadDataFromDicts",
+ "DataSampler",
"LoadDataFromDisk",
"LoadDataFromFileSystem",
"LoadDataFromHub",
diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py
index cc0c0b2e1..b98c0e827 100644
--- a/src/distilabel/steps/base.py
+++ b/src/distilabel/steps/base.py
@@ -39,6 +39,7 @@
RuntimeParameter,
RuntimeParametersMixin,
)
+from distilabel.mixins.signature import SignatureMixin
from distilabel.utils.serialization import _Serializable, write_json
from distilabel.utils.typing_ import is_parameter_annotated_with
@@ -133,7 +134,14 @@ class StepResources(RuntimeParametersMixin, BaseModel):
)
-class _Step(RuntimeParametersMixin, RequirementsMixin, BaseModel, _Serializable, ABC):
+class _Step(
+ RuntimeParametersMixin,
+ RequirementsMixin,
+ SignatureMixin,
+ BaseModel,
+ _Serializable,
+ ABC,
+):
"""Base class for the steps that can be included in a `Pipeline`.
A `Step` is a class defining some processing logic. The input and outputs for this
@@ -193,6 +201,7 @@ def process(self, inputs: *StepInput) -> StepOutput:
pipeline: Any = Field(default=None, exclude=True, repr=False)
input_mappings: Dict[str, str] = {}
output_mappings: Dict[str, str] = {}
+ use_cache: bool = True
_pipeline_artifacts_path: Path = PrivateAttr(None)
_built_from_decorator: bool = PrivateAttr(default=False)
@@ -582,6 +591,20 @@ def save_artifact(
)
write_json(filename=metadata_path, data=metadata or {})
+ def impute_step_outputs(
+ self, step_output: List[Dict[str, Any]]
+ ) -> List[Dict[str, Any]]:
+ """
+ Imputes the output columns of the step that are not present in the step output.
+ """
+ result = []
+ for row in step_output:
+ data = row.copy()
+ for output in self.get_outputs().keys():
+ data[output] = None
+ result.append(data)
+ return result
+
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
dump = super()._model_dump(obj, **kwargs)
dump["runtime_parameters_info"] = self.get_runtime_parameters_info()
@@ -644,10 +667,19 @@ def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput"
)
for output_rows in generator:
- yield [
- self._apply_mappings_and_restore_overriden(row, overriden_inputs[i])
- for i, row in enumerate(output_rows)
- ]
+ restored = []
+ for i, row in enumerate(output_rows):
+ # Correct the index here because we don't know the num_generations from the llm
+ # ahead of time. For example, if we have `len(overriden_inputs)==5` and `len(row)==10`,
+ # from `num_generations==2` and `group_generations=False` in the LLM:
+ # The loop will use indices 0, 1, 2, 3, 4, 0, 1, 2, 3, 4
+ ntimes_i = i % len(overriden_inputs)
+ restored.append(
+ self._apply_mappings_and_restore_overriden(
+ row, overriden_inputs[ntimes_i]
+ )
+ )
+ yield restored
def _apply_input_mappings(
self, inputs: Tuple[List[Dict[str, Any]], ...]
diff --git a/src/distilabel/steps/clustering/text_clustering.py b/src/distilabel/steps/clustering/text_clustering.py
index 4bf583c16..7e640bf5c 100644
--- a/src/distilabel/steps/clustering/text_clustering.py
+++ b/src/distilabel/steps/clustering/text_clustering.py
@@ -223,7 +223,6 @@ def _create_figure(
inputs: The inputs of the step, as we will extract information from them again.
label2docs: Map from each label to the list of documents (texts) that belong to that cluster.
cluster_summaries: The summaries of the clusters, obtained from the LLM.
- labels: The labels of the clusters (integers representing each predicted class).
"""
self._logger.info("🖼️ Creating figure for the clusters...")
diff --git a/src/distilabel/steps/decorator.py b/src/distilabel/steps/decorator.py
index 0816ca13e..3e84df66f 100644
--- a/src/distilabel/steps/decorator.py
+++ b/src/distilabel/steps/decorator.py
@@ -17,7 +17,6 @@
TYPE_CHECKING,
Any,
Callable,
- List,
Literal,
Type,
Union,
@@ -175,10 +174,10 @@ def decorator(func: ProcessingFunc) -> Type["_Step"]:
**runtime_parameters, # type: ignore
)
- def inputs_property(self) -> List[str]:
+ def inputs_property(self) -> "StepColumns":
return inputs
- def outputs_property(self) -> List[str]:
+ def outputs_property(self) -> "StepColumns":
return outputs
def process(
diff --git a/src/distilabel/steps/generators/data_sampler.py b/src/distilabel/steps/generators/data_sampler.py
new file mode 100644
index 000000000..6b2e55bf0
--- /dev/null
+++ b/src/distilabel/steps/generators/data_sampler.py
@@ -0,0 +1,179 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from itertools import islice
+from typing import TYPE_CHECKING, Any, Dict, List
+
+from pydantic import Field
+from typing_extensions import override
+
+from distilabel.steps.base import GeneratorStep
+
+if TYPE_CHECKING:
+ from distilabel.steps.base import GeneratorStepOutput
+
+
+class DataSampler(GeneratorStep):
+ """Step to sample from a dataset.
+
+ `GeneratorStep` that samples from a dataset and yields it in batches.
+ This step is useful when you have a pipeline that can benefit from using examples
+ in the prompts for example as few-shot learning, that can be changing on each row.
+ For example, you can pass a list of dictionaries with N examples and generate M samples
+ from it (assuming you have another step loading data, this M should have the same size
+ as the data being loaded in that step). The size S argument is the number of samples per
+ row generated, so each example would contain S examples to be used as examples.
+
+ Attributes:
+ data: The list of dictionaries to sample from.
+ size: Number of samples per example. For example in a few-shot learning scenario,
+ the number of few-shot examples that will be generated per example. Defaults to 2.
+ samples: Number of examples that will be generated by the step in total.
+ If used with another loader step, this should be the same as the number
+ of samples in the loader step. Defaults to 100.
+
+ Output columns:
+ - dynamic (based on the keys found on the first dictionary of the list): The columns
+ of the dataset.
+
+ Categories:
+ - load
+
+ Examples:
+ Sample data from a list of dictionaries:
+
+ ```python
+ from distilabel.steps import DataSampler
+
+ sampler = DataSampler(
+ data=[{"sample": f"sample {i}"} for i in range(30)],
+ samples=10,
+ size=2,
+ batch_size=4
+ )
+ sampler.load()
+
+ result = next(sampler.process())
+ # >>> result
+ # ([{'sample': ['sample 7', 'sample 0']}, {'sample': ['sample 2', 'sample 21']}, {'sample': ['sample 17', 'sample 12']}, {'sample': ['sample 2', 'sample 14']}], False)
+ ```
+
+ Pipeline with a loader and a sampler combined in a single stream:
+
+ ```python
+ from datasets import load_dataset
+
+ from distilabel.steps import LoadDataFromDicts, DataSampler
+ from distilabel.steps.tasks.apigen.utils import PrepareExamples
+ from distilabel.pipeline import Pipeline
+
+ ds = (
+ load_dataset("Salesforce/xlam-function-calling-60k", split="train")
+ .shuffle(seed=42)
+ .select(range(500))
+ .to_list()
+ )
+ data = [
+ {
+ "func_name": "final_velocity",
+ "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
+ },
+ {
+ "func_name": "permutation_count",
+ "func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
+ },
+ {
+ "func_name": "getdivision",
+ "func_desc": "Divides two numbers by making an API call to a division service.",
+ },
+ ]
+ with Pipeline(name="APIGenPipeline") as pipeline:
+ loader_seeds = LoadDataFromDicts(data=data)
+ sampler = DataSampler(
+ data=ds,
+ size=2,
+ samples=len(data),
+ batch_size=8,
+ )
+ prep_examples = PrepareExamples()
+
+ sampler >> prep_examples
+ (
+ [loader_seeds, prep_examples]
+ >> combine_steps
+ )
+ # Now we have a single stream of data with the loader and the sampler data
+ ```
+ """
+
+ data: List[Dict[str, Any]] = Field(default_factory=list, exclude=True)
+ size: int = Field(
+ default=2,
+ description=(
+ "Number of samples per example. For example in a few-shot learning scenario, the number "
+ "of few-shot examples that will be generated per example."
+ ),
+ )
+ samples: int = Field(
+ default=100,
+ description=(
+ "Number of examples that will be generated by the step in total. "
+ "If used with another loader step, this should be the same as the number of "
+ "samples in the loader step."
+ ),
+ )
+
+ @override
+ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
+ """Yields batches from a list of dictionaries.
+
+ Args:
+ offset: The offset to start the generation from. Defaults to `0`.
+
+ Yields:
+ A list of Python dictionaries as read from the inputs (propagated in batches)
+ and a flag indicating whether the yield batch is the last one.
+ """
+
+ total_samples = 0
+
+ while total_samples < self.samples:
+ batch = []
+ bs = min(self.batch_size, self.samples - total_samples)
+ for _ in range(self.batch_size):
+ choices = random.choices(self.data, k=self.size)
+ choices = self._transform_data(choices)
+ batch.extend(choices)
+ total_samples += bs
+ batch = list(islice(batch, bs))
+ yield (batch, True if total_samples >= self.samples else False)
+ batch = []
+
+ @staticmethod
+ def _transform_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ if not data:
+ return []
+
+ result = {key: [] for key in data[0].keys()}
+
+ for item in data:
+ for key, value in item.items():
+ result[key].append(value)
+
+ return [result]
+
+ @property
+ def outputs(self) -> List[str]:
+ return list(self.data[0].keys())
diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py
index f6e782a75..721b3d408 100644
--- a/src/distilabel/steps/generators/huggingface.py
+++ b/src/distilabel/steps/generators/huggingface.py
@@ -219,11 +219,11 @@ def _get_dataset_num_examples(self) -> int:
Returns:
The number of examples in the dataset.
"""
- return (
- self._dataset_info[self.config if self.config else "default"]
- .splits[self.split]
- .num_examples
- )
+ default_config = self.config
+ if not default_config:
+ default_config = list(self._dataset_info.keys())[0]
+
+ return self._dataset_info[default_config].splits[self.split].num_examples
def _get_dataset_columns(self) -> List[str]:
"""Get the columns of the dataset, based on the `config` runtime parameter provided.
diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py
index eb90c6dba..98974b00d 100644
--- a/src/distilabel/steps/tasks/__init__.py
+++ b/src/distilabel/steps/tasks/__init__.py
@@ -12,8 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from distilabel.steps.tasks.apigen.execution_checker import APIGenExecutionChecker
+from distilabel.steps.tasks.apigen.generator import APIGenGenerator
+from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker
+from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller
from distilabel.steps.tasks.base import GeneratorTask, Task
+from distilabel.steps.tasks.clair import CLAIR
from distilabel.steps.tasks.complexity_scorer import ComplexityScorer
+from distilabel.steps.tasks.decorator import task
from distilabel.steps.tasks.evol_instruct.base import EvolInstruct
from distilabel.steps.tasks.evol_instruct.evol_complexity.base import EvolComplexity
from distilabel.steps.tasks.evol_instruct.evol_complexity.generator import (
@@ -52,7 +58,12 @@
__all__ = [
"GeneratorTask",
"Task",
+ "ArgillaLabeller",
+ "APIGenExecutionChecker",
+ "APIGenGenerator",
+ "APIGenSemanticChecker",
"ComplexityScorer",
+ "task",
"EvolInstruct",
"EvolComplexity",
"EvolComplexityGenerator",
@@ -81,6 +92,7 @@
"TextGeneration",
"ChatItem",
"ChatType",
+ "CLAIR",
"UltraFeedback",
"URIAL",
]
diff --git a/src/distilabel/steps/tasks/apigen/__init__.py b/src/distilabel/steps/tasks/apigen/__init__.py
new file mode 100644
index 000000000..20ce00bda
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/distilabel/steps/tasks/apigen/execution_checker.py b/src/distilabel/steps/tasks/apigen/execution_checker.py
new file mode 100644
index 000000000..7d30dd1f7
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/execution_checker.py
@@ -0,0 +1,268 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# - Try to import the function from a given module
+# - If function, try to import it and run it
+# - If fails, track the error message, and return it
+
+import inspect
+import json
+from pathlib import Path
+from typing import TYPE_CHECKING, Callable, Union
+
+from pydantic import Field, PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.base import Step, StepInput
+from distilabel.steps.tasks.apigen.utils import (
+ execute_from_response,
+ load_module_from_path,
+)
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ from distilabel.steps.typing import StepColumns, StepOutput
+
+
+class APIGenExecutionChecker(Step):
+ """Executes the generated function calls.
+
+ This step checks if a given answer from a model as generated by `APIGenGenerator`
+ can be executed against the given library (given by `libpath`, which is a string
+ pointing to a python .py file with functions).
+
+ Attributes:
+ libpath: The path to the library where we will retrieve the functions.
+ It can also point to a folder with the functions. In this case, the folder
+ layout should be a folder with .py files, each containing a single function,
+ the name of the function being the same as the filename.
+ check_is_dangerous: Bool to exclude some potentially dangerous functions, it contains
+ some heuristics found while testing. This functions can run subprocesses, deal with
+ the OS, or have other potentially dangerous operations. Defaults to True.
+
+ Input columns:
+ - answers (`str`): List with arguments to be passed to the function,
+ dumped as a string from a list of dictionaries. Should be loaded using
+ `json.loads`.
+
+ Output columns:
+ - keep_row_after_execution_check (`bool`): Whether the function should be kept or not.
+ - execution_result (`str`): The result from executing the function.
+
+ Categories:
+ - filtering
+ - execution
+
+ References:
+ - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
+ - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
+
+ Examples:
+ Execute a function from a given library with the answer from an LLM:
+
+ ```python
+ from distilabel.steps.tasks import APIGenExecutionChecker
+
+ # For the libpath you can use as an example the file at the tests folder:
+ # ../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py
+ task = APIGenExecutionChecker(
+ libpath="../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py",
+ )
+ task.load()
+
+ res = next(
+ task.process(
+ [
+ {
+ "answers": [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": 0.1,
+ "time": 0.5,
+ },
+ "name": "final_velocity",
+ }
+ ],
+ }
+ ]
+ )
+ )
+ res
+ #[{'answers': [{'arguments': {'initial_velocity': 0.2, 'acceleration': 0.1, 'time': 0.5}, 'name': 'final_velocity'}], 'keep_row_after_execution_check': True, 'execution_result': ['0.25']}]
+ ```
+ """
+
+ libpath: str = Field(
+ default=...,
+ description=(
+ "The path to the library where we will retrieve the functions, "
+ "or a folder with python files named the same as the functions they contain.",
+ ),
+ )
+ check_is_dangerous: bool = Field(
+ default=True,
+ description=(
+ "Bool to exclude some potentially dangerous functions, it contains "
+ "some heuristics found while testing. This functions can run subprocesses, "
+ "deal with the OS, or have other potentially dangerous operations.",
+ ),
+ )
+
+ _toolbox: Union["ModuleType", None] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the library where the functions will be extracted from."""
+ super().load()
+ if Path(self.libpath).suffix == ".py":
+ self._toolbox = load_module_from_path(self.libpath)
+
+ def unload(self) -> None:
+ self._toolbox = None
+
+ @property
+ def inputs(self) -> "StepColumns":
+ """The inputs for the task are those found in the original dataset."""
+ return ["answers"]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ """The outputs are the columns required by `APIGenGenerator` task."""
+ return ["keep_row_after_execution_check", "execution_result"]
+
+ def _get_function(self, function_name: str) -> Callable:
+ """Retrieves the function from the toolbox.
+
+ Args:
+ function_name: The name of the function to retrieve.
+
+ Returns:
+ Callable: The function to be executed.
+ """
+ if self._toolbox:
+ return getattr(self._toolbox, function_name, None)
+ try:
+ toolbox = load_module_from_path(
+ str(Path(self.libpath) / f"{function_name}.py")
+ )
+ return getattr(toolbox, function_name, None)
+ except FileNotFoundError:
+ return None
+ except Exception as e:
+ self._logger.warning(f"Error loading function '{function_name}': {e}")
+ return None
+
+ def _is_dangerous(self, function: Callable) -> bool:
+ """Checks if a function is dangerous to remove it.
+ Contains a list of heuristics to avoid executing possibly dangerous functions.
+ """
+ source_code = inspect.getsource(function)
+ # We don't want to execute functions that use subprocess
+ if (
+ ("subprocess." in source_code)
+ or ("os.system(" in source_code)
+ or ("input(" in source_code)
+ # Avoiding threading
+ or ("threading.Thread(" in source_code)
+ or ("exec(" in source_code)
+ # Avoiding argparse (not sure why)
+ or ("argparse.ArgumentParser(" in source_code)
+ # Avoiding logging changing the levels to not mess with the logs
+ or (".setLevel(" in source_code)
+ # Don't run a test battery
+ or ("unittest.main(" in source_code)
+ # Avoid exiting the program
+ or ("sys.exit(" in source_code)
+ or ("exit(" in source_code)
+ or ("raise SystemExit(" in source_code)
+ or ("multiprocessing.Pool(" in source_code)
+ ):
+ return True
+ return False
+
+ @override
+ def process(self, inputs: StepInput) -> "StepOutput":
+ """Checks the answer to see if it can be executed.
+ Captures the possible errors and returns them.
+
+ If a single example is provided, it is copied to avoid raising an error.
+
+ Args:
+ inputs: A list of dictionaries with the input data.
+
+ Yields:
+ A list of dictionaries with the output data.
+ """
+ for input in inputs:
+ output = []
+ if input["answers"]:
+ answers = json.loads(input["answers"])
+ else:
+ input.update(
+ **{
+ "keep_row_after_execution_check": False,
+ "execution_result": ["No answers were provided."],
+ }
+ )
+ continue
+ for answer in answers:
+ if answer is None:
+ output.append(
+ {
+ "keep": False,
+ "execution_result": "Nothing was generated for this answer.",
+ }
+ )
+ continue
+
+ function_name = answer.get("name", None)
+ arguments = answer.get("arguments", None)
+
+ self._logger.debug(
+ f"Executing function '{function_name}' with arguments: {arguments}"
+ )
+ function = self._get_function(function_name)
+
+ if self.check_is_dangerous:
+ if function and self._is_dangerous(function):
+ function = None
+
+ if function is None:
+ output.append(
+ {
+ "keep": False,
+ "execution_result": f"Function '{function_name}' not found.",
+ }
+ )
+ else:
+ execution = execute_from_response(function, arguments)
+ output.append(
+ {
+ "keep": execution["keep"],
+ "execution_result": execution["execution_result"],
+ }
+ )
+ # We only consider a good response if all the answers were executed successfully,
+ # but keep the reasons for further review if needed.
+ input.update(
+ **{
+ "keep_row_after_execution_check": all(
+ o["keep"] is True for o in output
+ ),
+ "execution_result": [o["execution_result"] for o in output],
+ }
+ )
+
+ yield inputs
diff --git a/src/distilabel/steps/tasks/apigen/generator.py b/src/distilabel/steps/tasks/apigen/generator.py
new file mode 100644
index 000000000..c1c691e37
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/generator.py
@@ -0,0 +1,448 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as importlib_resources
+import json
+import random
+from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Union
+
+import orjson
+from jinja2 import Template
+from pydantic import PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.tasks.apigen.utils import remove_fences
+from distilabel.steps.tasks.base import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
+
+SYSTEM_PROMPT_API_GEN: Final[str] = """\
+You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.
+
+Construct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.
+
+Ensure the query:
+- Is clear and concise
+- Demonstrates typical use cases
+- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words
+- Across a variety level of difficulties, ranging from beginner and advanced use cases
+- The corresponding result's parameter types and ranges match with the function's descriptions
+
+Ensure the answer:
+- Is a list of function calls in JSON format
+- The length of the answer list should be equal to the number of requests in the query
+- Can solve all the requests in the query effectively"""
+
+
+class APIGenGenerator(Task):
+ """Generate queries and answers for the given functions in JSON format.
+
+ The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
+ verifiable and diverse function-calling datasets. The task generates a set of diverse queries
+ and corresponding answers for the given functions in JSON format.
+
+ Attributes:
+ system_prompt: The system prompt to guide the user in the generation of queries and answers.
+ use_tools: Whether to use the tools available in the prompt to generate the queries and answers.
+ In case the tools are given in the input, they will be added to the prompt.
+ number: The number of queries to generate. It can be a list, where each number will be
+ chosen randomly, or a dictionary with the number of queries and the probability of each.
+ I.e: `number=1`, `number=[1, 2, 3]`, `number={1: 0.5, 2: 0.3, 3: 0.2}` are all valid inputs.
+ It corresponds to the number of parallel queries to generate.
+ use_default_structured_output: Whether to use the default structured output or not.
+
+ Input columns:
+ - examples (`str`): Examples used as few shots to guide the model.
+ - func_name (`str`): Name for the function to generate.
+ - func_desc (`str`): Description of what the function should do.
+ - tools (`str`): JSON formatted string containing the tool representation of the function.
+
+ Output columns:
+ - query (`str`): The list of queries.
+ - answers (`str`): JSON formatted string with the list of answers, containing the info as
+ a dictionary to be passed to the functions.
+
+ Categories:
+ - text-generation
+
+ References:
+ - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
+ - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
+
+ Examples:
+ Generate without structured output (original implementation):
+
+ ```python
+ from distilabel.steps.tasks import ApiGenGenerator
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ apigen = ApiGenGenerator(
+ use_default_structured_output=False,
+ llm=llm
+ )
+ apigen.load()
+
+ res = next(
+ apigen.process(
+ [
+ {
+ "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ "func_name": "getrandommovie",
+ "func_desc": "Returns a list of random movies from a database by calling an external API."
+ }
+ ]
+ )
+ )
+ res
+ # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ # 'number': 1,
+ # 'func_name': 'getrandommovie',
+ # 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
+ # 'queries': ['I want to watch a movie tonight, can you recommend a random one from your database?',
+ # 'Give me 5 random movie suggestions from your database to plan my weekend.'],
+ # 'answers': [[{'name': 'getrandommovie', 'arguments': {}}],
+ # [{'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}},
+ # {'name': 'getrandommovie', 'arguments': {}}]],
+ # 'raw_input_api_gen_generator_0': [{'role': 'system',
+ # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
+ # {'role': 'user',
+ # 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n {\n "query": "The generated query.",\n "answers": [\n {\n "name": "api_name",\n "arguments": {\n "arg_name": "value"\n ... (more arguments as required)\n }\n },\n ... (more API calls as required)\n ]\n }\n]\n```\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Generate with structured output:
+
+ ```python
+ from distilabel.steps.tasks import ApiGenGenerator
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ apigen = ApiGenGenerator(
+ use_default_structured_output=True,
+ llm=llm
+ )
+ apigen.load()
+
+ res_struct = next(
+ apigen.process(
+ [
+ {
+ "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ "func_name": "getrandommovie",
+ "func_desc": "Returns a list of random movies from a database by calling an external API."
+ }
+ ]
+ )
+ )
+ res_struct
+ # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
+ # 'number': 1,
+ # 'func_name': 'getrandommovie',
+ # 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
+ # 'queries': ["I'm bored and want to watch a movie. Can you suggest some movies?",
+ # "My family and I are planning a movie night. We can't decide on what to watch. Can you suggest some random movie titles?"],
+ # 'answers': [[{'arguments': {}, 'name': 'getrandommovie'}],
+ # [{'arguments': {}, 'name': 'getrandommovie'}]],
+ # 'raw_input_api_gen_generator_0': [{'role': 'system',
+ # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
+ # {'role': 'user',
+ # 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+ """
+
+ system_prompt: str = SYSTEM_PROMPT_API_GEN
+ use_default_structured_output: bool = False
+ number: Union[int, List[int], Dict[int, float]] = 1
+ use_tools: bool = True
+
+ _number: Union[int, None] = PrivateAttr(None)
+ _fn_parallel_queries: Union[Callable[[], str], None] = PrivateAttr(None)
+ _format_inst: Union[str, None] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the template for the generator prompt."""
+ super().load()
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "apigen"
+ / "generator.jinja2"
+ )
+ self._template = Template(open(_path).read())
+ self._format_inst = self._set_format_inst()
+
+ def _parallel_queries(self, number: int) -> Callable[[int], str]:
+ """Prepares the function to update the parallel queries guide in the prompt.
+
+ Raises:
+ ValueError: if `is_parallel` is not a boolean or a list of floats.
+
+ Returns:
+ The function to generate the parallel queries guide.
+ """
+ if number > 1:
+ return (
+ "It can contain multiple parallel queries in natural language for the given functions. "
+ "They could use either the same function with different arguments or different functions.\n"
+ )
+ return ""
+
+ def _get_number(self) -> int:
+ """Generates the number of queries to generate in a single call.
+ The number must be set to `_number` to avoid changing the original value
+ when calling `_default_error`.
+ """
+ if isinstance(self.number, list):
+ self._number = random.choice(self.number)
+ elif isinstance(self.number, dict):
+ self._number = random.choices(
+ list(self.number.keys()), list(self.number.values())
+ )[0]
+ else:
+ self._number = self.number
+ return self._number
+
+ def _set_format_inst(self) -> str:
+ """Prepares the function to generate the formatted instructions for the prompt.
+
+ If the default structured output is used, returns an empty string because nothing
+ else is needed, otherwise, returns the original addition to the prompt to guide the model
+ to generate a formatted JSON.
+ """
+ return (
+ "\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n"
+ "```\n"
+ "[\n"
+ " {\n"
+ ' "query": "The generated query.",\n'
+ ' "answers": [\n'
+ " {\n"
+ ' "name": "api_name",\n'
+ ' "arguments": {\n'
+ ' "arg_name": "value"\n'
+ " ... (more arguments as required)\n"
+ " }\n"
+ " },\n"
+ " ... (more API calls as required)\n"
+ " ]\n"
+ " }\n"
+ "]\n"
+ "```\n"
+ )
+
+ def _get_func_desc(self, input: Dict[str, Any]) -> str:
+ """If available and required, will use the info from the tools in the
+ prompt for extra information. Otherwise will use jut the function description.
+ """
+ if not self.use_tools:
+ return input["func_desc"]
+ extra = "" # Extra information from the tools (if available will be added)
+ if "tools" in input:
+ extra = f"\n\nThis is the available tool to guide you (respect the order of the parameters):\n{input['tools']}"
+ return input["func_desc"] + extra
+
+ @property
+ def inputs(self) -> "StepColumns":
+ """The inputs for the task."""
+ return {
+ "examples": True,
+ "func_name": True,
+ "func_desc": True,
+ "tools": False,
+ }
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType`."""
+ number = self._get_number()
+ parallel_queries = self._parallel_queries(number)
+ return [
+ {"role": "system", "content": self.system_prompt},
+ {
+ "role": "user",
+ "content": self._template.render(
+ examples=input["examples"],
+ parallel_queries=parallel_queries,
+ number=number,
+ func_name=input["func_name"],
+ func_desc=self._get_func_desc(input),
+ format_inst=self._format_inst,
+ ),
+ },
+ ]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ """The output for the task are the queries and corresponding answers."""
+ return ["query", "answers", "model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """The output is formatted as a list with the score of each instruction.
+
+ Args:
+ output: the raw output of the LLM.
+ input: the input to the task. Used for obtaining the number of responses.
+
+ Returns:
+ A dict with the queries and answers pairs.
+ The answers are an array of answers corresponding to the query.
+ Each answer is represented as an object with the following properties:
+ - name (string): The name of the tool used to generate the answer.
+ - arguments (object): An object representing the arguments passed to the tool to generate the answer.
+ Each argument is represented as a key-value pair, where the key is the parameter name and the
+ value is the corresponding value.
+ """
+ if output is None:
+ return self._default_error(input)
+
+ if not self.use_default_structured_output:
+ output = remove_fences(output)
+
+ try:
+ pairs = orjson.loads(output)
+ except orjson.JSONDecodeError:
+ return self._default_error(input)
+
+ pairs = pairs["pairs"] if self.use_default_structured_output else pairs
+
+ return self._format_output(pairs, input)
+
+ def _format_output(
+ self, pairs: Dict[str, Any], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Parses the response, returning a dictionary with queries and answers.
+
+ Args:
+ pairs: The parsed dictionary from the LLM's output.
+ input: The input from the `LLM`.
+
+ Returns:
+ Formatted output, where the `queries` are a list of strings, and the `answers`
+ are a list of objects.
+ """
+ try:
+ input.update(
+ **{
+ "query": pairs[0]["query"],
+ "answers": json.dumps(pairs[0]["answers"]),
+ }
+ )
+ return input
+ except Exception as e:
+ self._logger.error(f"Error formatting output: {e}, pairs: '{pairs}'")
+ return self._default_error(input)
+
+ def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """Returns a default error output, to fill the responses in case of failure."""
+ input.update(
+ **{
+ "query": None,
+ "answers": json.dumps([None] * self._number),
+ }
+ )
+ return input
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ The schema corresponds to the following:
+
+ ```python
+ from typing import Dict, List
+ from pydantic import BaseModel
+
+
+ class Answer(BaseModel):
+ name: str
+ arguments: Dict[str, str]
+
+ class QueryAnswer(BaseModel):
+ query: str
+ answers: List[Answer]
+
+ class QueryAnswerPairs(BaseModel):
+ pairs: List[QueryAnswer]
+
+ json.dumps(QueryAnswerPairs.model_json_schema(), indent=4)
+ ```
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ return {
+ "$defs": {
+ "Answer": {
+ "properties": {
+ "name": {"title": "Name", "type": "string"},
+ "arguments": {
+ "additionalProperties": {"type": "string"},
+ "title": "Arguments",
+ "type": "object",
+ },
+ },
+ "required": ["name", "arguments"],
+ "title": "Answer",
+ "type": "object",
+ },
+ "QueryAnswer": {
+ "properties": {
+ "query": {"title": "Query", "type": "string"},
+ "answers": {
+ "items": {"$ref": "#/$defs/Answer"},
+ "title": "Answers",
+ "type": "array",
+ },
+ },
+ "required": ["query", "answers"],
+ "title": "QueryAnswer",
+ "type": "object",
+ },
+ },
+ "properties": {
+ "pairs": {
+ "items": {"$ref": "#/$defs/QueryAnswer"},
+ "title": "Pairs",
+ "type": "array",
+ }
+ },
+ "required": ["pairs"],
+ "title": "QueryAnswerPairs",
+ "type": "object",
+ }
diff --git a/src/distilabel/steps/tasks/apigen/semantic_checker.py b/src/distilabel/steps/tasks/apigen/semantic_checker.py
new file mode 100644
index 000000000..5ec7cdc57
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/semantic_checker.py
@@ -0,0 +1,308 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as importlib_resources
+from typing import TYPE_CHECKING, Any, Dict, Final, Union
+
+import orjson
+from jinja2 import Template
+from pydantic import PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.tasks.apigen.utils import remove_fences
+from distilabel.steps.tasks.base import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
+
+SYSTEM_PROMPT_SEMANTIC_CHECKER: Final[str] = """\
+As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.
+These function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.
+
+Do not pass if:
+1. The function call does not align with the query’s objective, or the input arguments appear incorrect.
+2. The function call and arguments are not properly chosen from the available functions.
+3. The number of function calls does not correspond to the user’s intentions.
+4. The execution results are irrelevant and do not match the function’s purpose.
+5. The execution results contain errors or reflect that the function calls were not executed successfully.
+""".rstrip()
+
+
+class APIGenSemanticChecker(Task):
+ r"""Generate queries and answers for the given functions in JSON format.
+
+ The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
+ verifiable and diverse function-calling datasets. The task generates a set of diverse queries
+ and corresponding answers for the given functions in JSON format.
+
+ Attributes:
+ system_prompt: System prompt for the task. Has a default one.
+ exclude_failed_execution: Whether to exclude failed executions (won't run on those
+ rows that have a False in `keep_row_after_execution_check` column, which
+ comes from running `APIGenExecutionChecker`). Defaults to True.
+
+ Input columns:
+ - func_desc (`str`): Description of what the function should do.
+ - query (`str`): Instruction from the user.
+ - answers (`str`): JSON encoded list with arguments to be passed to the function/API.
+ Should be loaded using `json.loads`.
+ - execution_result (`str`): Result of the function/API executed.
+
+ Output columns:
+ - thought (`str`): Reasoning for the output on whether to keep this output or not.
+ - keep_row_after_semantic_check (`bool`): True or False, can be used to filter
+ afterwards.
+
+ Categories:
+ - filtering
+ - text-generation
+
+ References:
+ - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
+ - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
+
+ Examples:
+
+ Semantic checker for generated function calls (original implementation):
+
+ ```python
+ from distilabel.steps.tasks import APIGenSemanticChecker
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ semantic_checker = APIGenSemanticChecker(
+ use_default_structured_output=False,
+ llm=llm
+ )
+ semantic_checker.load()
+
+ res = next(
+ semantic_checker.process(
+ [
+ {
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
+ "execution_result": "The Maine Coon is a big and hairy breed of cat",
+ }
+ ]
+ )
+ )
+ res
+ # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
+ # 'query': 'What information can be obtained about the Maine Coon cat breed?',
+ # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
+ # 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
+ # 'thought': '',
+ # 'keep_row_after_semantic_check': True,
+ # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
+ # 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
+ # {'role': 'user',
+ # 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n "thought": "Concisely describe your reasoning here",\n "pass": "yes" or "no"\n}\n```\n'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Semantic checker for generated function calls (structured output):
+
+ ```python
+ from distilabel.steps.tasks import APIGenSemanticChecker
+ from distilabel.llms import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 1024,
+ },
+ )
+ semantic_checker = APIGenSemanticChecker(
+ use_default_structured_output=True,
+ llm=llm
+ )
+ semantic_checker.load()
+
+ res = next(
+ semantic_checker.process(
+ [
+ {
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
+ "execution_result": "The Maine Coon is a big and hairy breed of cat",
+ }
+ ]
+ )
+ )
+ res
+ # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
+ # 'query': 'What information can be obtained about the Maine Coon cat breed?',
+ # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
+ # 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
+ # 'keep_row_after_semantic_check': True,
+ # 'thought': '',
+ # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
+ # 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
+ # {'role': 'user',
+ # 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+ """
+
+ system_prompt: str = SYSTEM_PROMPT_SEMANTIC_CHECKER
+ use_default_structured_output: bool = False
+
+ _format_inst: Union[str, None] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the template for the generator prompt."""
+ super().load()
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "apigen"
+ / "semantic_checker.jinja2"
+ )
+
+ self._template = Template(open(_path).read())
+ self._format_inst = self._set_format_inst()
+
+ def _set_format_inst(self) -> str:
+ """Prepares the function to generate the formatted instructions for the prompt.
+
+ If the default structured output is used, returns an empty string because nothing
+ else is needed, otherwise, returns the original addition to the prompt to guide the model
+ to generate a formatted JSON.
+ """
+ return (
+ "\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n"
+ "```\n"
+ "{\n"
+ ' "thought": "Concisely describe your reasoning here",\n'
+ ' "passes": "yes" or "no"\n'
+ "}\n"
+ "```\n"
+ )
+
+ @property
+ def inputs(self) -> "StepColumns":
+ """The inputs for the task."""
+ return {
+ "func_desc": True,
+ "query": True,
+ "answers": True,
+ "execution_result": True,
+ "keep_row_after_execution_check": True,
+ }
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType`."""
+ return [
+ {"role": "system", "content": self.system_prompt},
+ {
+ "role": "user",
+ "content": self._template.render(
+ func_desc=input["func_desc"],
+ query=input["query"] or "",
+ func_call=input["answers"] or "",
+ execution_result=input["execution_result"],
+ format_inst=self._format_inst,
+ ),
+ },
+ ]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ """The output for the task are the queries and corresponding answers."""
+ return ["keep_row_after_semantic_check", "thought"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """The output is formatted as a list with the score of each instruction.
+
+ Args:
+ output: the raw output of the LLM.
+ input: the input to the task. Used for obtaining the number of responses.
+
+ Returns:
+ A dict with the queries and answers pairs.
+ The answers are an array of answers corresponding to the query.
+ Each answer is represented as an object with the following properties:
+ - name (string): The name of the tool used to generate the answer.
+ - arguments (object): An object representing the arguments passed to the tool to generate the answer.
+ Each argument is represented as a key-value pair, where the key is the parameter name and the
+ value is the corresponding value.
+ """
+ if output is None:
+ return self._default_error(input)
+
+ output = remove_fences(output)
+
+ try:
+ result = orjson.loads(output)
+ # Update the column name and change to bool
+ result["keep_row_after_semantic_check"] = (
+ result.pop("passes").lower() == "yes"
+ )
+ input.update(**result)
+ return input
+ except orjson.JSONDecodeError:
+ return self._default_error(input)
+
+ def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """Default error message for the task."""
+ input.update({"thought": None, "keep_row_after_semantic_check": None})
+ return input
+
+ @override
+ def get_structured_output(self) -> Dict[str, Any]:
+ """Creates the json schema to be passed to the LLM, to enforce generating
+ a dictionary with the output which can be directly parsed as a python dictionary.
+
+ The schema corresponds to the following:
+
+ ```python
+ from typing import Literal
+ from pydantic import BaseModel
+ import json
+
+ class Checker(BaseModel):
+ thought: str
+ passes: Literal["yes", "no"]
+
+ json.dumps(Checker.model_json_schema(), indent=4)
+ ```
+
+ Returns:
+ JSON Schema of the response to enforce.
+ """
+ return {
+ "properties": {
+ "thought": {"title": "Thought", "type": "string"},
+ "passes": {"enum": ["yes", "no"], "title": "Passes", "type": "string"},
+ },
+ "required": ["thought", "passes"],
+ "title": "Checker",
+ "type": "object",
+ }
diff --git a/src/distilabel/steps/tasks/apigen/utils.py b/src/distilabel/steps/tasks/apigen/utils.py
new file mode 100644
index 000000000..85ff0b764
--- /dev/null
+++ b/src/distilabel/steps/tasks/apigen/utils.py
@@ -0,0 +1,194 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.util
+import re
+import signal
+from typing import TYPE_CHECKING, Any, Callable, Dict, TypedDict, Union
+
+from distilabel.steps.base import Step, StepInput
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ from distilabel.steps.typing import StepColumns, StepOutput
+
+
+class PrepareExamples(Step):
+ r"""Helper step to create examples from `query` and `answers` pairs used as Few Shots in APIGen.
+
+ Attributes:
+ template (str): The template to format the examples.
+
+ Input columns:
+ - query (`str`): The query to generate examples from.
+ - answers (`str`): The answers to the query.
+
+ Output columns:
+ - examples (`str`): The formatted examples.
+
+ Categories:
+ - format
+
+ Examples:
+ Generate examples for APIGen:
+
+ ```python
+ from distilabel.steps.tasks.apigen.utils import PrepareExamples
+
+ prepare_examples = PrepareExamples()
+ result = next(prepare_examples.process(
+ [
+ {
+ "query": ['I need the area of circles with radius 2.5, 5, and 7.5 inches, please.', 'Can you provide the current locations of buses and trolleys on route 12?'],
+ "answers": ['[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]', '[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]']
+ }
+ ]
+ )
+ # result
+ # [{'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}, {'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}]
+ ```
+ """
+
+ template: str = "## Query:\n{query}\n## Answers:\n{answers}"
+
+ @property
+ def inputs(self) -> "StepColumns":
+ return ["query", "answers"]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ return ["examples"]
+
+ def process(self, inputs: StepInput) -> "StepOutput":
+ """The process prepares the data for the `APIGenGenerator` task.
+
+ If a single example is provided, it is copied to avoid raising an error.
+
+ Args:
+ inputs: A list of dictionaries with the input data.
+
+ Yields:
+ A list of dictionaries with the output data.
+ """
+ outputs = []
+ for input in inputs:
+ example_list = []
+ for query, answers in zip(input["query"], input["answers"]):
+ example_list.append(self.template.format(query=query, answers=answers))
+ outputs.append({"examples": "\n\n".join(example_list)})
+
+ yield outputs
+
+
+def load_module_from_path(path: str) -> "ModuleType":
+ """Loads a python module from a given path.
+
+ Args:
+ path: Path pointing to the module.
+
+ Returns:
+ ModuleType
+
+ Example:
+ ```python
+ path = "/path/to/module.py"
+ module = load_module_from_path(path)
+ # And you can load functions from the module like this:
+ function = getattr(module, "function_name")
+ function(*args, **kwargs)
+ ```
+ """
+ spec = importlib.util.spec_from_file_location("module.name", path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
+
+
+class FunctionResult(TypedDict):
+ keep: bool
+ execution_result: str
+
+
+def execute_from_response(
+ function: Callable, call_answer: Union[Dict[str, Any], None]
+) -> FunctionResult:
+ """Executes a function with the given arguments as generated by `APIGenGenerator`.
+
+ Given that we cannot cast all the arguments arbitrarily, we try to evaluate them,
+ which ensures the strings can be converted to the correct type if possible (say
+ a list of lists of ints will be passed as such instead of its string representation).
+
+ Args:
+ function: A callable object.
+ call_answer: The arguments to call the function, as generated by the model.
+
+ Returns:
+ A container with the result of the execution and if the row should be kept.
+ """
+ if not function:
+ return FunctionResult(keep=False, execution_result="Function not found")
+
+ if call_answer:
+ for key, value in call_answer.items():
+ if isinstance(value, str):
+ try:
+ call_answer[key] = eval(value)
+ except Exception:
+ # Leave as is and expect the function to handle it
+ pass
+
+ try:
+ if call_answer:
+ result = run_function_with_timeout(function, 5, *call_answer.values())
+ else:
+ # There can be functions that do not require arguments
+ result = run_function_with_timeout(function, 5)
+ return FunctionResult(keep=True, execution_result=str(result))
+ except Exception as e:
+ return FunctionResult(keep=False, execution_result=str(e))
+
+
+def remove_json_fences(text: str) -> str:
+ pattern = r"^```json\n([\s\S]*)\n```$"
+ match = re.match(pattern, text, re.MULTILINE)
+ if match:
+ return match.group(1)
+ return text
+
+
+def remove_fences(text: str) -> str:
+ pattern = r"^```\n([\s\S]*)\n```$"
+ match = re.match(pattern, text, re.MULTILINE)
+ if match:
+ return match.group(1)
+ return text
+
+
+def timeout_handler(signum, frame):
+ raise TimeoutError("Function execution timed out")
+
+
+def run_function_with_timeout(function: Callable, timeout: int = 5, *args: Any) -> Any:
+ """Run a function with a timeout, to limit the total time waiting for a result."""
+ signal.signal(signal.SIGALRM, timeout_handler)
+ signal.alarm(timeout)
+
+ try:
+ result = function(*args)
+ finally:
+ # Cancel the alarm
+ signal.alarm(0)
+
+ return result
diff --git a/src/distilabel/steps/tasks/argilla_labeller.py b/src/distilabel/steps/tasks/argilla_labeller.py
new file mode 100644
index 000000000..d0874ed3d
--- /dev/null
+++ b/src/distilabel/steps/tasks/argilla_labeller.py
@@ -0,0 +1,614 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import warnings
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+import orjson as json
+from jinja2 import Template
+from pydantic import BaseModel, Field, PrivateAttr
+from typing_extensions import override
+
+from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.steps.base import StepInput
+from distilabel.steps.tasks.base import Task
+
+if sys.version_info < (3, 9):
+ import importlib_resources
+else:
+ import importlib.resources as importlib_resources
+
+if TYPE_CHECKING:
+ from argilla import (
+ LabelQuestion,
+ MultiLabelQuestion,
+ RatingQuestion,
+ Record,
+ TextField,
+ TextQuestion,
+ )
+
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepOutput
+
+
+class ArgillaLabeller(Task):
+ """
+ Annotate Argilla records based on input fields, example records and question settings.
+
+ This task is designed to facilitate the annotation of Argilla records by leveraging a pre-trained LLM.
+ It uses a system prompt that guides the LLM to understand the input fields, the question type,
+ and the question settings. The task then formats the input data and generates a response based on the question.
+ The response is validated against the question's value model, and the final suggestion is prepared for annotation.
+
+ Attributes:
+ _template: a Jinja2 template used to format the input for the LLM.
+
+ Input columns:
+ - record (`argilla.Record`): The record to be annotated.
+ - fields (`Optional[List[Dict[str, Any]]]`): The list of field settings for the input fields.
+ - question (`Optional[Dict[str, Any]]`): The question settings for the question to be answered.
+ - example_records (`Optional[List[Dict[str, Any]]]`): The few shot example records with responses to be used to answer the question.
+ - guidelines (`Optional[str]`): The guidelines for the annotation task.
+
+ Output columns:
+ - suggestion (`Dict[str, Any]`): The final suggestion for annotation.
+
+ Categories:
+ - text-classification
+ - scorer
+ - text-generation
+
+ References:
+ - [`Argilla: Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets`](https://github.com/argilla-io/argilla/)
+
+ Examples:
+ Annotate a record with the same dataset and question:
+
+ ```python
+ import argilla as rg
+ from argilla import Suggestion
+ from distilabel.steps.tasks import ArgillaLabeller
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Get information from Argilla dataset definition
+ dataset = rg.Dataset("my_dataset")
+ pending_records_filter = rg.Filter(("status", "==", "pending"))
+ completed_records_filter = rg.Filter(("status", "==", "completed"))
+ pending_records = list(
+ dataset.records(
+ query=rg.Query(filter=pending_records_filter),
+ limit=5,
+ )
+ )
+ example_records = list(
+ dataset.records(
+ query=rg.Query(filter=completed_records_filter),
+ limit=5,
+ )
+ )
+ field = dataset.settings.fields["text"]
+ question = dataset.settings.questions["label"]
+
+ # Initialize the labeller with the model and fields
+ labeller = ArgillaLabeller(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ ),
+ fields=[field],
+ question=question,
+ example_records=example_records,
+ guidelines=dataset.guidelines
+ )
+ labeller.load()
+
+ # Process the pending records
+ result = next(
+ labeller.process(
+ [
+ {
+ "record": record
+ } for record in pending_records
+ ]
+ )
+ )
+
+ # Add the suggestions to the records
+ for record, suggestion in zip(pending_records, result):
+ record.suggestions.add(Suggestion(**suggestion["suggestion"]))
+
+ # Log the updated records
+ dataset.records.log(pending_records)
+ ```
+
+ Annotate a record with alternating datasets and questions:
+
+ ```python
+ import argilla as rg
+ from distilabel.steps.tasks import ArgillaLabeller
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Get information from Argilla dataset definition
+ dataset = rg.Dataset("my_dataset")
+ field = dataset.settings.fields["text"]
+ question = dataset.settings.questions["label"]
+ question2 = dataset.settings.questions["label2"]
+
+ # Initialize the labeller with the model and fields
+ labeller = ArgillaLabeller(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ )
+ )
+ labeller.load()
+
+ # Process the record
+ record = next(dataset.records())
+ result = next(
+ labeller.process(
+ [
+ {
+ "record": record,
+ "fields": [field],
+ "question": question,
+ },
+ {
+ "record": record,
+ "fields": [field],
+ "question": question2,
+ }
+ ]
+ )
+ )
+
+ # Add the suggestions to the record
+ for suggestion in result:
+ record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))
+
+ # Log the updated record
+ dataset.records.log([record])
+ ```
+
+ Overwrite default prompts and instructions:
+
+ ```python
+ import argilla as rg
+ from distilabel.steps.tasks import ArgillaLabeller
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Overwrite default prompts and instructions
+ labeller = ArgillaLabeller(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ ),
+ system_prompt="You are an expert annotator and labelling assistant that understands complex domains and natural language processing.",
+ question_to_label_instruction={
+ "label_selection": "Select the appropriate label from the list of provided labels.",
+ "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
+ "text": "Provide a text response to the question.",
+ "rating": "Provide a rating for the question.",
+ },
+ )
+ labeller.load()
+ ```
+ """
+
+ system_prompt: str = (
+ "You are an expert annotator and labelling assistant that understands complex domains and natural language processing. "
+ "You are given input fields and a question. "
+ "You should create a valid JSON object as an answer to the question based on the input fields. "
+ "1. Understand the input fields and optional guidelines. "
+ "2. Understand the question type and the question settings. "
+ "3. Reason through your response step-by-step. "
+ "4. Provide a valid JSON object as an answer to the question."
+ )
+ question_to_label_instruction: Dict[str, str] = {
+ "label_selection": "Select the appropriate label from the list of provided labels.",
+ "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
+ "text": "Provide a text response to the question.",
+ "rating": "Provide a rating for the question.",
+ }
+ example_records: Optional[
+ RuntimeParameter[Union[List[Union[Dict[str, Any], BaseModel]], None]]
+ ] = Field(
+ default=None,
+ description="The few shot serialized example records or `BaseModel`s with responses to be used to answer the question.",
+ )
+ fields: Optional[
+ RuntimeParameter[Union[List[Union[BaseModel, Dict[str, Any]]], None]]
+ ] = Field(
+ default=None,
+ description="The field serialized field settings or `BaseModel` for the fields to be used to answer the question.",
+ )
+ question: Optional[
+ RuntimeParameter[
+ Union[
+ Dict[str, Any],
+ BaseModel,
+ None,
+ ]
+ ]
+ ] = Field(
+ default=None,
+ description="The question serialized question settings or `BaseModel` for the question to be answered.",
+ )
+ guidelines: Optional[RuntimeParameter[str]] = Field(
+ default=None,
+ description="The guidelines for the annotation task.",
+ )
+
+ _template: Union[Template, None] = PrivateAttr(...)
+ _client: Optional[Any] = PrivateAttr(None)
+
+ def load(self) -> None:
+ """Loads the Jinja2 template."""
+ super().load()
+
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "argillalabeller.jinja2"
+ )
+
+ self._template = Template(open(_path).read())
+
+ @property
+ def inputs(self) -> Dict[str, bool]:
+ return {
+ "record": True,
+ "fields": False,
+ "question": False,
+ "example_records": False,
+ "guidelines": False,
+ }
+
+ def _format_record(
+ self, record: Dict[str, Any], fields: List[Dict[str, Any]]
+ ) -> str:
+ """Format the record fields into a string.
+
+ Args:
+ record (Dict[str, Any]): The record to format.
+ fields (List[Dict[str, Any]]): The fields to format.
+
+ Returns:
+ str: The formatted record fields.
+ """
+ output = []
+ for field in fields:
+ if title := field.get("title"):
+ output.append(f"title: {title}")
+ if description := field.get("description"):
+ output.append(f"description: {description}")
+ output.append(record.get("fields", {}).get(field.get("name", "")))
+ return "\n".join(output)
+
+ def _get_label_instruction(self, question: Dict[str, Any]) -> str:
+ """Get the label instruction for the question.
+
+ Args:
+ question (Dict[str, Any]): The question to get the label instruction for.
+
+ Returns:
+ str: The label instruction for the question.
+ """
+ question_type = question["settings"]["type"]
+ return self.question_to_label_instruction[question_type]
+
+ def _format_question(self, question: Dict[str, Any]) -> str:
+ """Format the question settings into a string.
+
+ Args:
+ question (Dict[str, Any]): The question to format.
+
+ Returns:
+ str: The formatted question.
+ """
+ output = [
+ f"title: {question.get('title', '')}",
+ f"description: {question.get('description', '')}",
+ f"label_instruction: {self._get_label_instruction(question)}",
+ ]
+ settings = question.get("settings", {})
+ if "options" in settings:
+ output.append(
+ f"labels: {[option['value'] for option in settings.get('options', [])]}"
+ )
+ return "\n".join(output)
+
+ def _format_example_records(
+ self,
+ records: List[Dict[str, Any]],
+ fields: List[Dict[str, Any]],
+ question: Dict[str, Any],
+ ) -> str:
+ """Format the example records into a string.
+
+ Args:
+ records (List[Dict[str, Any]]): The records to format.
+ fields (List[Dict[str, Any]]): The fields to format.
+ question (Dict[str, Any]): The question to format.
+
+ Returns:
+ str: The formatted example records.
+ """
+ base = []
+ for record in records:
+ responses = record.get("responses", {})
+ if responses.get(question["name"]):
+ base.append(self._format_record(record, fields))
+ value = responses[question["name"]][0]["value"]
+ formatted_value = self._assign_value_to_question_value_model(
+ value, question
+ )
+ base.append(f"Response: {formatted_value}")
+ base.append("")
+ else:
+ warnings.warn(
+ f"Record {record} has no response for question {question['name']}. Skipping example record.",
+ stacklevel=2,
+ )
+ return "\n".join(base)
+
+ def format_input(
+ self,
+ input: Dict[
+ str,
+ Union[
+ Dict[str, Any],
+ "Record",
+ "TextField",
+ "MultiLabelQuestion",
+ "LabelQuestion",
+ "RatingQuestion",
+ "TextQuestion",
+ ],
+ ],
+ ) -> "ChatType":
+ """Format the input into a chat message.
+
+ Args:
+ input: The input to format.
+
+ Returns:
+ The formatted chat message.
+
+ Raises:
+ ValueError: If question or fields are not provided.
+ """
+ input_keys = list(self.inputs.keys())
+ record = input[input_keys[0]]
+ fields = input.get(input_keys[1], self.fields)
+ question = input.get(input_keys[2], self.question)
+ examples = input.get(input_keys[3], self.example_records)
+ guidelines = input.get(input_keys[4], self.guidelines)
+
+ if question is None:
+ raise ValueError("Question must be provided.")
+ if fields is None or any(field is None for field in fields):
+ raise ValueError("Fields must be provided.")
+
+ record = record.to_dict() if not isinstance(record, dict) else record
+ question = question.serialize() if not isinstance(question, dict) else question
+ fields = [
+ field.serialize() if not isinstance(field, dict) else field
+ for field in fields
+ ]
+ examples = (
+ [
+ example.to_dict() if not isinstance(example, dict) else example
+ for example in examples
+ ]
+ if examples
+ else None
+ )
+
+ formatted_fields = self._format_record(record, fields)
+ formatted_question = self._format_question(question)
+ formatted_examples = (
+ self._format_example_records(examples, fields, question)
+ if examples
+ else False
+ )
+
+ prompt = self._template.render(
+ fields=formatted_fields,
+ question=formatted_question,
+ examples=formatted_examples,
+ guidelines=guidelines,
+ )
+
+ messages = []
+ if self.system_prompt:
+ messages.append({"role": "system", "content": self.system_prompt})
+ messages.append({"role": "user", "content": prompt})
+ return messages
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["suggestion"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Format the output into a dictionary.
+
+ Args:
+ output (Union[str, None]): The output to format.
+ input (Dict[str, Any]): The input to format.
+
+ Returns:
+ Dict[str, Any]: The formatted output.
+ """
+ from argilla import Suggestion
+
+ question: Union[
+ Any,
+ Dict[str, Any],
+ LabelQuestion,
+ MultiLabelQuestion,
+ RatingQuestion,
+ TextQuestion,
+ None,
+ ] = input.get(list(self.inputs.keys())[2], self.question) or self.question
+ question = question.serialize() if not isinstance(question, dict) else question
+ model = self._get_pydantic_model_of_structured_output(question)
+ validated_output = model(**json.loads(output))
+ value = self._get_value_from_question_value_model(validated_output)
+ suggestion = Suggestion(
+ value=value,
+ question_name=question["name"],
+ type="model",
+ agent=self.llm.model_name,
+ ).serialize()
+ return {
+ self.outputs[0]: {
+ k: v
+ for k, v in suggestion.items()
+ if k in ["value", "question_name", "type", "agent"]
+ }
+ }
+
+ def _set_llm_structured_output_for_question(self, question: Dict[str, Any]) -> None:
+ runtime_parameters = self.llm._runtime_parameters
+ runtime_parameters.update(
+ {
+ "structured_output": {
+ "format": "json",
+ "schema": self._get_pydantic_model_of_structured_output(question),
+ },
+ }
+ )
+ self.llm.set_runtime_parameters(runtime_parameters)
+
+ @override
+ def process(self, inputs: StepInput) -> "StepOutput":
+ """Process the input through the task.
+
+ Args:
+ inputs (StepInput): The input to process.
+
+ Returns:
+ StepOutput: The output of the task.
+ """
+
+ question_list = [input.get("question", self.question) for input in inputs]
+ fields_list = [input.get("fields", self.fields) for input in inputs]
+ # check if any field for the field in fields is None
+ for fields in fields_list:
+ if any(field is None for field in fields):
+ raise ValueError(
+ "Fields must be provided during init or through `process` method."
+ )
+ # check if any question is None
+ if any(question is None for question in question_list):
+ raise ValueError(
+ "Question must be provided during init or through `process` method."
+ )
+ question_list = [
+ question.serialize() if not isinstance(question, dict) else question
+ for question in question_list
+ ]
+ if not all(question == question_list[0] for question in question_list):
+ warnings.warn(
+ "Not all questions are the same. Processing each question separately by setting the structured output for each question. This may impact performance.",
+ stacklevel=2,
+ )
+ for input, question in zip(inputs, question_list):
+ self._set_llm_structured_output_for_question(question)
+ yield from super().process([input])
+ else:
+ question = question_list[0]
+ self._set_llm_structured_output_for_question(question)
+ yield from super().process(inputs)
+
+ def _get_value_from_question_value_model(
+ self, question_value_model: BaseModel
+ ) -> Any:
+ """Get the value from the question value model.
+
+ Args:
+ question_value_model (BaseModel): The question value model to get the value from.
+
+ Returns:
+ Any: The value from the question value model.
+ """
+ for attr in ["label", "labels", "rating", "text"]:
+ if hasattr(question_value_model, attr):
+ return getattr(question_value_model, attr)
+ raise ValueError(f"Unsupported question type: {question_value_model}")
+
+ def _assign_value_to_question_value_model(
+ self, value: Any, question: Dict[str, Any]
+ ) -> BaseModel:
+ """Assign the value to the question value model.
+
+ Args:
+ value (Any): The value to assign.
+ question (Dict[str, Any]): The question to assign the value to.
+
+ Returns:
+ BaseModel: The question value model with the assigned value.
+ """
+ question_value_model = self._get_pydantic_model_of_structured_output(question)
+ for attr in ["label", "labels", "rating", "text"]:
+ try:
+ model_dict = {attr: value}
+ question_value_model = question_value_model(**model_dict)
+ return question_value_model.model_dump_json()
+ except AttributeError:
+ pass
+ return value
+
+ def _get_pydantic_model_of_structured_output(
+ self,
+ question: Dict[str, Any],
+ ) -> BaseModel:
+ """Get the Pydantic model of the structured output.
+
+ Args:
+ question (Dict[str, Any]): The question to get the Pydantic model of the structured output for.
+
+ Returns:
+ BaseModel: The Pydantic model of the structured output.
+ """
+
+ question_type = question["settings"]["type"]
+
+ if question_type == "multi_label_selection":
+
+ class QuestionValueModel(BaseModel):
+ labels: Optional[List[str]] = Field(default_factory=list)
+
+ elif question_type == "label_selection":
+
+ class QuestionValueModel(BaseModel):
+ label: str
+
+ elif question_type == "text":
+
+ class QuestionValueModel(BaseModel):
+ text: str
+
+ elif question_type == "rating":
+
+ class QuestionValueModel(BaseModel):
+ rating: int
+ else:
+ raise ValueError(f"Unsupported question type: {question}")
+
+ return QuestionValueModel
diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py
index e5600abf6..a2787cd36 100644
--- a/src/distilabel/steps/tasks/base.py
+++ b/src/distilabel/steps/tasks/base.py
@@ -13,7 +13,7 @@
# limitations under the License.
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import Field, PrivateAttr
from typing_extensions import override
@@ -34,7 +34,7 @@
if TYPE_CHECKING:
from distilabel.llms.typing import GenerateOutput
- from distilabel.steps.tasks.typing import FormattedInput
+ from distilabel.steps.tasks.typing import ChatType, FormattedInput
from distilabel.steps.typing import StepOutput
@@ -117,6 +117,28 @@ def unload(self) -> None:
self._logger.debug("Executing task unload logic.")
self.llm.unload()
+ @override
+ def impute_step_outputs(
+ self, step_output: List[Dict[str, Any]]
+ ) -> List[Dict[str, Any]]:
+ """
+ Imputes the outputs of the task in case the LLM failed to generate a response.
+ """
+ result = []
+ for row in step_output:
+ data = row.copy()
+ for output in self.get_outputs().keys():
+ data[output] = None
+ data = self._maybe_add_raw_input_output(
+ data,
+ None,
+ None,
+ add_raw_output=self.add_raw_output,
+ add_raw_input=self.add_raw_input,
+ )
+ result.append(data)
+ return result
+
@abstractmethod
def format_output(
self,
@@ -201,7 +223,7 @@ def _maybe_add_raw_input_output(
if add_raw_output:
meta[f"raw_output_{self.name}"] = raw_output
if add_raw_input:
- meta[f"raw_input_{self.name}"] = self.format_input(input)
+ meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None
if meta:
output[DISTILABEL_METADATA_KEY] = meta
@@ -247,6 +269,93 @@ def get_structured_output(self) -> Union[Dict[str, Any], None]:
"""
return None
+ def _sample_input(self) -> "ChatType":
+ """Returns a sample input to be used in the `print` method.
+ Tasks that don't adhere to a format input that returns a map of the type
+ str -> str should override this method to return a sample input.
+ """
+ return self.format_input(
+ {input: f"" for input in self.inputs}
+ )
+
+ def print(self, sample_input: Optional["ChatType"] = None) -> None:
+ """Prints a sample input to the console using the `rich` library.
+ Helper method to visualize the prompt of the task.
+
+ Args:
+ sample_input: A sample input to be printed. If not provided, a default will be
+ generated using the `_sample_input` method, which can be overriden by
+ subclasses. This should correspond to the same example you could pass to
+ the `format_input` method.
+ The variables be named by default.
+
+ Examples:
+ Print the URIAL prompt:
+
+ ```python
+ from distilabel.steps.tasks import URIAL
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ urial = URIAL(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ ),
+ )
+ urial.load()
+ urial.print()
+ ╭─────────────────────────────────────── Prompt: URIAL ────────────────────────────────────────╮
+ │ ╭────────────────────────────────────── User Message ───────────────────────────────────────╮ │
+ │ │ # Instruction │ │
+ │ │ │ │
+ │ │ Below is a list of conversations between a human and an AI assistant (you). │ │
+ │ │ Users place their queries under "# User:", and your responses are under "# Assistant:". │ │
+ │ │ You are a helpful, respectful, and honest assistant. │ │
+ │ │ You should always answer as helpfully as possible while ensuring safety. │ │
+ │ │ Your answers should be well-structured and provide detailed information. They should also │ │
+ │ │ have an engaging tone. │ │
+ │ │ Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, │ │
+ │ │ dangerous, or illegal content, even if it may be helpful. │ │
+ │ │ Your response must be socially responsible, and thus you can refuse to answer some │ │
+ │ │ controversial topics. │ │
+ │ │ │ │
+ │ │ │ │
+ │ │ # User: │ │
+ │ │ │ │
+ │ │ │ │
+ │ │ │ │
+ │ │ # Assistant: │ │
+ │ ╰───────────────────────────────────────────────────────────────────────────────────────────╯ │
+ ╰───────────────────────────────────────────────────────────────────────────────────────────────╯
+ ```
+ """
+ from rich.console import Console, Group
+ from rich.panel import Panel
+ from rich.text import Text
+
+ console = Console()
+ sample_input = sample_input or self._sample_input()
+
+ panels = []
+ for item in sample_input:
+ content = Text.assemble((item.get("content", ""),))
+ panel = Panel(
+ content,
+ title=f"[bold][magenta]{item.get('role', '').capitalize()} Message[/magenta][/bold]",
+ border_style="light_cyan3",
+ )
+ panels.append(panel)
+
+ # Create a group of panels
+ # Wrap the group in an outer panel
+ outer_panel = Panel(
+ Group(*panels),
+ title=f"[bold][magenta]Prompt: {type(self).__name__} [/magenta][/bold]",
+ border_style="light_cyan3",
+ expand=False,
+ )
+ console.print(outer_panel)
+
class Task(_Task, Step):
"""Task is a class that implements the `_Task` abstract class and adds the `Step`
diff --git a/src/distilabel/steps/tasks/clair.py b/src/distilabel/steps/tasks/clair.py
new file mode 100644
index 000000000..cbf189ab7
--- /dev/null
+++ b/src/distilabel/steps/tasks/clair.py
@@ -0,0 +1,199 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as importlib_resources
+from typing import TYPE_CHECKING, Any, Dict, Final, Union
+
+from jinja2 import Template
+from pydantic import PrivateAttr
+
+from distilabel.steps.tasks.base import Task
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.typing import StepColumns
+
+
+SYSTEM_PROMPT: Final[str] = (
+ "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."
+)
+
+
+class CLAIR(Task):
+ r"""Contrastive Learning from AI Revisions (CLAIR).
+
+ CLAIR uses an AI system to minimally revise a solution A→A´ such that the resulting
+ preference A `preferred` A’ is much more contrastive and precise.
+
+ Input columns:
+ - task (`str`): The task or instruction.
+ - student_solution (`str`): An answer to the task that is to be revised.
+
+ Output columns:
+ - revision (`str`): The revised text.
+ - rational (`str`): The rational for the provided revision.
+ - model_name (`str`): The name of the model used to generate the revision and rational.
+
+ Categories:
+ - preference
+ - text-generation
+
+ References:
+ - [`Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment`](https://arxiv.org/abs/2408.06266v1)
+ - [`APO and CLAIR - GitHub Repository`](https://github.com/ContextualAI/CLAIR_and_APO)
+
+ Examples:
+ Create contrastive preference pairs:
+
+ ```python
+ from distilabel.steps.tasks import CLAIR
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
+ generation_kwargs={
+ "temperature": 0.7,
+ "max_new_tokens": 4096,
+ },
+ )
+ clair_task = CLAIR(llm=llm)
+
+ clair_task.load()
+
+ result = next(
+ clair_task.process(
+ [
+ {
+ "task": "How many gaps are there between the earth and the moon?",
+ "student_solution": 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon's orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.'
+ }
+ ]
+ )
+ )
+ # result
+ # [{'task': 'How many gaps are there between the earth and the moon?',
+ # 'student_solution': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.',
+ # 'revision': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
+ # 'rational': 'The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.',
+ # 'distilabel_metadata': {'raw_output_c_l_a_i_r_0': '{teacher_reasoning}: The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.\n\n{corrected_student_solution}: There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
+ # 'raw_input_c_l_a_i_r_0': [{'role': 'system',
+ # 'content': "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."},
+ # {'role': 'user',
+ # 'content': '{task}: How many gaps are there between the earth and the moon?\n\n{student_solution}: There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.\n\n-----------------\n\nLet\'s first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.'}]},
+ # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
+ ```
+
+ Citations:
+
+ ```
+ @misc{doosterlinck2024anchoredpreferenceoptimizationcontrastive,
+ title={Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment},
+ author={Karel D'Oosterlinck and Winnie Xu and Chris Develder and Thomas Demeester and Amanpreet Singh and Christopher Potts and Douwe Kiela and Shikib Mehri},
+ year={2024},
+ eprint={2408.06266},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2408.06266},
+ }
+ ```
+ """
+
+ system_prompt: str = SYSTEM_PROMPT
+ _template: Union[Template, None] = PrivateAttr(...)
+
+ def load(self) -> None:
+ super().load()
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "clair.jinja2"
+ )
+ with open(_path, "r") as f:
+ self._template = Template(f.read())
+
+ @property
+ def inputs(self) -> "StepColumns":
+ return ["task", "student_solution"]
+
+ @property
+ def outputs(self) -> "StepColumns":
+ return ["revision", "rational", "model_name"]
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The input is formatted as a `ChatType` assuming that the instruction
+ is the first interaction from the user within a conversation."""
+ return [
+ {"role": "system", "content": self.system_prompt},
+ {
+ "role": "user",
+ "content": self._template.render(
+ task=input["task"], student_solution=input["student_solution"]
+ ),
+ },
+ ]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """The output is formatted as a list with the score of each instruction-response pair.
+
+ Args:
+ output: the raw output of the LLM.
+ input: the input to the task. Used for obtaining the number of responses.
+
+ Returns:
+ A dict with the key `scores` containing the scores for each instruction-response pair.
+ """
+ if output is None:
+ return self._default_error()
+
+ return self._format_output(output)
+
+ def _format_output(self, output: Union[str, None]) -> Dict[str, Any]:
+ if "**Corrected Student Solution:**" in output:
+ splits = output.split("**Corrected Student Solution:**")
+ elif "{corrected_student_solution}:" in output:
+ splits = output.split("{corrected_student_solution}:")
+ elif "{corrected_student_solution}" in output:
+ splits = output.split("{corrected_student_solution}")
+ elif "**Worsened Student Solution:**" in output:
+ splits = output.split("**Worsened Student Solution:**")
+ elif "{worsened_student_solution}:" in output:
+ splits = output.split("{worsened_student_solution}:")
+ elif "{worsened_student_solution}" in output:
+ splits = output.split("{worsened_student_solution}")
+ else:
+ splits = None
+
+ # Safety check when the output doesn't follow the expected format
+ if not splits:
+ return self._default_error()
+
+ if len(splits) >= 2:
+ revision = splits[1]
+ revision = revision.strip("\n\n").strip() # noqa: B005
+
+ rational = splits[0]
+ if "{teacher_reasoning}" in rational:
+ rational = rational.split("{teacher_reasoning}")[1].strip(":").strip()
+ rational = rational.strip("\n\n").strip() # noqa: B005
+ else:
+ return self._default_error()
+ return {"revision": revision, "rational": rational}
+
+ def _default_error(self) -> Dict[str, None]:
+ return {"revision": None, "rational": None}
diff --git a/src/distilabel/steps/tasks/complexity_scorer.py b/src/distilabel/steps/tasks/complexity_scorer.py
index 9ee8befa5..401e3b760 100644
--- a/src/distilabel/steps/tasks/complexity_scorer.py
+++ b/src/distilabel/steps/tasks/complexity_scorer.py
@@ -239,3 +239,17 @@ def _format_structured_output(
return orjson.loads(output)
except orjson.JSONDecodeError:
return {"scores": [None] * len(input["instructions"])}
+
+ @override
+ def _sample_input(self) -> "ChatType":
+ """Returns a sample input to be used in the `print` method.
+ Tasks that don't adhere to a format input that returns a map of the type
+ str -> str should override this method to return a sample input.
+ """
+ return self.format_input(
+ {
+ "instructions": [
+ f"" for i in range(2)
+ ],
+ }
+ )
diff --git a/src/distilabel/steps/tasks/decorator.py b/src/distilabel/steps/tasks/decorator.py
new file mode 100644
index 000000000..8862734f8
--- /dev/null
+++ b/src/distilabel/steps/tasks/decorator.py
@@ -0,0 +1,220 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Tuple, Type, Union
+
+import yaml
+
+from distilabel.errors import DistilabelUserError
+from distilabel.steps.tasks.base import Task
+from distilabel.steps.tasks.typing import FormattedInput
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns
+
+
+TaskFormattingOutputFunc = Callable[..., Dict[str, Any]]
+
+
+def task(
+ inputs: Union["StepColumns", None] = None,
+ outputs: Union["StepColumns", None] = None,
+) -> Callable[..., Type["Task"]]:
+ """Creates a `Task` from a formatting output function.
+
+ Args:
+ inputs: a list containing the name of the inputs columns/keys or a dictionary
+ where the keys are the columns and the values are booleans indicating whether
+ the column is required or not, that are required by the step. If not provided
+ the default will be an empty list `[]` and it will be assumed that the step
+ doesn't need any specific columns. Defaults to `None`.
+ outputs: a list containing the name of the outputs columns/keys or a dictionary
+ where the keys are the columns and the values are booleans indicating whether
+ the column will be generated or not. If not provided the default will be an
+ empty list `[]` and it will be assumed that the step doesn't need any specific
+ columns. Defaults to `None`.
+ """
+
+ inputs = inputs or []
+ outputs = outputs or []
+
+ def decorator(func: TaskFormattingOutputFunc) -> Type["Task"]:
+ doc = inspect.getdoc(func)
+ if doc is None:
+ raise DistilabelUserError(
+ "When using the `task` decorator, including a docstring in the formatting"
+ " function is mandatory. The docstring must follow the format described"
+ " in the documentation.",
+ page="",
+ )
+
+ system_prompt, user_message_template = _parse_docstring(doc)
+ _validate_templates(inputs, system_prompt, user_message_template)
+
+ def inputs_property(self) -> "StepColumns":
+ return inputs
+
+ def outputs_property(self) -> "StepColumns":
+ return outputs
+
+ def format_input(self, input: Dict[str, Any]) -> "FormattedInput":
+ return [
+ {"role": "system", "content": system_prompt.format(**input)},
+ {"role": "user", "content": user_message_template.format(**input)},
+ ]
+
+ def format_output(
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ return func(output, input)
+
+ return type(
+ func.__name__,
+ (Task,),
+ {
+ "inputs": property(inputs_property),
+ "outputs": property(outputs_property),
+ "__module__": func.__module__,
+ "format_input": format_input,
+ "format_output": format_output,
+ },
+ )
+
+ return decorator
+
+
+_SYSTEM_PROMPT_YAML_KEY: Final[str] = "system_prompt"
+_USER_MESSAGE_TEMPLATE_YAML_KEY: Final[str] = "user_message_template"
+_DOCSTRING_FORMATTING_FUNCTION_ERROR: Final[str] = (
+ "Formatting function decorated with `task` doesn't follow the expected format. Please,"
+ " check the documentation and update the function to include a docstring with the expected"
+ " format."
+)
+
+
+def _parse_docstring(docstring: str) -> Tuple[str, str]:
+ """Parses the docstring of the formatting function that was built using the `task`
+ decorator.
+
+ Args:
+ docstring: the docstring of the formatting function.
+
+ Returns:
+ A tuple containing the system prompt and the user message template.
+
+ Raises:
+ DistilabelUserError: if the docstring doesn't follow the expected format or if
+ the expected keys are missing.
+ """
+ parts = docstring.split("---")
+
+ if len(parts) != 3:
+ raise DistilabelUserError(
+ _DOCSTRING_FORMATTING_FUNCTION_ERROR,
+ page="",
+ )
+
+ yaml_content = parts[1]
+
+ try:
+ parsed_yaml = yaml.safe_load(yaml_content)
+ if not isinstance(parsed_yaml, dict):
+ raise DistilabelUserError(
+ _DOCSTRING_FORMATTING_FUNCTION_ERROR,
+ page="",
+ )
+
+ system_prompt = parsed_yaml.get(_SYSTEM_PROMPT_YAML_KEY)
+ user_template = parsed_yaml.get(_USER_MESSAGE_TEMPLATE_YAML_KEY)
+ if system_prompt is None or user_template is None:
+ raise DistilabelUserError(
+ "The formatting function decorated with `task` must include both the `system_prompt`"
+ " and `user_message_template` keys in the docstring. Please, check the documentation"
+ " and update the docstring of the formatting function to include the expected"
+ " keys.",
+ page="",
+ )
+
+ return system_prompt.strip(), user_template.strip()
+
+ except yaml.YAMLError as e:
+ raise DistilabelUserError(_DOCSTRING_FORMATTING_FUNCTION_ERROR, page="") from e
+
+
+TEMPLATE_PLACEHOLDERS_REGEX = re.compile(r"\{(\w+)\}")
+
+
+def _validate_templates(
+ inputs: "StepColumns", system_prompt: str, user_message_template: str
+) -> None:
+ """Validates the system prompt and user message template to ensure that they only
+ contain the allowed placeholders i.e. the columns/keys that are provided as inputs.
+
+ Args:
+ inputs: the list of inputs columns/keys.
+ system_prompt: the system prompt.
+ user_message_template: the user message template.
+
+ Raises:
+ DistilabelUserError: if the system prompt or the user message template contain
+ invalid placeholders.
+ """
+ list_inputs = list(inputs.keys()) if isinstance(inputs, dict) else inputs
+
+ valid_system_prompt, invalid_system_prompt_placeholders = _validate_template(
+ system_prompt, list_inputs
+ )
+ if not valid_system_prompt:
+ raise DistilabelUserError(
+ f"The formatting function decorated with `task` includes invalid placeholders"
+ f" in the extracted `system_prompt` from the function docstring. Valid placeholders"
+ f" are: {list_inputs}, but the following placeholders were found: {invalid_system_prompt_placeholders}."
+ f" Please, update the `system_prompt` to only include the valid placeholders.",
+ page="",
+ )
+
+ valid_user_message_template, invalid_user_message_template_placeholders = (
+ _validate_template(user_message_template, list_inputs)
+ )
+ if not valid_user_message_template:
+ raise DistilabelUserError(
+ f"The formatting function decorated with `task` includes invalid placeholders"
+ f" in the extracted `user_message_template` from the function docstring. Valid"
+ f" placeholders are: {list_inputs}, but the following placeholders were found:"
+ f" {invalid_user_message_template_placeholders}. Please, update the `system_prompt`"
+ " to only include the valid placeholders.",
+ page="",
+ )
+
+
+def _validate_template(
+ template: str, allowed_placeholders: List[str]
+) -> Tuple[bool, set[str]]:
+ """Validates that the template only contains the allowed placeholders.
+
+ Args:
+ template: the template to validate.
+ allowed_placeholders: the list of allowed placeholders.
+
+ Returns:
+ A tuple containing a boolean indicating if the template is valid and a set
+ with the invalid placeholders.
+ """
+ placeholders = set(TEMPLATE_PLACEHOLDERS_REGEX.findall(template))
+ allowed_placeholders_set = set(allowed_placeholders)
+ are_valid = placeholders.issubset(allowed_placeholders_set)
+ invalid_placeholders = placeholders - allowed_placeholders_set
+ return are_valid, invalid_placeholders
diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py
index ea73f2503..95f271a11 100644
--- a/src/distilabel/steps/tasks/evol_instruct/base.py
+++ b/src/distilabel/steps/tasks/evol_instruct/base.py
@@ -388,3 +388,9 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
):
input.update(self.format_output(instruction, answers[idx]))
yield inputs
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input(
+ self._apply_random_mutation("")
+ )
diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py
index bc15655ba..1f56c866a 100644
--- a/src/distilabel/steps/tasks/evol_instruct/generator.py
+++ b/src/distilabel/steps/tasks/evol_instruct/generator.py
@@ -347,3 +347,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
],
True,
)
+
+ @override
+ def _sample_input(self) -> "ChatType":
+ return self._apply_random_mutation(iter_no=0)[0]
diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py
index 743deeb4f..5c899aa68 100644
--- a/src/distilabel/steps/tasks/evol_quality/base.py
+++ b/src/distilabel/steps/tasks/evol_quality/base.py
@@ -271,3 +271,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
yield inputs
self._logger.info(f"🎉 Finished evolving {len(responses)} instructions!")
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input("")
diff --git a/src/distilabel/steps/tasks/improving_text_embeddings.py b/src/distilabel/steps/tasks/improving_text_embeddings.py
index a23b9dbba..d806e3ade 100644
--- a/src/distilabel/steps/tasks/improving_text_embeddings.py
+++ b/src/distilabel/steps/tasks/improving_text_embeddings.py
@@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib.resources as importlib_resources
import random
import re
-import sys
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union
-if sys.version_info < (3, 9):
- import importlib_resources
-else:
- import importlib.resources as importlib_resources
-
from jinja2 import Template
from pydantic import Field, PrivateAttr
from typing_extensions import override
@@ -232,6 +227,10 @@ def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore
)
yield task_outputs, True
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.prompt
+
# IMPLEMENTED TASKS
class EmbeddingTaskGenerator(GeneratorTask):
@@ -402,6 +401,10 @@ def format_output(
pass
return {"tasks": output}
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.prompt
+
class GenerateTextRetrievalData(_EmbeddingDataGeneration):
"""Generate text retrieval data with an `LLM` to later on train an embedding model.
diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py
index 2df8bf621..c1e413d32 100644
--- a/src/distilabel/steps/tasks/magpie/generator.py
+++ b/src/distilabel/steps/tasks/magpie/generator.py
@@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, Dict, Union
from pydantic import Field
+from typing_extensions import override
from distilabel.errors import DistilabelUserError
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
@@ -23,6 +24,7 @@
from distilabel.steps.tasks.magpie.base import MagpieBase
if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import GeneratorStepOutput, StepColumns
@@ -312,3 +314,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput":
)
generated += rows_to_generate # type: ignore
yield (conversations, generated == self.num_rows)
+
+ @override
+ def _sample_input(self) -> "ChatType":
+ return self._generate_with_pre_query_template(inputs=[{}])
diff --git a/src/distilabel/steps/tasks/quality_scorer.py b/src/distilabel/steps/tasks/quality_scorer.py
index 57b7e38e2..604f2a027 100644
--- a/src/distilabel/steps/tasks/quality_scorer.py
+++ b/src/distilabel/steps/tasks/quality_scorer.py
@@ -262,3 +262,14 @@ def _format_structured_output(
return orjson.loads(output)
except orjson.JSONDecodeError:
return {"scores": [None] * len(input["responses"])}
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input(
+ {
+ "instruction": f"",
+ "responses": [
+ f"" for i in range(2)
+ ],
+ }
+ )
diff --git a/src/distilabel/steps/tasks/templates/apigen/generator.jinja2 b/src/distilabel/steps/tasks/templates/apigen/generator.jinja2
new file mode 100644
index 000000000..cc92c725c
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/apigen/generator.jinja2
@@ -0,0 +1,10 @@
+Here are examples of queries and the corresponding answers for similar functions:
+{{ examples }}
+
+Note that the query could be interpreted as a combination of several independent requests.
+{{ parallel_queries }}
+Based on these examples, generate {{ number }} diverse query and answer pairs for the function `{{ func_name }}`.
+The detailed function description is the following:
+{{ func_desc }}
+{{ format_inst }}
+Now please generate {{ number }} diverse query and answer pairs following the above format.
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2 b/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2
new file mode 100644
index 000000000..8d94357e7
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2
@@ -0,0 +1,13 @@
+Given Information:
+- All Available Functions:
+{{ func_desc }}
+- User Query: {{ query }}
+- Generated Function Calls: {{ func_call }}
+- Execution Results: {{ execution_result }}
+
+Note: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.
+
+The main decision factor is wheather the function calls accurately reflect the query's intentions and the function descriptions.
+Provide your reasoning in the thought section and decide if the data passes (answer yes or no).
+If not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.
+{{ format_inst }}
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/templates/argillalabeller.jinja2 b/src/distilabel/steps/tasks/templates/argillalabeller.jinja2
new file mode 100644
index 000000000..d5afa75d2
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/argillalabeller.jinja2
@@ -0,0 +1,13 @@
+Please provide an answer to the question based on the input fields{% if examples %} and examples{% endif %}.
+{% if guidelines %}
+# Guidelines
+{{ guidelines }}
+{% endif %}
+# Input Fields
+{{ fields }}
+# Question
+{{ question }}
+{% if examples %}
+# Examples
+{{ examples }}
+{% endif %}
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/templates/clair.jinja2 b/src/distilabel/steps/tasks/templates/clair.jinja2
new file mode 100644
index 000000000..3815c6db8
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/clair.jinja2
@@ -0,0 +1,7 @@
+{task}: {{ task }}
+
+{student_solution}: {{ student_solution }}
+
+-----------------
+
+Let's first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py
index 04af9c177..aeb57bda3 100644
--- a/src/distilabel/steps/tasks/ultrafeedback.py
+++ b/src/distilabel/steps/tasks/ultrafeedback.py
@@ -480,3 +480,14 @@ def _format_structured_output(
"types": [None] * len(input["generations"]),
"rationales-for-ratings": [None] * len(input["generations"]),
}
+
+ @override
+ def _sample_input(self) -> ChatType:
+ return self.format_input(
+ {
+ "instruction": f"",
+ "generations": [
+ f"" for i in range(2)
+ ],
+ }
+ )
diff --git a/src/distilabel/typing.py b/src/distilabel/typing.py
new file mode 100644
index 000000000..e034f216d
--- /dev/null
+++ b/src/distilabel/typing.py
@@ -0,0 +1,55 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.llms.typing import GenerateOutput
+from distilabel.pipeline.typing import (
+ DownstreamConnectable,
+ DownstreamConnectableSteps,
+ InputDataset,
+ PipelineRuntimeParametersInfo,
+ StepLoadStatus,
+ UpstreamConnectableSteps,
+)
+from distilabel.steps.tasks.typing import (
+ ChatItem,
+ ChatType,
+ FormattedInput,
+ InstructorStructuredOutputType,
+ OutlinesStructuredOutputType,
+ StandardInput,
+ StructuredInput,
+ StructuredOutputType,
+)
+from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput
+
+__all__ = [
+ "GenerateOutput",
+ "DownstreamConnectable",
+ "DownstreamConnectableSteps",
+ "InputDataset",
+ "PipelineRuntimeParametersInfo",
+ "StepLoadStatus",
+ "UpstreamConnectableSteps",
+ "ChatItem",
+ "ChatType",
+ "FormattedInput",
+ "InstructorStructuredOutputType",
+ "OutlinesStructuredOutputType",
+ "StandardInput",
+ "StructuredInput",
+ "StructuredOutputType",
+ "GeneratorStepOutput",
+ "StepColumns",
+ "StepOutput",
+]
diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py
index 9d5d9b59e..621f4b61d 100644
--- a/src/distilabel/utils/mkdocs/components_gallery.py
+++ b/src/distilabel/utils/mkdocs/components_gallery.py
@@ -90,6 +90,7 @@
"filtering": ":material-filter:",
"format": ":material-format-list-bulleted:",
"load": ":material-file-download:",
+ "execution": ":octicons-code-16:",
"save": ":material-content-save:",
}
@@ -108,6 +109,7 @@
"filtering": "Filtering steps are used to filter the data based on some criteria.",
"format": "Format steps are used to format the data.",
"load": "Load steps are used to load the data.",
+ "execution": "Executes python functions.",
"save": "Save steps are used to save the data.",
}
diff --git a/tests/integration/test_caching_steps.py b/tests/integration/test_caching_steps.py
new file mode 100644
index 000000000..5ed8af993
--- /dev/null
+++ b/tests/integration/test_caching_steps.py
@@ -0,0 +1,499 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tempfile import TemporaryDirectory
+from typing import TYPE_CHECKING, Any, Dict, Generator, List
+from unittest import mock
+from uuid import uuid4
+
+from pydantic import PrivateAttr
+
+from distilabel.pipeline import Pipeline
+from distilabel.steps import LoadDataFromDicts
+from distilabel.steps.base import Step, StepInput
+
+if TYPE_CHECKING:
+ from distilabel.pipeline.batch import _Batch
+
+
+class DummyStep(Step):
+ attr: int = 5
+ do_fail: bool = False
+ _ctr: int = PrivateAttr(default=0)
+
+ _random: str = PrivateAttr(default="")
+
+ def load(self) -> None:
+ super().load()
+ self._random = str(uuid4())
+
+ @property
+ def inputs(self) -> List[str]:
+ return ["instruction"]
+
+ def process(self, inputs: StepInput) -> Generator[List[Dict[str, Any]], None, None]:
+ for input in inputs:
+ input["response"] = f"I don't know - {self._ctr} - {self._random}"
+ self._ctr += 1
+
+ if self.do_fail:
+ raise ValueError("The step failed")
+ yield inputs
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["response"]
+
+
+class DummyStep2(DummyStep):
+ def process(
+ self, *inputs: StepInput
+ ) -> Generator[List[Dict[str, Any]], None, None]:
+ outputs = []
+ for input_a, input_b in zip(*inputs):
+ output = {**input_a, **input_b}
+ output["response"] = f"I don't know - {self._ctr}"
+ self._ctr += 1
+ outputs.append(output)
+ yield outputs
+
+
+class OtherDummyStep(DummyStep):
+ pass
+
+
+def test_cache() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ distiset_0 = pipeline.run()
+ distiset_1 = pipeline.run()
+ assert (
+ distiset_0["default"]["train"].to_list()
+ == distiset_1["default"]["train"].to_list()
+ )
+
+ distiset_2 = pipeline.run(use_cache=False)
+ assert len(distiset_2["default"]["train"]) == 48
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_2["default"]["train"].to_list()
+ )
+
+
+def test_cache_with_step_cache_false() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=False,
+ )
+
+ step_generator >> step_a >> step_b
+
+ distiset_0 = pipeline.run()
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_b` has been executed
+ assert run_step_spy.call_count == 1
+
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_1["default"]["train"].to_list()
+ )
+
+
+def test_cache_with_step_changing() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b
+
+ distiset_0 = pipeline.run()
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ attr=103401234, # change attribute so step is not the same
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_b` has been executed
+ assert run_step_spy.call_count == 1
+
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_1["default"]["train"].to_list()
+ )
+
+
+def test_cache_with_intermediate_step_cache_false() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=False,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ distiset_0 = pipeline.run()
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_b` and `step_c` has been executed
+ assert run_step_spy.call_count == 2
+
+ assert (
+ distiset_0["default"]["train"].to_list()
+ != distiset_1["default"]["train"].to_list()
+ )
+
+
+def test_cache_adding_step() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b
+
+ distiset_0 = pipeline.run()
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=False,
+ use_cache=True,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=True,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_c` has been executed
+ assert run_step_spy.call_count == 1
+
+ dict_0 = distiset_0["default"]["train"].to_dict()
+ dict_1 = distiset_1["default"]["train"].to_dict()
+ del dict_1["response_2"]
+ assert dict_0 == dict_1
+
+
+def test_cache_adding_step_with_multiple_predecessor() -> None:
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ output_mappings={"response": "response_1"},
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ output_mappings={"response": "response_2"},
+ do_fail=False,
+ use_cache=True,
+ )
+
+ step_generator >> [step_a, step_b]
+
+ distiset_0 = pipeline.run()
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a",
+ input_batch_size=4,
+ output_mappings={"response": "response_1"},
+ use_cache=True,
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ output_mappings={"response": "response_2"},
+ do_fail=False,
+ use_cache=True,
+ )
+ step_c = DummyStep2(
+ name="step_c",
+ input_batch_size=12,
+ output_mappings={"response": "response_3"},
+ use_cache=True,
+ )
+
+ step_generator >> [step_a, step_b] >> step_c
+
+ with mock.patch.object(
+ pipeline, "_run_step", wraps=pipeline._run_step
+ ) as run_step_spy:
+ distiset_1 = pipeline.run()
+
+ # check that only `step_c` has been executed
+ assert run_step_spy.call_count == 1
+
+ for row_1, row_0_a, row_0_b in zip(
+ distiset_1["default"]["train"],
+ distiset_0["step_a"]["train"],
+ distiset_0["step_b"]["train"],
+ ):
+ assert row_1["response_1"] == row_0_a["response_1"]
+ assert row_1["response_2"] == row_0_b["response_2"]
+
+
+def test_cache_with_offset() -> None:
+ use_cache_per_step = True
+ do_fail = False
+ with TemporaryDirectory() as tmp_dir:
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline_0:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a", input_batch_size=4, use_cache=use_cache_per_step
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=do_fail,
+ use_cache=use_cache_per_step,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=use_cache_per_step,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ # Controlled failure of the Pipeline
+ original_process_batch = pipeline_0._process_batch
+
+ def _process_batch_wrapper(
+ batch: "_Batch", send_last_batch_flag: bool = True
+ ) -> None:
+ if batch.step_name == step_b.name and batch.seq_no == 2:
+ pipeline_0._stop_called = True
+ original_process_batch(batch)
+
+ # Run first time and stop the pipeline when specific batch received (simulate CTRL + C)
+ with mock.patch.object(pipeline_0, "_process_batch", _process_batch_wrapper):
+ distiset_0 = pipeline_0.run(use_cache=False)
+
+ assert len(distiset_0["default"]["train"]) == 12
+
+ with Pipeline(name="test_pipeline_caching", cache_dir=tmp_dir) as pipeline_1:
+ initial_batch_size = 8
+ step_generator = LoadDataFromDicts(
+ data=[{"instruction": "some text"}] * initial_batch_size * 6,
+ batch_size=initial_batch_size,
+ )
+
+ step_a = DummyStep(
+ name="step_a", input_batch_size=4, use_cache=use_cache_per_step
+ )
+ step_b = DummyStep(
+ name="step_b",
+ input_batch_size=10,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_1"},
+ do_fail=do_fail,
+ use_cache=use_cache_per_step,
+ )
+ step_c = DummyStep(
+ name="step_c",
+ input_batch_size=12,
+ input_mappings={"instruction": "response"},
+ output_mappings={"response": "response_2"},
+ use_cache=use_cache_per_step,
+ )
+
+ step_generator >> step_a >> step_b >> step_c
+
+ distiset_1 = pipeline_1.run()
+
+ assert len(distiset_1["default"]["train"]) == 48
diff --git a/tests/integration/test_generator_and_sampler.py b/tests/integration/test_generator_and_sampler.py
new file mode 100644
index 000000000..1bb0a457b
--- /dev/null
+++ b/tests/integration/test_generator_and_sampler.py
@@ -0,0 +1,55 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.llms._dummy import DummyAsyncLLM
+from distilabel.pipeline import Pipeline
+from distilabel.steps import CombineOutputs, LoadDataFromDicts
+from distilabel.steps.generators.data_sampler import DataSampler
+from distilabel.steps.tasks import TextGeneration
+
+
+def get_pipeline():
+ with Pipeline() as pipe:
+ size_dataset_1 = 10
+ loader_1 = LoadDataFromDicts(
+ data=[{"instruction": f"instruction {i}"} for i in range(size_dataset_1)]
+ )
+ sampler = DataSampler(
+ data=[{"sample": f"sample {i}"} for i in range(30)],
+ size=2,
+ samples=size_dataset_1,
+ batch_size=8,
+ )
+ text_generation = TextGeneration(llm=DummyAsyncLLM(), input_batch_size=8)
+
+ combine = CombineOutputs()
+ [loader_1, sampler] >> combine >> text_generation
+ return pipe
+
+
+def test_sampler():
+ pipe = get_pipeline()
+ distiset = pipe.run(use_cache=False)
+ assert len(distiset["default"]["train"]) == 10
+ row = distiset["default"]["train"][0]
+ assert isinstance(row["sample"], list)
+ assert len(row["sample"]) == 2
+ assert isinstance(row["instruction"], str)
+
+
+if __name__ == "__main__":
+ pipe = get_pipeline()
+ distiset = pipe.run(use_cache=False)
+ print(distiset)
+ print(distiset["default"]["train"][0])
diff --git a/tests/integration/test_prints.py b/tests/integration/test_prints.py
new file mode 100644
index 000000000..7db85caf8
--- /dev/null
+++ b/tests/integration/test_prints.py
@@ -0,0 +1,72 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Union
+
+import pytest
+
+from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+from distilabel.steps import tasks as tasks_
+from tests.unit.conftest import DummyLLM
+
+# The tasks not listed here don't have a print method (or don't have a print method that works)
+tasks = [
+ tasks_.ComplexityScorer,
+ partial(tasks_.EvolInstruct, num_evolutions=1),
+ partial(tasks_.EvolComplexity, num_evolutions=1),
+ partial(tasks_.EvolComplexityGenerator, num_instructions=1),
+ partial(tasks_.EvolInstructGenerator, num_instructions=1),
+ partial(tasks_.EvolQuality, num_evolutions=1),
+ tasks_.Genstruct,
+ partial(
+ tasks_.BitextRetrievalGenerator,
+ source_language="English",
+ target_language="Spanish",
+ unit="sentence",
+ difficulty="elementary school",
+ high_score="4",
+ low_score="2.5",
+ ),
+ partial(tasks_.EmbeddingTaskGenerator, category="text-retrieval"),
+ tasks_.GenerateLongTextMatchingData,
+ tasks_.GenerateShortTextMatchingData,
+ tasks_.GenerateTextClassificationData,
+ tasks_.GenerateTextRetrievalData,
+ tasks_.MonolingualTripletGenerator,
+ tasks_.InstructionBacktranslation,
+ tasks_.Magpie,
+ tasks_.MagpieGenerator,
+ partial(tasks_.PrometheusEval, mode="absolute", rubric="factual-validity"),
+ tasks_.QualityScorer,
+ tasks_.SelfInstruct,
+ partial(tasks_.GenerateSentencePair, action="paraphrase"),
+ tasks_.UltraFeedback,
+ tasks_.URIAL,
+]
+
+
+class TestLLM(DummyLLM, MagpieChatTemplateMixin):
+ magpie_pre_query_template: Union[str, None] = "llama3"
+
+
+llm = TestLLM()
+
+
+@pytest.mark.parametrize("task", tasks)
+def test_prints(task) -> None:
+ t = task(llm=llm)
+ t.load()
+ t.print()
+ t.unload()
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 1903d10e3..0e2e157e6 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -53,9 +53,9 @@ def model_name(self) -> str:
return "test"
def generate( # type: ignore
- self, input: "FormattedInput", num_generations: int = 1
- ) -> "GenerateOutput":
- return ["output" for _ in range(num_generations)]
+ self, inputs: "FormattedInput", num_generations: int = 1
+ ) -> List["GenerateOutput"]:
+ return [["output" for _ in range(num_generations)]]
class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
diff --git a/tests/unit/pipeline/conftest.py b/tests/unit/pipeline/conftest.py
index b3e708a17..a2bf2b932 100644
--- a/tests/unit/pipeline/conftest.py
+++ b/tests/unit/pipeline/conftest.py
@@ -14,7 +14,10 @@
import pytest
+from distilabel.pipeline._dag import DAG
+from distilabel.pipeline.batch_manager import _BatchManager
from distilabel.pipeline.local import Pipeline
+from distilabel.steps.base import GeneratorStep, GlobalStep, Step
from .utils import DummyGeneratorStep, DummyGlobalStep, DummyStep1, DummyStep2
@@ -42,3 +45,26 @@ def dummy_generator_step_fixture(pipeline: "Pipeline") -> DummyGeneratorStep:
@pytest.fixture(name="dummy_global_step")
def dummy_global_step_fixture(pipeline: "Pipeline") -> DummyGlobalStep:
return DummyGlobalStep(name="dummy_global_step", pipeline=pipeline)
+
+
+@pytest.fixture(name="dummy_dag")
+def dummy_dag_fixture(
+ dummy_generator_step: "GeneratorStep",
+ dummy_step_1: "Step",
+ dummy_step_2: "Step",
+ dummy_global_step: "GlobalStep",
+) -> DAG:
+ dag = DAG()
+ dag.add_step(dummy_generator_step)
+ dag.add_step(dummy_step_1)
+ dag.add_step(dummy_step_2)
+ dag.add_step(dummy_global_step)
+ dag.add_edge("dummy_generator_step", "dummy_step_1")
+ dag.add_edge("dummy_generator_step", "dummy_global_step")
+ dag.add_edge("dummy_step_1", "dummy_step_2")
+ return dag
+
+
+@pytest.fixture(name="dummy_batch_manager")
+def dummy_batch_manager_from_dag_fixture(dummy_dag: DAG) -> _BatchManager:
+ return _BatchManager.from_dag(dummy_dag)
diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py
index 5f38781a0..86db3f5cf 100644
--- a/tests/unit/pipeline/test_base.py
+++ b/tests/unit/pipeline/test_base.py
@@ -95,6 +95,28 @@ def test_get_pipeline(self) -> None:
class TestBasePipeline:
+ def test_aggregated_steps_signature(self) -> None:
+ with DummyPipeline(name="dummy") as pipeline_0:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ with DummyPipeline(name="dummy") as pipeline_1:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ assert (
+ pipeline_0.aggregated_steps_signature
+ == pipeline_1.aggregated_steps_signature
+ )
+
def test_context_manager(self) -> None:
assert _GlobalPipelineManager.get_pipeline() is None
@@ -123,12 +145,18 @@ def test_load_batch_manager(self, use_cache: bool) -> None:
if use_cache:
mock_load_from_cache.assert_called_once_with(
- pipeline._cache_location["batch_manager"]
+ dag=pipeline.dag,
+ batch_manager_path=pipeline._cache_location["batch_manager"],
+ steps_data_path=pipeline._cache_location["steps_data"],
)
mock_from_dag.assert_not_called()
else:
mock_load_from_cache.assert_not_called()
- mock_from_dag.assert_called_once_with(pipeline.dag)
+ mock_from_dag.assert_called_once_with(
+ dag=pipeline.dag,
+ use_cache=use_cache,
+ steps_data_path=pipeline._cache_location["steps_data"],
+ )
def test_setup_write_buffer(self) -> None:
pipeline = DummyPipeline(name="unit-test-pipeline")
@@ -328,6 +356,7 @@ def test_run_stage_steps_and_wait(self, caplog) -> None:
generator >> [step, step2] >> step3 >> step4
+ pipeline._load_batch_manager()
pipeline._steps_load_status = { # type: ignore
generator.name: 1,
step.name: 1,
@@ -351,6 +380,7 @@ def test_run_stage_steps_and_wait_with_failing_step(self, caplog) -> None:
generator >> [step, step2] >> step3 >> step4
pipeline._init_steps_load_status()
+ pipeline._load_batch_manager()
pipeline._steps_load_status[generator.name] = _STEP_LOAD_FAILED_CODE # type: ignore
caplog.set_level(logging.INFO)
@@ -368,6 +398,7 @@ def test_run_stage_steps_and_wait_stop_called(self) -> None:
generator >> [step, step2] >> step3 >> step4
pipeline._init_steps_load_status()
+ pipeline._load_batch_manager()
pipeline._stop_called = True
assert pipeline._run_stage_steps_and_wait(stage=0) is False
@@ -626,7 +657,9 @@ def test_register_batch(self) -> None:
batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
pipeline._register_batch(batch)
- pipeline._batch_manager.register_batch.assert_called_once_with(batch)
+ pipeline._batch_manager.register_batch.assert_called_once_with(
+ batch, steps_data_path=pipeline._cache_location["steps_data"]
+ )
def test_send_last_batch_flag_to_step(self) -> None:
with DummyPipeline(name="unit-test-pipeline") as pipeline:
@@ -743,7 +776,9 @@ def test_handle_batch_on_stop(self) -> None:
batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
pipeline._handle_batch_on_stop(batch)
- batch_manager_mock.register_batch.assert_called_once_with(batch)
+ batch_manager_mock.register_batch.assert_called_once_with(
+ batch, steps_data_path=pipeline._cache_location["steps_data"]
+ )
batch_manager_mock.add_batch.assert_has_calls(
[
mock.call(step.name, batch),
@@ -1300,8 +1335,7 @@ def test_base_pipeline_signature(self) -> None:
pipeline = DummyPipeline(name="unit-test-pipeline")
# Doesn't matter if it's exactly this or not, the test should fail if we change the
# way this is created.
- signature = pipeline._create_signature()
- assert signature == "da39a3ee5e6b4b0d3255bfef95601890afd80709"
+ assert pipeline.signature == "da39a3ee5e6b4b0d3255bfef95601890afd80709"
# Maybe not the best place for this test, but does the work for now
from distilabel.pipeline.local import Pipeline
@@ -1311,11 +1345,28 @@ def test_base_pipeline_signature(self) -> None:
sample_two_steps = sample_n_steps(2)
with Pipeline(name="unit-test-pipeline") as pipeline:
- dummy_generator = DummyGeneratorStep()
- dummy_step_1_0 = DummyStep1()
- dummy_step_1_1 = DummyStep1()
- dummy_step_1_2 = DummyStep1()
- dummy_step_2 = DummyStep2()
+ dummy_generator = DummyGeneratorStep(name="dummy_generator")
+ dummy_step_1_0 = DummyStep1(name="dummy_step_1_0")
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
+
+ (
+ dummy_generator
+ >> sample_two_steps
+ >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
+ >> dummy_step_2
+ )
+
+ assert pipeline.signature == "edff8f5bb8b51da406ff274e640f87264f014e3b"
+
+ # attributes shouldn't affect in pipeline signature
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator")
+ dummy_step_1_0 = DummyStep1(name="dummy_step_1_0", attr1=17238497128934)
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
(
dummy_generator
@@ -1324,8 +1375,51 @@ def test_base_pipeline_signature(self) -> None:
>> dummy_step_2
)
- signature = pipeline._create_signature()
- assert signature == "d3c7c572fe31233aa1198174c6c793b67ef3744b"
+ assert pipeline.signature == "edff8f5bb8b51da406ff274e640f87264f014e3b"
+
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator")
+ dummy_step_1_0 = DummyStep1(name="dummy_step_1_0")
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
+
+ (
+ dummy_generator
+ >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
+ >> dummy_step_2
+ )
+
+ assert pipeline.signature == "5634172be496319d50848b1679b2a8781cc5581f"
+
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator_second_time")
+ dummy_step_1_0 = DummyStep1(
+ name="dummy_step_1_0_second_time", attr1=17238497128934
+ )
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1_second_time")
+ dummy_step_1_2 = DummyStep1(name="dummy_step_1_2_second_time")
+ dummy_step_2 = DummyStep2(name="dummy_step_2_second_time")
+
+ (
+ dummy_generator
+ >> sample_two_steps
+ >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
+ >> dummy_step_2
+ )
+
+ assert pipeline.signature == "806dad3fca0f8274af0f374660d4e3eb25d62d12"
+
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator_second_time")
+ dummy_step_1_0 = DummyStep1(
+ name="dummy_step_1_0_second_time", attr1=17238497128934
+ )
+ dummy_step_1_1 = DummyStep1(name="dummy_step_1_1_second_time")
+
+ (dummy_generator >> sample_two_steps >> [dummy_step_1_0, dummy_step_1_1])
+
+ assert pipeline.signature == "7222ce34c677bea3720ef3d08c2673b29b61ff9b"
def test_binary_rshift_operator(self) -> None:
# Tests the steps can be connected using the >> operator.
@@ -1340,7 +1434,7 @@ def test_binary_rshift_operator(self) -> None:
dummy_generator.connect(dummy_step_1)
dummy_step_1.connect(dummy_step_2)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-3") as pipeline_2:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
@@ -1349,7 +1443,7 @@ def test_binary_rshift_operator(self) -> None:
dummy_generator >> dummy_step_1 >> dummy_step_2
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
@@ -1366,7 +1460,7 @@ def test_binary_rshift_operator_with_list(self) -> None:
dummy_generator.connect(dummy_step_1)
dummy_generator.connect(dummy_step_2)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-2") as pipeline_2:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
@@ -1375,7 +1469,7 @@ def test_binary_rshift_operator_with_list(self) -> None:
dummy_generator >> [dummy_step_1, dummy_step_2]
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
@@ -1395,7 +1489,7 @@ def test_binary_rrshift_operator(self) -> None:
dummy_step_1.connect(dummy_global)
dummy_step_2.connect(dummy_global)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-2") as pipeline_2:
dummy_step_1 = DummyStep1(name="dummy_step_1")
@@ -1403,7 +1497,7 @@ def test_binary_rrshift_operator(self) -> None:
dummy_global = DummyGlobalStep(name="dummy_global_step")
[dummy_step_1, dummy_step_2] >> dummy_global
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
@@ -1429,7 +1523,7 @@ def test_binary_operators(self) -> None:
dummy_step_1.connect(dummy_global)
dummy_step_2.connect(dummy_global)
- signature_1 = pipeline_1._create_signature()
+ signature_1 = pipeline_1.signature
with Pipeline(name="unit-test-pipeline-2") as pipeline_2:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
@@ -1438,6 +1532,6 @@ def test_binary_operators(self) -> None:
dummy_global = DummyGlobalStep(name="dummy_global_step")
dummy_generator >> [dummy_step_1, dummy_step_2] >> dummy_global
- signature_2 = pipeline_2._create_signature()
+ signature_2 = pipeline_2.signature
assert signature_1 == signature_2
diff --git a/tests/unit/pipeline/test_batch_manager.py b/tests/unit/pipeline/test_batch_manager.py
index c5023b261..8801096ce 100644
--- a/tests/unit/pipeline/test_batch_manager.py
+++ b/tests/unit/pipeline/test_batch_manager.py
@@ -15,14 +15,18 @@
import tempfile
from pathlib import Path
from typing import Dict, List
+from unittest import mock
import pytest
from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.batch_manager import _BatchManager, _BatchManagerStep
+from distilabel.pipeline.local import Pipeline
from distilabel.steps.base import GeneratorStep, GlobalStep, Step
+from .utils import DummyGeneratorStep, DummyStep1, DummyStep2
+
class TestBatchManagerStep:
def test_add_batch(self) -> None:
@@ -144,6 +148,7 @@ def test_get_batch(self) -> None:
)
],
},
+ step_offset={"step1": (0, 0), "step2": (0, 0)},
built_batches=[previously_built_batch],
next_expected_seq_no={"step1": (1, 1), "step2": (1, 1)},
)
@@ -168,7 +173,7 @@ def test_get_batch(self) -> None:
{"b": 2},
],
],
- created_from={"step1": [(1, 5)], "step2": [(1, 5)]},
+ created_from={"step1": [(1, 5, 2)], "step2": [(1, 5, 2)]},
)
batch = batch_manager_step.get_batch()
@@ -187,7 +192,7 @@ def test_get_batch(self) -> None:
{"b": 4},
],
],
- created_from={"step1": [(1, 5)], "step2": [(1, 5)]},
+ created_from={"step1": [(1, 5, 2)], "step2": [(1, 5, 2)]},
)
def test_get_batches_accumulate(self) -> None:
@@ -231,6 +236,7 @@ def test_get_batches_accumulate(self) -> None:
)
],
},
+ step_offset={"step1": (0, 0), "step2": (0, 0)},
last_batch_received=["step1", "step2"],
)
@@ -258,7 +264,7 @@ def test_get_batches_accumulate(self) -> None:
{"b": 6},
],
],
- created_from={"step1": [(0, 5)], "step2": [(0, 6)]},
+ created_from={"step1": [(0, 5, 5)], "step2": [(0, 6, 6)]},
)
def test_get_batches_not_enough_data(self) -> None:
@@ -430,7 +436,7 @@ def test_get_data(self) -> None:
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}],
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}],
]
- assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
+ assert created_from == {"step1": [(0, 6, 5)], "step2": [(0, 7, 5)]}
assert routed_to == ["step1", "step2"]
assert batch_manager_step.data == {
@@ -502,7 +508,7 @@ def test_get_data_accumulate(self) -> None:
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}],
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}],
]
- assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
+ assert created_from == {"step1": [(0, 6, 6)], "step2": [(0, 7, 7)]}
assert routed_to == []
assert batch_manager_step.data == {"step1": [], "step2": []}
@@ -520,7 +526,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(0, 3)]},
+ created_from={"Z": [(0, 3, 3)]},
)
batch_a_1 = _Batch(
@@ -535,7 +541,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(1, 3)]},
+ created_from={"Z": [(1, 3, 3)]},
)
batch_b_0 = _Batch(
@@ -550,7 +556,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(0, 3)]},
+ created_from={"Z": [(0, 3, 3)]},
)
batch_c_0 = _Batch(
@@ -565,7 +571,7 @@ def test_get_data_convergence_step(self) -> None:
]
],
size=3,
- created_from={"Z": [(1, 3)]},
+ created_from={"Z": [(1, 3, 3)]},
)
batch_manager_step = _BatchManagerStep(
@@ -590,7 +596,7 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm B 0"},
],
]
- assert created_from == {"A": [(0, 3)], "B": [(0, 3)]}
+ assert created_from == {"A": [(0, 3, 3)], "B": [(0, 3, 3)]}
assert routed_to == []
assert batch_manager_step.next_expected_created_from_batch_seq_no == 1
@@ -608,7 +614,7 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm C 0"},
],
]
- assert created_from == {"A": [(1, 3)], "C": [(0, 3)]}
+ assert created_from == {"A": [(1, 3, 3)], "C": [(0, 3, 3)]}
assert routed_to == []
assert batch_manager_step.next_expected_created_from_batch_seq_no == 2
@@ -803,7 +809,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [
@@ -812,7 +818,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -827,7 +833,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [
@@ -836,7 +842,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -851,7 +857,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
- created_from={"step0": [(0, 3)]},
+ created_from={"step0": [(0, 3, 3)]},
)
],
"step2": [
@@ -860,7 +866,7 @@ def test_last_batch_accumulate(
step_name="step1",
last_batch=True,
data=[[{"b": 1}, {"b": 2}, {"b": 3}]],
- created_from={"step0": [(0, 3)]},
+ created_from={"step0": [(0, 3, 3)]},
)
],
},
@@ -1217,6 +1223,9 @@ def test_dump(self) -> None:
"step1": (0, 0),
"step2": (0, 0),
},
+ "step_offset": {},
+ "step_signature": None,
+ "use_cache": False,
"type_info": {
"module": "distilabel.pipeline.batch_manager",
"name": "_BatchManagerStep",
@@ -1235,7 +1244,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [],
@@ -1252,7 +1261,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
"step2": [
@@ -1262,7 +1271,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -1278,7 +1287,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
"step2": [
@@ -1288,7 +1297,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -1304,7 +1313,7 @@ def test_dump(self) -> None:
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
"step2": [
@@ -1314,7 +1323,7 @@ def test_dump(self) -> None:
last_batch=True,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
},
@@ -1330,7 +1339,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
+ created_from={"step0": [(0, 4, 4)]},
)
],
"step2": [
@@ -1340,7 +1349,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
+ created_from={"step0": [(0, 5, 5)]},
)
],
},
@@ -1467,6 +1476,41 @@ def test_add_batch(self) -> None:
"step2": [],
}
+ def test_step_hash_finished(self) -> None:
+ batch_manager = _BatchManager(
+ steps={
+ "step1": _BatchManagerStep(
+ step_name="step1",
+ accumulate=False,
+ input_batch_size=5,
+ data={},
+ ),
+ "step2": _BatchManagerStep(
+ step_name="step2",
+ accumulate=False,
+ input_batch_size=5,
+ data={"step_1": []},
+ ),
+ "step3": _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=5,
+ data={"step2": []},
+ ),
+ },
+ last_batch_received={
+ "step1": _Batch(seq_no=0, step_name="step1", last_batch=True),
+ "step2": None,
+ "step3": None,
+ },
+ last_batch_sent={"step1": None, "step2": None, "step3": None},
+ last_batch_flag_sent_to=["step2"],
+ )
+
+ assert batch_manager.step_has_finished("step1") is True
+ assert batch_manager.step_has_finished("step2") is True
+ assert batch_manager.step_has_finished("step3") is False
+
def test_add_batch_with_prepend(self) -> None:
batch_1 = _Batch(
seq_no=1,
@@ -1554,12 +1598,26 @@ def test_from_dag(
batch_manager = _BatchManager.from_dag(dag)
assert batch_manager._steps == {
+ "dummy_generator_step": _BatchManagerStep(
+ step_name="dummy_generator_step",
+ accumulate=False,
+ input_batch_size=None,
+ data={},
+ convergence_step=True,
+ next_expected_seq_no={},
+ step_signature="963a16b6081170f39eef011d64d992a0a6e9f0e9",
+ use_cache=True,
+ step_offset={},
+ ),
"dummy_step_1": _BatchManagerStep(
step_name="dummy_step_1",
accumulate=False,
input_batch_size=50,
data={"dummy_generator_step": []},
next_expected_seq_no={"dummy_generator_step": (0, 0)},
+ step_signature="bc765d5801dc71c88a1a444e1b1e26035d309724",
+ use_cache=True,
+ step_offset={"dummy_generator_step": (0, 0)},
),
"dummy_global_step": _BatchManagerStep(
step_name="dummy_global_step",
@@ -1567,6 +1625,9 @@ def test_from_dag(
input_batch_size=50,
data={"dummy_generator_step": []},
next_expected_seq_no={"dummy_generator_step": (0, 0)},
+ step_signature="6a0e9f45043fa7dc37e2b36269d660dfef63dbb7",
+ use_cache=True,
+ step_offset={"dummy_generator_step": (0, 0)},
),
"dummy_step_2": _BatchManagerStep(
step_name="dummy_step_2",
@@ -1574,9 +1635,73 @@ def test_from_dag(
input_batch_size=50,
data={"dummy_step_1": []},
next_expected_seq_no={"dummy_step_1": (0, 0)},
+ step_signature="2d1076164acb43431aad1a54a781b7bad22c7037",
+ use_cache=True,
+ step_offset={"dummy_step_1": (0, 0)},
),
}
+ def test_cache(self, dummy_batch_manager: _BatchManager) -> None:
+ # We test the cache starting from the DAG because we need the signature
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ batch_manager_path = Path(tmp_dir) / "batch_manager.json"
+ dummy_batch_manager.cache(batch_manager_path, Path(tmp_dir))
+
+ assert batch_manager_path.exists() and batch_manager_path.is_file()
+
+ for step_name, step in dummy_batch_manager._steps.items():
+ batch_manager_step_dir = (
+ Path(tmp_dir) / "batch_manager_steps" / step_name
+ )
+ assert (
+ batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir()
+ )
+
+ batch_manager_step_path = (
+ batch_manager_step_dir / "batch_manager_step.json"
+ )
+ assert (
+ batch_manager_step_path.exists()
+ and batch_manager_step_path.is_file()
+ )
+
+ built_batches_dir = batch_manager_step_dir / "built_batches"
+ assert built_batches_dir.exists()
+
+ for batch in step.built_batches:
+ batch_path = (
+ built_batches_dir
+ / f"batch_{batch.seq_no}_{batch.data_hash}.json"
+ )
+ assert batch_path.exists() and batch_path.is_file()
+
+ # for buffered_step_name in step.data:
+ # buffered_step_dir = batch_manager_step_dir / buffered_step_name
+ # assert buffered_step_dir.exists() and buffered_step_dir.is_dir()
+
+ # for batch in step.data[buffered_step_name]:
+ # batch_path = (
+ # buffered_step_dir
+ # / f"batch_{batch.seq_no}_{batch.data_hash}.json"
+ # )
+ # assert batch_path.exists() and batch_path.is_file()
+
+ def test_load_from_cache(
+ self, dummy_dag: DAG, dummy_batch_manager: _BatchManager
+ ) -> None:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ from pathlib import Path
+
+ tmp_dir = Path.home() / "Downloads/test_batch_manager"
+
+ batch_manager_path = Path(tmp_dir) / "batch_manager.json"
+ dummy_batch_manager.cache(batch_manager_path, Path(tmp_dir))
+ loaded_batch_manager = _BatchManager.load_from_cache(
+ dummy_dag, batch_manager_path, Path(tmp_dir)
+ )
+
+ assert dummy_batch_manager.dump() == loaded_batch_manager.dump()
+
def test_can_generate(self) -> None:
batch_manager = _BatchManager(
steps={},
@@ -1608,6 +1733,108 @@ def test_can_generate(self) -> None:
assert not batch_manager.can_generate()
+ def test_invalidate_cache_for(self) -> None:
+ with Pipeline() as pipeline:
+ generator = DummyGeneratorStep()
+ step_a = DummyStep1()
+ step_b = DummyStep1()
+ step_c = DummyStep2()
+
+ generator >> [step_a, step_b] >> step_c
+
+ pipeline._load_batch_manager()
+ batch_manager: "_BatchManager" = pipeline._batch_manager # type: ignore
+
+ with (
+ mock.patch.object(
+ batch_manager, "_reset_batch_manager_for_step"
+ ) as reset_mock,
+ mock.patch.object(batch_manager, "_load_predecessor_batches") as load_mock,
+ ):
+ batch_manager.invalidate_cache_for(
+ step_name=step_a.name, # type: ignore
+ dag=pipeline.dag,
+ steps_data_path=pipeline._cache_location["steps_data"],
+ )
+
+ # shouldn't have been called for step b
+ reset_mock.assert_has_calls(
+ [
+ mock.call(step_a.name, pipeline.dag),
+ mock.call(step_c.name, pipeline.dag),
+ ]
+ )
+
+ load_mock.assert_called_once_with(
+ step_a.name, pipeline.dag, pipeline._cache_location["steps_data"]
+ )
+
+ def test_reset_batch_manager_for_step(self) -> None:
+ batch_manager = _BatchManager(
+ steps={
+ "step1": _BatchManagerStep(
+ step_name="step1",
+ accumulate=True,
+ input_batch_size=5,
+ data={
+ "step0": [_Batch(seq_no=0, step_name="step0", last_batch=True)]
+ },
+ )
+ },
+ last_batch_received={
+ "step1": _Batch(seq_no=0, step_name="step1", last_batch=True)
+ },
+ last_batch_sent={
+ "step1": _Batch(seq_no=0, step_name="step1", last_batch=True)
+ },
+ last_batch_flag_sent_to=["step1"],
+ )
+
+ dag = DAG()
+ dag.add_step(DummyStep1(name="step1"))
+
+ batch_manager._reset_batch_manager_for_step("step1", dag)
+ assert batch_manager._steps["step1"].data == {}
+ assert batch_manager._last_batch_received["step1"] is None
+ assert batch_manager._last_batch_sent["step1"] is None
+ assert batch_manager._last_batch_flag_sent_to == []
+
+ def test_load_predecessor_batches(self) -> None:
+ with Pipeline() as pipeline:
+ generator = DummyGeneratorStep()
+ step_a = DummyStep1()
+ step_b = DummyStep1()
+ step_c = DummyStep2()
+
+ generator >> [step_a, step_b] >> step_c
+
+ pipeline._load_batch_manager()
+ batch_manager: "_BatchManager" = pipeline._batch_manager # type: ignore
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ previous_step_dir = (
+ Path(tmp_dir) / f"{generator.name}_{generator.signature}"
+ ) # type: ignore
+ batches = []
+ for i in range(3):
+ batch = _Batch(
+ seq_no=i,
+ step_name=generator.name, # type: ignore
+ data=[[{"a": i} for _ in range(5)]],
+ last_batch=i % 3 == 0,
+ )
+ batches.append(batch)
+ batch.save(path=previous_step_dir / f"batch_{i}.json")
+
+ batch_manager._load_predecessor_batches(
+ step_name=step_a.name, # type: ignore
+ dag=pipeline.dag,
+ steps_data_path=Path(tmp_dir), # type: ignore
+ )
+
+ assert batch_manager._steps[step_a.name].data[generator.name] == batches # type: ignore
+ assert generator.name in batch_manager._steps[step_a.name].last_batch_received # type: ignore
+
def test_dump(self) -> None:
built_batch = _Batch(
seq_no=0,
@@ -1681,6 +1908,9 @@ def test_dump(self) -> None:
"step1": (1, 1),
"step2": (1, 1),
},
+ "step_offset": {},
+ "step_signature": None,
+ "use_cache": False,
"type_info": {
"module": "distilabel.pipeline.batch_manager",
"name": "_BatchManagerStep",
@@ -1898,467 +2128,3 @@ def test_from_dict(self) -> None:
assert isinstance(step, _Batch)
assert batch_manager._last_batch_flag_sent_to == ["step3"]
-
- def test_cache(self) -> None:
- batch_manager = _BatchManager.from_dict(
- {
- "steps": {
- "step1": {
- "step_name": "step1",
- "accumulate": True,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": None,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- "step2": {
- "step_name": "step2",
- "accumulate": False,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": 50,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- },
- "last_batch_received": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- },
- "last_batch_sent": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_Batch",
- },
- },
- },
- "last_batch_flag_sent_to": ["step3"],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManager",
- },
- }
- )
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- batch_manager_path = Path(tmp_dir) / "batch_manager.json"
- batch_manager.cache(batch_manager_path)
-
- assert batch_manager_path.exists() and batch_manager_path.is_file()
-
- for step_name, step in batch_manager._steps.items():
- batch_manager_step_dir = (
- Path(tmp_dir) / "batch_manager_steps" / step_name
- )
- assert (
- batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir()
- )
-
- batch_manager_step_path = (
- batch_manager_step_dir / "batch_manager_step.json"
- )
- assert (
- batch_manager_step_path.exists()
- and batch_manager_step_path.is_file()
- )
-
- built_batches_dir = batch_manager_step_dir / "built_batches"
- assert built_batches_dir.exists()
-
- for batch in step.built_batches:
- batch_path = (
- built_batches_dir
- / f"batch_{batch.seq_no}_{batch.data_hash}.json"
- )
- assert batch_path.exists() and batch_path.is_file()
-
- for buffered_step_name in step.data:
- buffered_step_dir = batch_manager_step_dir / buffered_step_name
- assert buffered_step_dir.exists() and buffered_step_dir.is_dir()
-
- for batch in step.data[buffered_step_name]:
- batch_path = (
- buffered_step_dir
- / f"batch_{batch.seq_no}_{batch.data_hash}.json"
- )
- assert batch_path.exists() and batch_path.is_file()
-
- def test_load_from_cache(self) -> None:
- batch_manager = _BatchManager.from_dict(
- {
- "steps": {
- "step1": {
- "step_name": "step1",
- "accumulate": True,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": None,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- "step2": {
- "step_name": "step2",
- "accumulate": False,
- "convergence_step": False,
- "convergence_step_batches_consumed": {"0": {"Z": 1234}},
- "input_batch_size": 50,
- "data": {
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "data_hash": "1234",
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- },
- "built_batches": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- "data_hash": "1234",
- "size": 5,
- "accumulated": False,
- "batch_routed_to": [],
- "created_from": {},
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- }
- ],
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManagerStep",
- },
- },
- },
- "last_batch_received": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- },
- "last_batch_sent": {
- "step1": {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- "step2": {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {
- "module": "distilabel.pipeline.batch",
- "name": "_Batch",
- },
- },
- },
- "last_batch_flag_sent_to": ["step3"],
- "type_info": {
- "module": "distilabel.pipeline.batch_manager",
- "name": "_BatchManager",
- },
- }
- )
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- batch_manager_path = Path(tmp_dir) / "batch_manager.json"
- batch_manager.cache(batch_manager_path)
- loaded_batch_manager = _BatchManager.load_from_cache(batch_manager_path)
-
- assert batch_manager.dump() == loaded_batch_manager.dump()
diff --git a/tests/unit/pipeline/utils.py b/tests/unit/pipeline/utils.py
index 937a9c68b..cb223755a 100644
--- a/tests/unit/pipeline/utils.py
+++ b/tests/unit/pipeline/utils.py
@@ -42,6 +42,8 @@ def outputs(self) -> List[str]:
class DummyStep1(Step):
+ attr1: int = 5
+
@property
def inputs(self) -> List[str]:
return ["instruction"]
diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py
index f7e881273..8ecc96961 100644
--- a/tests/unit/steps/argilla/test_base.py
+++ b/tests/unit/steps/argilla/test_base.py
@@ -188,6 +188,7 @@ def test_serialization(self) -> None:
"description": "The API key to authenticate the requests to the Argilla API.",
},
],
+ "use_cache": True,
"type_info": {
"module": "tests.unit.steps.argilla.test_base",
"name": "CustomArgilla",
diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py
index 9df1df461..ab63ee541 100644
--- a/tests/unit/steps/argilla/test_preference.py
+++ b/tests/unit/steps/argilla/test_preference.py
@@ -180,6 +180,7 @@ def test_serialization(self) -> None:
"description": "The API key to authenticate the requests to the Argilla API.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.argilla.preference",
"name": "PreferenceToArgilla",
diff --git a/tests/unit/steps/argilla/test_text_generation.py b/tests/unit/steps/argilla/test_text_generation.py
index 689b8e092..356bf5a2e 100644
--- a/tests/unit/steps/argilla/test_text_generation.py
+++ b/tests/unit/steps/argilla/test_text_generation.py
@@ -155,6 +155,7 @@ def test_serialization(self) -> None:
"description": "The API key to authenticate the requests to the Argilla API.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.argilla.text_generation",
"name": "TextGenerationToArgilla",
diff --git a/tests/unit/steps/generators/test_data_sampler.py b/tests/unit/steps/generators/test_data_sampler.py
new file mode 100644
index 000000000..32882e037
--- /dev/null
+++ b/tests/unit/steps/generators/test_data_sampler.py
@@ -0,0 +1,45 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import pytest
+
+from distilabel.steps.generators.data_sampler import DataSampler
+
+
+@pytest.mark.parametrize(
+ "samples, size, batch_size, expected",
+ [
+ (10, 2, 4, [4, 4, 2]),
+ (7, 5, 6, [6, 1]),
+ (20, 5, 20, [20]),
+ (20, 50, 8, [8, 8, 4]),
+ ],
+)
+def test_generator_and_sampler(
+ samples: int, size: int, batch_size: int, expected: List[int]
+):
+ sampler = DataSampler(
+ data=[{"sample": f"sample {i}"} for i in range(30)],
+ size=size,
+ samples=samples,
+ batch_size=batch_size,
+ )
+ sampler.load()
+ results = [item[0] for item in sampler.process()]
+ assert len(results) == len(expected)
+ assert len(results[0]) == batch_size
+ for i, result in enumerate(results):
+ assert len(result) == expected[i]
diff --git a/tests/unit/steps/tasks/apigen/__init__.py b/tests/unit/steps/tasks/apigen/__init__.py
new file mode 100644
index 000000000..20ce00bda
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py b/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py
new file mode 100644
index 000000000..abcc66214
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py
@@ -0,0 +1,27 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
+ """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
+
+ Args:
+ initial_velocity: The initial velocity of the object.
+ acceleration: The acceleration of the object.
+ time: The time elapsed.
+
+ Returns:
+ The final velocity
+ """
+ return initial_velocity + acceleration * time
diff --git a/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py b/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py
new file mode 100644
index 000000000..db3bd1bcc
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py
@@ -0,0 +1,33 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple
+
+
+def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]:
+ """Gets the value at the specified index in the matrix.
+
+ Args:
+ matrix: A list of lists representing the matrix.
+ indices: A tuple containing the row and column indices.
+ """
+ row_index, col_index = indices
+ if (
+ row_index < 0
+ or row_index >= len(matrix)
+ or col_index < 0
+ or col_index >= len(matrix[row_index])
+ ):
+ return None
+ return matrix[row_index][col_index]
diff --git a/tests/unit/steps/tasks/apigen/_sample_module.py b/tests/unit/steps/tasks/apigen/_sample_module.py
new file mode 100644
index 000000000..6e9e08502
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/_sample_module.py
@@ -0,0 +1,47 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple
+
+
+def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
+ """Calculates the final velocity of an object given its initial velocity, acceleration, and time.
+
+ Args:
+ initial_velocity: The initial velocity of the object.
+ acceleration: The acceleration of the object.
+ time: The time elapsed.
+
+ Returns:
+ The final velocity
+ """
+ return initial_velocity + acceleration * time
+
+
+def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]:
+ """Gets the value at the specified index in the matrix.
+
+ Args:
+ matrix: A list of lists representing the matrix.
+ indices: A tuple containing the row and column indices.
+ """
+ row_index, col_index = indices
+ if (
+ row_index < 0
+ or row_index >= len(matrix)
+ or col_index < 0
+ or col_index >= len(matrix[row_index])
+ ):
+ return None
+ return matrix[row_index][col_index]
diff --git a/tests/unit/steps/tasks/apigen/test_execution_checker.py b/tests/unit/steps/tasks/apigen/test_execution_checker.py
new file mode 100644
index 000000000..d70e42271
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_execution_checker.py
@@ -0,0 +1,140 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from pathlib import Path
+from typing import Any, Dict
+
+import pytest
+
+from distilabel.steps.tasks.apigen.execution_checker import APIGenExecutionChecker
+
+SAMPLE_LIB = Path(__file__).parent / "_sample_module.py"
+SAMPLE_LIB_FOLDER = Path(__file__).parent / "_sample_lib"
+
+
+class TestAPIGenExecutionChecker:
+ @pytest.mark.parametrize("lib", (SAMPLE_LIB, SAMPLE_LIB_FOLDER))
+ @pytest.mark.parametrize(
+ "answers, expected",
+ [
+ (
+ {
+ "query": "Whats the velocity of X?",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": "0.1",
+ "time": 5,
+ },
+ "name": "final_velocity",
+ }
+ ]
+ ),
+ },
+ [
+ {
+ "query": "Whats the velocity of X?",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": "0.1",
+ "time": 5,
+ },
+ "name": "final_velocity",
+ }
+ ]
+ ),
+ "keep_row_after_execution_check": True,
+ "execution_result": ["0.7"],
+ }
+ ],
+ ),
+ (
+ {
+ "query": "Other query",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": 0.1,
+ "time": 0.5,
+ },
+ "name": "unknown_function",
+ }
+ ]
+ ),
+ },
+ [
+ {
+ "query": "Other query",
+ "answers": json.dumps(
+ [
+ {
+ "arguments": {
+ "initial_velocity": 0.2,
+ "acceleration": 0.1,
+ "time": 0.5,
+ },
+ "name": "unknown_function",
+ }
+ ]
+ ),
+ "keep_row_after_execution_check": False,
+ "execution_result": ["Function 'unknown_function' not found."],
+ }
+ ],
+ ),
+ (
+ {
+ "query": "Other query",
+ "answers": '[{"arguments": {"matrix": "[[1, 2, 3], [4, 5, 6], [7, 8, 9]]", "indices": "[1, 2]"}, "name": "get_value"}]',
+ },
+ [
+ {
+ "query": "Other query",
+ "answers": '[{"arguments": {"matrix": "[[1, 2, 3], [4, 5, 6], [7, 8, 9]]", "indices": "[1, 2]"}, "name": "get_value"}]',
+ "keep_row_after_execution_check": True,
+ "execution_result": ["6"],
+ }
+ ],
+ ),
+ (
+ {
+ "query": "Other query",
+ "answers": None,
+ },
+ [
+ {
+ "query": "Other query",
+ "answers": None,
+ "keep_row_after_execution_check": False,
+ "execution_result": ["No answers were provided."],
+ }
+ ],
+ ),
+ ],
+ )
+ def test_process(
+ self, lib: str, answers: Dict[str, str], expected: Dict[str, Any]
+ ) -> None:
+ task = APIGenExecutionChecker(libpath=str(lib))
+ task.load()
+ result = next(task.process([answers]))
+ assert result == expected
diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py
new file mode 100644
index 000000000..a290666a6
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_generator.py
@@ -0,0 +1,172 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from typing import TYPE_CHECKING, List, Union
+
+import pytest
+
+from distilabel.steps.tasks.apigen.generator import APIGenGenerator
+from tests.unit.conftest import DummyLLM
+
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
+
+import json
+
+
+class DummyAPIGenLLM(DummyLLM):
+ use_structured_output: bool = False
+ number: int = 1
+
+ def generate(
+ self, inputs: List["FormattedInput"], num_generations: int = 1
+ ) -> "GenerateOutput":
+ query_answers = [
+ {
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": [
+ {
+ "name": "get_breed_information",
+ "arguments": {"breed": "Maine Coon"},
+ }
+ ]
+ * self.number,
+ }
+ ]
+ if self.use_structured_output:
+ query_answers = {"pairs": query_answers}
+ return [
+ [json.dumps(query_answers) for _ in range(num_generations)]
+ for _ in range(len(inputs))
+ ]
+
+
+# Example of 3 rows from Salesforce/xlam-function-calling-60k
+SAMPLE_DATA = [
+ {
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "id": 3493,
+ "tools": '[{"name": "get_breed_information", "description": "Fetch information about a specific cat breed from the Cat Breeds API.", "parameters": {"breed": {"description": "The name of the cat breed to fetch information for.", "type": "str", "default": "aegean"}}}, {"name": "country_region_cities", "description": "Fetches a list of cities within a specified region of a given country from the GeoDB API.", "parameters": {"countryid": {"description": "An ISO-3166 country code or WikiData ID.", "type": "str", "default": "US"}, "regioncode": {"description": "An ISO-3166 or FIPS region code.", "type": "str", "default": "CA"}, "limit": {"description": "The maximum number of results to retrieve. Defaults to None.", "type": "int, optional", "default": ""}, "hateoasmode": {"description": "Include HATEOAS-style links in results. Defaults to None.", "type": "bool, optional", "default": ""}, "asciimode": {"description": "Display results using ASCII characters. Defaults to None.", "type": "bool, optional", "default": ""}, "nameprefixdefaultlangresults": {"description": "Match on names in the default language if a non-default language is requested when prefix-matching. Defaults to None.", "type": "bool, optional", "default": ""}, "timezoneids": {"description": "Only include cities in these time zones. Comma-separated values. Defaults to None.", "type": "str, optional", "default": ""}, "nameprefix": {"description": "Only include cities whose names start with this prefix. If languagecode is set, the prefix will be matched on the name as it appears in that language. Defaults to None.", "type": "str, optional", "default": ""}, "types": {"description": "Only include cities of these types (comma-separated): CITY, ADM2. Defaults to None.", "type": "str, optional", "default": ""}, "minpopulation": {"description": "Only include cities with at least this population. Defaults to None.", "type": "int, optional", "default": ""}, "languagecode": {"description": "Display results in this language. Defaults to None.", "type": "str, optional", "default": ""}, "offset": {"description": "The zero-based offset into the results. Defaults to None.", "type": "int, optional", "default": ""}, "maxpopulation": {"description": "Only include cities with no more than this population. Defaults to None.", "type": "int, optional", "default": ""}, "includedeleted": {"description": "Whether to include any cities marked deleted. Options are: ALL, SINCE_YESTERDAY, SINCE_LAST_WEEK, NONE. Defaults to None.", "type": "str, optional", "default": ""}, "sort": {"description": "How to sort the results. Format: \\u00b1SORT_FIELD,\\u00b1SORT_FIELD where SORT_FIELD = elevation, name, population. Defaults to None.", "type": "str, optional", "default": ""}}}, {"name": "company_details", "description": "Fetch details of a company from Indeed\'s API.", "parameters": {"company_id": {"description": "The unique identifier of the company to fetch details for.", "type": "str", "default": "Microsoft"}, "locality": {"description": "The locality or country code for Indeed\'s subdomain. Default is \'us\' if not provided.", "type": "str, optional", "default": ""}}}]',
+ },
+ {
+ "answers": '[{"name": "mailcheck", "arguments": {"domain": "protonmail.com"}}, {"name": "mailcheck", "arguments": {"domain": "mail.com"}}, {"name": "get_products_in_category", "arguments": {"skip": 20, "limit": 25, "category": "furniture"}}]',
+ "query": "Check if the email domains 'protonmail.com' and 'mail.com' are valid and not temporary. Get the products from category 'furniture' in my store, skipping the first 20 items and limiting to 25 items.",
+ "id": 57546,
+ "tools": '[{"name": "mailcheck", "description": "Checks if an email domain is valid or a disposable/temporary address.", "parameters": {"domain": {"description": "The email or domain to check for validity. It is recommended to enter just the domain for user privacy.", "type": "str", "default": "mailinator.com"}}}, {"name": "get_products_in_category", "description": "Fetches a list of products from a specified category in a store with pagination.", "parameters": {"skip": {"description": "The number of items to skip before starting to collect the result set.", "type": "int", "default": ""}, "limit": {"description": "The number of items to return in the result set.", "type": "int", "default": ""}, "category": {"description": "The category from which to fetch products.", "type": "str", "default": ""}}}, {"name": "product_by_id", "description": "Fetches detailed information about a specific product from the AliExpress API using the provided product ID.", "parameters": {"product_id": {"description": "The unique identifier for the product on AliExpress.", "type": "int", "default": "32841070485"}}}]',
+ },
+ {
+ "answers": '[{"name": "navigations_get_node_content", "arguments": {"is_id": 8899, "cat_id": 8899, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 7766, "cat_id": 7766, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 5544, "cat_id": 5544, "language": "fr"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 3322, "cat_id": 3322, "language": "fr"}}]',
+ "query": "What are the node contents for category IDs 8899 and 7766 in English and for category IDs 5544 and 3322 in French?",
+ "id": 8815,
+ "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]',
+ },
+]
+
+
+class TestApiGenGenerator:
+ @pytest.mark.parametrize("number", [1, 2, [3]])
+ @pytest.mark.parametrize("use_default_structured_output", [True, False])
+ @pytest.mark.parametrize("use_tools", [True, False])
+ def test_format_input(
+ self,
+ number: Union[int, List[int]],
+ use_default_structured_output: bool,
+ use_tools: bool,
+ ) -> None:
+ random.seed(42)
+ task = APIGenGenerator(
+ llm=DummyLLM(),
+ number=number,
+ use_tools=use_tools,
+ use_default_structured_output=use_default_structured_output,
+ )
+ task.load()
+ formatted = task.format_input(
+ input={
+ "examples": '## Query:\nWhat information can be obtained about the Maine Coon cat breed?\n## Answer:\n[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "func_name": "get_breed_information",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]',
+ }
+ )
+
+ assert isinstance(formatted, list)
+ # Check only the user prompt, the system one should be fixed
+ formatted_prompt = formatted[1]["content"]
+
+ if isinstance(number, list):
+ # Fix the number for the tests for simplicity
+ number = 3
+ assert f"Now please generate {number} diverse" in formatted_prompt
+
+ assert (
+ "The output MUST strictly adhere to the following JSON format, and NO other text MUST be included:"
+ in formatted_prompt
+ )
+
+ tools_entry = "This is the available tool to guide you (respect the order of the parameters):"
+ if use_tools:
+ assert tools_entry in formatted_prompt
+ else:
+ assert tools_entry not in formatted_prompt
+
+ is_parallel_check = "It can contain multiple parallel queries in natural language for the given functions. They could use either the same function with different arguments or different functions."
+ if number > 1:
+ assert is_parallel_check in formatted_prompt
+ else:
+ assert is_parallel_check not in formatted_prompt
+
+ @pytest.mark.parametrize("number", [1, 2])
+ @pytest.mark.parametrize("use_default_structured_output", [True, False])
+ @pytest.mark.parametrize("use_tools", [True, False])
+ def test_process(
+ self,
+ number: Union[int, List[int]],
+ use_default_structured_output: bool,
+ use_tools: bool,
+ ) -> None:
+ # Is parallel is not relevant in this case, it's only relevant for the format_input
+ # as it will be multiple questions in the prompt
+ random.seed(42)
+ task = APIGenGenerator(
+ llm=DummyAPIGenLLM(
+ use_structured_output=use_default_structured_output, number=number
+ ),
+ number=number,
+ use_tools=use_tools,
+ use_default_structured_output=use_default_structured_output,
+ )
+ task.load()
+ result = next(
+ task.process(
+ [
+ {
+ "examples": '## Query:\nWhat information can be obtained about the Maine Coon cat breed?\n## Answer:\n[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "func_name": "get_breed_information",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]',
+ }
+ ]
+ )
+ )[0]
+ assert "query" in result
+ assert "answers" in result
+ query = result["query"]
+ assert isinstance(query, str)
+ answers = json.loads(result["answers"])
+ assert isinstance(answers, list)
+ assert len(answers) == number
diff --git a/tests/unit/steps/tasks/apigen/test_semantic_checker.py b/tests/unit/steps/tasks/apigen/test_semantic_checker.py
new file mode 100644
index 000000000..e73b71c3a
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_semantic_checker.py
@@ -0,0 +1,113 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict
+
+import pytest
+
+from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker
+from tests.unit.conftest import DummyLLM
+
+SAMPLE_DATA = [
+ # The info can for the function description can be obtained from the tool itself
+ {
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ },
+ {
+ "func_desc": "Checks if an email domain is valid or a disposable/temporary address.",
+ "query": "Check if the email domains 'protonmail.com' and 'mail.com' are valid and not temporary. Get the products from category 'furniture' in my store, skipping the first 20 items and limiting to 25 items.",
+ "answers": '[{"name": "mailcheck", "arguments": {"domain": "protonmail.com"}}, {"name": "mailcheck", "arguments": {"domain": "mail.com"}}, {"name": "get_products_in_category", "arguments": {"skip": 20, "limit": 25, "category": "furniture"}}]',
+ "execution_result": "Response for the emails",
+ },
+ {
+ "func_desc": "Fetches the content of a node in a navigation hierarchy.",
+ "query": "What are the node contents for category IDs 8899 and 7766 in English and for category IDs 5544 and 3322 in French?",
+ "answers": '[{"name": "navigations_get_node_content", "arguments": {"is_id": 8899, "cat_id": 8899, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 7766, "cat_id": 7766, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 5544, "cat_id": 5544, "language": "fr"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 3322, "cat_id": 3322, "language": "fr"}}]',
+ "execution_result": "Response for the node contents",
+ },
+]
+
+
+class TestAPIGenSemanticChecker:
+ @pytest.mark.parametrize("use_default_structured_output", [True, False])
+ def test_format_input(self, use_default_structured_output: bool) -> None:
+ task = APIGenSemanticChecker(
+ llm=DummyLLM(),
+ use_default_structured_output=use_default_structured_output,
+ )
+ task.load()
+ result = task.format_input(SAMPLE_DATA[0])
+ assert isinstance(result, list)
+ formatted_prompt = result[1]["content"]
+
+ default_structured_output_check = "Your response MUST strictly adhere to the following JSON format, and NO other text MUST be included"
+ assert default_structured_output_check in formatted_prompt
+ assert (
+ '- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]'
+ in formatted_prompt
+ )
+ assert (
+ "- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API."
+ in formatted_prompt
+ )
+ assert (
+ "- Execution Results: Hopefully some info about the Maine Coon"
+ in formatted_prompt
+ )
+
+ @pytest.mark.parametrize(
+ "result, expected",
+ [
+ (
+ '{"thought": "thought", "keep_row_after_semantic_check": "no", "passes": "no"}',
+ {
+ "thought": "thought",
+ "keep_row_after_semantic_check": False,
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ },
+ ),
+ (
+ None,
+ {
+ "thought": None,
+ "keep_row_after_semantic_check": None,
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ },
+ ),
+ (
+ "wrong",
+ {
+ "thought": None,
+ "keep_row_after_semantic_check": None,
+ "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]',
+ "execution_result": "Hopefully some info about the Maine Coon",
+ "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
+ "query": "What information can be obtained about the Maine Coon cat breed?",
+ },
+ ),
+ ],
+ )
+ def test_format_output(self, result: str, expected: Dict[str, Any]) -> None:
+ task = APIGenSemanticChecker(llm=DummyLLM())
+ task.load()
+ assert task.format_output(result, SAMPLE_DATA[0]) == expected
diff --git a/tests/unit/steps/tasks/apigen/test_utils.py b/tests/unit/steps/tasks/apigen/test_utils.py
new file mode 100644
index 000000000..00707f17a
--- /dev/null
+++ b/tests/unit/steps/tasks/apigen/test_utils.py
@@ -0,0 +1,77 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+from typing import Any, Dict
+
+import pytest
+
+from distilabel.steps.tasks.apigen.utils import (
+ execute_from_response,
+ load_module_from_path,
+)
+
+
+@pytest.mark.parametrize(
+ "function_name, answer, expected_result",
+ [
+ (
+ "final_velocity",
+ {"initial_velocity": 10, "acceleration": 5, "time": 2},
+ {"execution_result": "20", "keep": True},
+ ),
+ # In this case, internally we should cast the arguments
+ (
+ "final_velocity",
+ {"initial_velocity": "10", "acceleration": "5", "time": "2"},
+ {"execution_result": "20", "keep": True},
+ ),
+ # Different names for the arguments but correctly positioned
+ (
+ "final_velocity",
+ {"v0": "10", "a": "5", "t": "2"},
+ {"execution_result": "20", "keep": True},
+ ),
+ # Fail casting one of the values
+ (
+ "final_velocity",
+ {"initial_velocity": "10", "acceleration": "5", "time": "1m/s"},
+ {
+ "execution_result": "unsupported operand type(s) for +: 'int' and 'str'",
+ "keep": False,
+ },
+ ),
+ (
+ "final_velocity",
+ {"initial_velocity": 10, "acceleration": 5},
+ {
+ "execution_result": "final_velocity() missing 1 required positional argument: 'time'",
+ "keep": False,
+ },
+ ),
+ (
+ "unknwown_function",
+ {"initial_velocity": 10, "acceleration": 5, "time": 2},
+ {"execution_result": "Function not found", "keep": False},
+ ),
+ ],
+)
+def test_execute_from_response(
+ function_name: str, answer: Dict[str, Any], expected_result: Dict[str, Any]
+):
+ libpath = Path(__file__).parent / "_sample_module.py"
+ libpath = load_module_from_path(libpath)
+ final_velocity = getattr(libpath, function_name, None)
+ result = execute_from_response(final_velocity, answer)
+ assert result == expected_result
diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py
index 4f6e12c6f..66f67347b 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_base.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_base.py
@@ -241,6 +241,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"description": "As `numpy` is being used in order to randomly pick a mutation method, then is nice to seed a random seed.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.tasks.evol_instruct.base",
"name": "EvolInstruct",
diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py
index 77b4a8ea0..8f86b9490 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_generator.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py
@@ -246,6 +246,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"description": "As `numpy` is being used in order to randomly pick a mutation method, then is nice to seed a random seed.",
},
],
+ "use_cache": True,
"type_info": {
"module": EvolInstructGenerator.__module__,
"name": EvolInstructGenerator.__name__,
diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py
index a7346b106..2ac460afc 100644
--- a/tests/unit/steps/tasks/evol_quality/test_base.py
+++ b/tests/unit/steps/tasks/evol_quality/test_base.py
@@ -205,6 +205,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"description": "As `numpy` is being used in order to randomly pick a mutation method, then is nice to set a random seed.",
},
],
+ "use_cache": True,
"type_info": {
"module": task.__module__,
"name": task.__class__.__name__,
diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py
index 8b830a0db..cc13681f9 100644
--- a/tests/unit/steps/tasks/magpie/test_base.py
+++ b/tests/unit/steps/tasks/magpie/test_base.py
@@ -762,6 +762,7 @@ def test_serialization(self) -> None:
"description": "The number of generations to be produced per input.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.tasks.magpie.base",
"name": "Magpie",
diff --git a/tests/unit/steps/tasks/magpie/test_generator.py b/tests/unit/steps/tasks/magpie/test_generator.py
index b5e41c355..d1d142635 100644
--- a/tests/unit/steps/tasks/magpie/test_generator.py
+++ b/tests/unit/steps/tasks/magpie/test_generator.py
@@ -202,6 +202,7 @@ def test_serialization(self) -> None:
"description": "The number of rows to generate.",
},
],
+ "use_cache": True,
"type_info": {
"module": "distilabel.steps.tasks.magpie.generator",
"name": "MagpieGenerator",
diff --git a/tests/unit/steps/tasks/test_argilla_labeller.py b/tests/unit/steps/tasks/test_argilla_labeller.py
new file mode 100644
index 000000000..926118dd6
--- /dev/null
+++ b/tests/unit/steps/tasks/test_argilla_labeller.py
@@ -0,0 +1,210 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Any, Dict, List
+
+import pytest
+
+from distilabel.pipeline.local import Pipeline
+from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller
+from distilabel.steps.tasks.typing import ChatItem
+from tests.unit.conftest import DummyAsyncLLM
+
+
+@pytest.fixture
+def fields() -> Dict[str, Any]:
+ return [
+ {
+ "name": "text",
+ "description": "The text of the question",
+ "title": "The text of the question",
+ "settings": {"type": "text"},
+ }
+ ]
+
+
+@pytest.fixture
+def questions() -> List[Dict[str, Any]]:
+ return [
+ {
+ "name": "label_selection",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "label_selection",
+ "options": [
+ {"value": "yes", "text": "Yes"},
+ {"value": "no", "text": "No"},
+ ],
+ },
+ },
+ {
+ "name": "multi_label_selection",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "multi_label_selection",
+ "options": [
+ {"value": "yes", "text": "Yes"},
+ {"value": "no", "text": "No"},
+ ],
+ },
+ },
+ {
+ "name": "rating",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "rating",
+ "options": [
+ {"value": "1", "text": "1"},
+ ],
+ },
+ },
+ {
+ "name": "text",
+ "description": "The class of the question",
+ "title": "Is the question a question?",
+ "settings": {
+ "type": "text",
+ },
+ },
+ ]
+
+
+@pytest.fixture
+def outputs() -> List[Dict[str, Any]]:
+ return [
+ {
+ "label": "yes",
+ },
+ {
+ "labels": ["yes", "no"],
+ },
+ {
+ "rating": "1",
+ },
+ {
+ "text": "yes",
+ },
+ ]
+
+
+@pytest.fixture
+def records() -> List[Dict[str, Any]]:
+ return [
+ {
+ "fields": {
+ "text": "What is the capital of France?",
+ },
+ "responses": [
+ {
+ "quesion_name": "label_selection",
+ "value": "yes",
+ }
+ ],
+ }
+ ]
+
+
+class TestArgillaLabeller:
+ def test_format_input(
+ self,
+ questions: List[Dict[str, Any]],
+ records: List[Dict[str, Any]],
+ fields: List[Dict[str, Any]],
+ ) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ for question in questions:
+ result: List[ChatItem] = task.format_input(
+ input={
+ "question": question,
+ "fields": fields,
+ "record": records[0],
+ }
+ )
+ assert question["description"] in result[-1]["content"]
+ assert question["title"] in result[-1]["content"]
+ if question["settings"]["type"] in [
+ "label_selection",
+ "multi_label_selection",
+ "span",
+ "rating",
+ ]:
+ assert (
+ question["settings"]["options"][0]["value"] in result[-1]["content"]
+ )
+
+ def test_format_output(
+ self,
+ questions: List[Dict[str, Any]],
+ records: List[Dict[str, Any]],
+ fields: List[Dict[str, Any]],
+ outputs: List[Dict[str, Any]],
+ ) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ for question, output in zip(questions, outputs):
+ task.format_output(
+ input={
+ "question": question,
+ "fields": fields,
+ "record": records[0],
+ },
+ output=json.dumps(output),
+ )
+
+ def test_fail_on_invalid_question_type(
+ self, questions: List[Dict[str, Any]], records: List[Dict[str, Any]]
+ ) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ fake_question = questions[0]
+ fake_question["settings"]["type"] = "invalid_type"
+
+ with pytest.raises(ValueError):
+ task.format_input(
+ input={
+ "record": records[0],
+ "question": fake_question,
+ }
+ )
+
+ def test_fail_on_no_question(self, records: List[Dict[str, Any]]) -> None:
+ task = ArgillaLabeller(
+ name="argilla_labeller",
+ llm=DummyAsyncLLM(),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ with pytest.raises(ValueError):
+ task.format_input(input={"record": records[0]})
diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py
index d00b8edac..29341052f 100644
--- a/tests/unit/steps/tasks/test_base.py
+++ b/tests/unit/steps/tasks/test_base.py
@@ -414,6 +414,94 @@ def test_process(
result = next(task.process(input))
assert result == expected
+ def test_process_overriding_inputs(self) -> None:
+ llm = DummyAsyncLLM()
+ task = DummyTask(
+ name="task",
+ llm=llm,
+ group_generations=False,
+ num_generations=3,
+ input_mappings={"instruction": "instruction_2"},
+ )
+
+ result = next(
+ task.process_applying_mappings(
+ [
+ {
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "additional_info": "info",
+ }
+ ]
+ )
+ )
+
+ assert result == [
+ {
+ "additional_info": "info",
+ "distilabel_metadata": {
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "instruction that will be used as input",
+ "role": "user",
+ },
+ ],
+ "raw_output_task": "output",
+ },
+ "info_from_input": "info",
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "model_name": "test",
+ "output": "output",
+ },
+ {
+ "additional_info": "info",
+ "distilabel_metadata": {
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "instruction that will be used as input",
+ "role": "user",
+ },
+ ],
+ "raw_output_task": "output",
+ },
+ "info_from_input": "info",
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "model_name": "test",
+ "output": "output",
+ },
+ {
+ "additional_info": "info",
+ "distilabel_metadata": {
+ "raw_input_task": [
+ {
+ "content": "",
+ "role": "system",
+ },
+ {
+ "content": "instruction that will be used as input",
+ "role": "user",
+ },
+ ],
+ "raw_output_task": "output",
+ },
+ "info_from_input": "info",
+ "instruction": "instruction that won't be used but overriden by input mapping",
+ "instruction_2": "instruction that will be used as input",
+ "model_name": "test",
+ "output": "output",
+ },
+ ]
+
def test_process_with_runtime_parameters(self) -> None:
# 1. Runtime parameters provided
llm = DummyRuntimeLLM() # type: ignore
@@ -576,6 +664,7 @@ def test_serialization(self) -> None:
"optional": True,
},
],
+ "use_cache": True,
"type_info": {
"module": "tests.unit.conftest",
"name": "DummyTask",
diff --git a/tests/unit/steps/tasks/test_clair.py b/tests/unit/steps/tasks/test_clair.py
new file mode 100644
index 000000000..3d16c0bf4
--- /dev/null
+++ b/tests/unit/steps/tasks/test_clair.py
@@ -0,0 +1,74 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Union
+
+import pytest
+
+from distilabel.steps.tasks.clair import CLAIR
+from tests.unit.conftest import DummyLLM
+
+
+class TestCLAIR:
+ def test_format_input(self) -> None:
+ task = CLAIR(llm=DummyLLM())
+ task.load()
+
+ result = task.format_input(
+ input={"task": "TASK", "student_solution": "SOLUTION"}
+ )
+ # System prompt
+ assert (
+ result[0]["content"]
+ == "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."
+ )
+ # User prompt
+ assert (
+ result[1]["content"]
+ == """\
+{task}: TASK
+
+{student_solution}: SOLUTION
+
+-----------------
+
+Let's first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.
+""".strip()
+ )
+
+ @pytest.mark.parametrize(
+ "output, expected",
+ [
+ (None, {"revision": None, "rational": None}),
+ ("WRONG", {"revision": None, "rational": None}),
+ (
+ "{teacher_reasoning}\n\nreasoning\n\n{corrected_student_solution}\n\ncorrected",
+ {"revision": "corrected", "rational": "reasoning"},
+ ),
+ ],
+ )
+ def test_format_output(
+ self,
+ output: Union[str, None],
+ expected: Dict[str, Any],
+ ) -> None:
+ task = CLAIR(llm=DummyLLM())
+ task.load()
+
+ result = task.format_output(
+ output=output,
+ input={},
+ )
+
+ assert result == expected
diff --git a/tests/unit/steps/tasks/test_decorator.py b/tests/unit/steps/tasks/test_decorator.py
new file mode 100644
index 000000000..085153c1f
--- /dev/null
+++ b/tests/unit/steps/tasks/test_decorator.py
@@ -0,0 +1,200 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Union
+
+import pytest
+
+from distilabel.errors import DistilabelUserError
+from distilabel.steps.tasks.decorator import task
+from tests.unit.conftest import DummyLLM
+
+
+class TestTaskDecorator:
+ def test_decoraror_raise_if_no_docstring(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"When using the `task` decorator, including a docstring in the formatting function is mandatory",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ return {"response": output}
+
+ def test_decorator_raise_if_docstring_invalid(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"Formatting function decorated with `task` doesn't follow the expected format.",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """This is not valid"""
+ return {"response": output}
+
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"Formatting function decorated with `task` doesn't follow the expected format.",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ - this
+ - is
+ - a
+ - list
+ ---
+ """
+ return {"response": output}
+
+ def test_decorator_raise_if_no_system_prompt_or_user_message_template(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` must include both the `system_prompt` and `user_message_template` keys in the docstring",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: prompt
+ ---
+ """
+ return {"response": output}
+
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` must include both the `system_prompt` and `user_message_template` keys in the docstring",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ user_message_template: prompt
+ ---
+ """
+ return {"response": output}
+
+ def test_decorator_raise_if_template_invalid_placeholders(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` includes invalid placeholders in the extracted `system_prompt`",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: |
+ You are an AI assistant designed to {task}
+
+ user_message_template: |
+ {instruction}
+ ---
+ """
+ return {"response": output}
+
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` includes invalid placeholders in the extracted `user_message_template`",
+ ):
+
+ @task(inputs=["task"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: |
+ You are an AI assistant designed to {task}
+
+ user_message_template: |
+ {instruction}
+ ---
+ """
+ return {"response": output}
+
+ def test_decorator_task(self) -> None:
+ @task(inputs=["task", "instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ `MyTask` is a simple `Task` for bla bla bla
+
+ ---
+ system_prompt: |
+ You are an AI assistant designed to {task}
+
+ user_message_template: |
+ Text: {instruction}
+ ---
+ """
+ return {"response": output}
+
+ my_task = MyTask(llm=DummyLLM())
+
+ my_task.load()
+
+ assert my_task.inputs == ["task", "instruction"]
+ assert my_task.outputs == ["response"]
+ assert my_task.format_input(
+ {"task": "summarize", "instruction": "The cell..."}
+ ) == [
+ {
+ "role": "system",
+ "content": "You are an AI assistant designed to summarize",
+ },
+ {"role": "user", "content": "Text: The cell..."},
+ ]
+ assert next(
+ my_task.process_applying_mappings(
+ [{"task": "summarize", "instruction": "The cell..."}]
+ )
+ ) == [
+ {
+ "task": "summarize",
+ "instruction": "The cell...",
+ "response": "output",
+ "model_name": "test",
+ "distilabel_metadata": {
+ "raw_input_my_task_0": [
+ {
+ "content": "You are an AI assistant designed to summarize",
+ "role": "system",
+ },
+ {
+ "content": "Text: The cell...",
+ "role": "user",
+ },
+ ],
+ "raw_output_my_task_0": "output",
+ },
+ }
+ ]
diff --git a/tests/unit/steps/tasks/test_pair_rm.py b/tests/unit/steps/tasks/test_pair_rm.py
index 1903ccfb2..104726307 100644
--- a/tests/unit/steps/tasks/test_pair_rm.py
+++ b/tests/unit/steps/tasks/test_pair_rm.py
@@ -111,5 +111,6 @@ def test_serialization(self, _: MagicMock) -> None:
"optional": True,
},
],
+ "use_cache": True,
"type_info": {"module": "distilabel.steps.tasks.pair_rm", "name": "PairRM"},
}
diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py
index 4791c5be2..6e8297bb0 100644
--- a/tests/unit/steps/test_base.py
+++ b/tests/unit/steps/test_base.py
@@ -29,6 +29,8 @@
class DummyStep(Step):
+ attr1: int = 5
+
@property
def inputs(self) -> List[str]:
return ["instruction"]
@@ -66,6 +68,16 @@ def process(self, inputs: StepInput) -> StepOutput:
class TestStep:
+ def test_signature(self) -> None:
+ step = DummyStep(attr1=5)
+ assert step.signature == "a0ce83adedabec3fba270ec7bc8a52a62cbbee40"
+
+ step = DummyStep(attr1=5)
+ assert step.signature == "a0ce83adedabec3fba270ec7bc8a52a62cbbee40"
+
+ step = DummyStep(attr1=1234)
+ assert step.signature == "c00e67df4f7ed97a2bf8d9b1178d6c728e577c3b"
+
def test_create_step_with_invalid_name(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
@@ -397,6 +409,7 @@ def test_step_dump(self) -> None:
step = DummyStep(name="dummy", pipeline=pipeline)
assert step.dump() == {
"name": "dummy",
+ "attr1": 5,
"input_batch_size": 50,
"input_mappings": {},
"output_mappings": {},
@@ -444,6 +457,7 @@ def test_step_dump(self) -> None:
"optional": True,
},
],
+ "use_cache": True,
TYPE_INFO_KEY: {
"module": "tests.unit.steps.test_base",
"name": "DummyStep",