diff --git a/.vale/styles/config/vocabularies/Data/accept.txt b/.vale/styles/config/vocabularies/Data/accept.txt index a796081ee2cf..dacabaeec1a1 100644 --- a/.vale/styles/config/vocabularies/Data/accept.txt +++ b/.vale/styles/config/vocabularies/Data/accept.txt @@ -28,10 +28,13 @@ PDFs Predibase('s)? [Pp]refetch [Pp]refetching +[Pp]ostprocess +[Pp]ostprocessor(s)? [Pp]reprocess [Pp]reprocessor(s)? process_file [Pp]ushdown +[Rr]eformat(s)? queryable RGB runai diff --git a/doc/source/data/examples.yml b/doc/source/data/examples.yml index 757c8560bd73..3380cdf3a83a 100644 --- a/doc/source/data/examples.yml +++ b/doc/source/data/examples.yml @@ -24,14 +24,19 @@ examples: use_cases: - computer vision link: examples/huggingface_vit_batch_prediction - - title: Batch Inference with LoRA Adapter + - title: Tabular Data Training and Batch Inference with XGBoost + skill_level: beginner + frameworks: + - xgboost + link: ../train/examples/xgboost/distributed-xgboost-lightgbm + - title: LLM Batch Inference skill_level: beginner frameworks: - vLLM use_cases: - large language models - generative ai - link: ../llm/examples/batch/vllm-with-lora + link: examples/llm_batch_inference_text/content/README - title: Batch Inference with Structural Output skill_level: beginner frameworks: @@ -40,11 +45,23 @@ examples: - large language models - generative ai link: ../llm/examples/batch/vllm-with-structural-output - - title: Tabular Data Training and Batch Inference with XGBoost + - title: Batch Inference with LoRA Adapter skill_level: beginner frameworks: - - xgboost - link: ../train/examples/xgboost/distributed-xgboost-lightgbm + - vLLM + use_cases: + - large language models + - generative ai + link: ../llm/examples/batch/vllm-with-lora + - title: Multimodal LLM Batch Inference + skill_level: beginner + frameworks: + - vLLM + use_cases: + - large language models + - generative ai + - computer vision + link: examples/llm_batch_inference_vision/content/README - title: Unstructured Data Ingestion and Processing skill_level: intermediate frameworks: @@ -53,4 +70,4 @@ examples: use_cases: - document processing - data ingestion - link: examples/unstructured_data_ingestion/content/unstructured_data_ingestion \ No newline at end of file + link: examples/unstructured_data_ingestion/content/unstructured_data_ingestion diff --git a/doc/source/data/examples/BUILD.bazel b/doc/source/data/examples/BUILD.bazel index 3c507b754b1b..9b374dc707ce 100644 --- a/doc/source/data/examples/BUILD.bazel +++ b/doc/source/data/examples/BUILD.bazel @@ -25,4 +25,4 @@ filegroup( "**/ci/gce.yaml" ]), visibility = ["//release:__pkg__"], -) \ No newline at end of file +) diff --git a/doc/source/data/examples/llm_batch_inference_text/ci/aws.yaml b/doc/source/data/examples/llm_batch_inference_text/ci/aws.yaml new file mode 100644 index 000000000000..8205befe144b --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/ci/aws.yaml @@ -0,0 +1,17 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: m5.2xlarge + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g6.2xlarge + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true diff --git a/doc/source/data/examples/llm_batch_inference_text/ci/gce.yaml b/doc/source/data/examples/llm_batch_inference_text/ci/gce.yaml new file mode 100644 index 000000000000..deadbcc6461d --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/ci/gce.yaml @@ -0,0 +1,17 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-central1 + +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: n2-standard-8 + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g2-standard-8-nvidia-l4-1 + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true diff --git a/doc/source/data/examples/llm_batch_inference_text/ci/nb2py.py b/doc/source/data/examples/llm_batch_inference_text/ci/nb2py.py new file mode 100644 index 000000000000..5183b4348096 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/ci/nb2py.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +import argparse +import nbformat + + +def convert_notebook( + input_path: str, output_path: str, ignore_cmds: bool = False +) -> None: + """ + Read a Jupyter notebook and write a Python script, converting all %%bash + cells and IPython "!" commands into subprocess.run calls that raise on error. + Cells that load or autoreload extensions are ignored. + """ + nb = nbformat.read(input_path, as_version=4) + with open(output_path, "w") as out: + for cell in nb.cells: + # Only process code cells + if cell.cell_type != "code": + continue + + lines = cell.source.splitlines() + + # Detect a %%bash cell + if lines: + # Detect any IPython '!' shell commands in code lines + has_bang = any(line.lstrip().startswith("!") for line in lines) + # Detect %pip magic commands + has_pip_magic = any(line.lstrip().startswith("%pip") for line in lines) + # Start with "serve run" "serve shutdown" "curl" or "anyscale service" commands + to_ignore_cmd = ( + "serve run", + "serve shutdown", + "curl", + "anyscale service", + ) + has_ignored_start = any( + line.lstrip().startswith(to_ignore_cmd) for line in lines + ) + # Skip %pip cells entirely + if has_pip_magic: + continue + if has_bang or has_ignored_start: + if ignore_cmds: + continue + out.write("import subprocess\n") + for line in lines: + stripped = line.lstrip() + if stripped.startswith("!"): + cmd = stripped[1:].lstrip() + out.write( + f"subprocess.run(r'''{cmd}''',\n" + " shell=True,\n" + " check=True,\n" + " executable='/bin/bash')\n" + ) + else: + out.write(line.rstrip() + "\n") + out.write("\n") + else: + # Regular Python cell: + code = cell.source.rstrip() + if "ds_large = ds.limit(1_000_000)" in code: + # Instead of testing a large dataset in CI, test a small dataset + code = code.replace("ds.limit(1_000_000)", "ds.limit(10_000)") + # else, dump as-is + out.write(code + "\n\n") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert a Jupyter notebook to a Python script, preserving bash cells and '!' commands as subprocess calls unless ignored with --ignore-cmds." + ) + parser.add_argument("input_nb", help="Path to the input .ipynb file") + parser.add_argument("output_py", help="Path for the output .py script") + parser.add_argument( + "--ignore-cmds", action="store_true", help="Ignore bash cells and '!' commands" + ) + args = parser.parse_args() + convert_notebook(args.input_nb, args.output_py, ignore_cmds=args.ignore_cmds) + + +if __name__ == "__main__": + main() diff --git a/doc/source/data/examples/llm_batch_inference_text/ci/tests.sh b/doc/source/data/examples/llm_batch_inference_text/ci/tests.sh new file mode 100755 index 000000000000..928aa464c37a --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/ci/tests.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Install requirements first (done by CI automatically): +# release/ray_release/byod/byod_llm_batch_inference_text.sh + +# Don't use nbconvert or jupytext unless you're willing +# to check each subprocess unit and validate that errors +# aren't being consumed/hidden + +set -exo pipefail + +python ci/nb2py.py "content/README.ipynb" "content/README.py" --ignore-cmds +python "content/README.py" +rm "content/README.py" diff --git a/doc/source/data/examples/llm_batch_inference_text/configs/aws.yaml b/doc/source/data/examples/llm_batch_inference_text/configs/aws.yaml new file mode 100644 index 000000000000..19ee0be9b644 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/configs/aws.yaml @@ -0,0 +1,14 @@ +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: m5.2xlarge + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g6.2xlarge + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true diff --git a/doc/source/data/examples/llm_batch_inference_text/configs/gce.yaml b/doc/source/data/examples/llm_batch_inference_text/configs/gce.yaml new file mode 100644 index 000000000000..49e866eb8b1b --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/configs/gce.yaml @@ -0,0 +1,14 @@ +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: n2-standard-8 + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g2-standard-8-nvidia-l4-1 + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true diff --git a/doc/source/data/examples/llm_batch_inference_text/content/README.ipynb b/doc/source/data/examples/llm_batch_inference_text/content/README.ipynb new file mode 100644 index 000000000000..7b64de18024b --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/content/README.ipynb @@ -0,0 +1,541 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLM batch inference with Ray Data LLM\n", + "\n", + "
\n", + "\n", + "**⏱️ Time to complete**: 15 minutes\n", + "\n", + "This example shows you how to run batch inference for large language models (LLMs) using [Ray Data LLM APIs](https://docs.ray.io/en/latest/data/api/llm.html). In this use case, the batch inference job infers company industries from company names across a large customer dataset.\n", + "\n", + "\n", + "## When to use LLM batch inference\n", + "\n", + "Offline (batch) inference optimizes for throughput over latency. Unlike online inference, which processes requests one at a time in real-time, batch inference processes thousands or millions of inputs together, maximizing GPU utilization and reducing per-inference costs.\n", + "\n", + "Choose batch inference when:\n", + "- You have a fixed dataset to process (such as daily reports or data migrations)\n", + "- Throughput matters more than immediate results\n", + "- You want to take advantage of fault tolerance for long-running jobs\n", + "\n", + "On the contrary, if you are more interested in optimizing for latency, consider [deploying your LLM with Ray Serve LLM for online inference](https://docs.ray.io/en/latest/serve/llm/index.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare a Ray Data dataset\n", + "\n", + "Ray Data LLM runs batch inference for LLMs on Ray Data datasets. In this tutorial, you perform batch inference with an LLM to infer company industries from company names. The source is a 2-million-row CSV file containing sample customer data.\n", + "\n", + "First, load the data from a remote URL then repartition the dataset to ensure the workload can be distributed across multiple GPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "\n", + "# Define the path to the sample CSV file hosted on S3.\n", + "# This dataset contains 2 million rows of synthetic customer data.\n", + "path = \"https://llm-guide.s3.us-west-2.amazonaws.com/data/ray-data-llm/customers-2000000.csv\"\n", + "\n", + "# Load the CSV file into a Ray Dataset.\n", + "print(\"Loading dataset from remote URL...\")\n", + "ds = ray.data.read_csv(path)\n", + "\n", + "# Inspect the dataset schema and a few rows to verify it loaded correctly.\n", + "print(\"\\nDataset schema:\")\n", + "print(ds.schema())\n", + "print(\"\\nSample rows:\")\n", + "ds.show(limit=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this initial example, limit the dataset to 10,000 rows so you can process and test faster. Later, you can scale up to the full dataset.\n", + "\n", + "If you don't repartition, the system might read a large file into only a few blocks, which limits parallelism in later steps. For example, you might see that only 4 out of 8 GPUs in your cluster are being used. To address this, you can repartition the data into a specific number of blocks so the system can better parallelize work across all available GPUs in the pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Limit the dataset to 10,000 rows for this example.\n", + "print(\"Limiting dataset to 10,000 rows for initial processing.\")\n", + "ds_small = ds.limit(10_000)\n", + "\n", + "\n", + "# Repartition the dataset to enable parallelism across multiple workers (GPUs).\n", + "# By default, streaming datasets might not be optimally partitioned. Repartitioning\n", + "# splits the data into a specified number of blocks, allowing Ray to process them\n", + "# in parallel.\n", + "# Tip: Repartition count should typically be 2-4x your worker (GPU) count.\n", + "# Example: 4 GPUs → 8-16 partitions, 10 GPUs → 20-40 partitions.\n", + "# This ensures enough parallelism while avoiding excessive overhead.\n", + "num_partitions = 128\n", + "print(f\"Repartitioning dataset into {num_partitions} blocks for parallelism...\")\n", + "ds_small = ds_small.repartition(num_blocks=num_partitions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Ray Data LLM\n", + "\n", + "Ray Data LLM provides a unified interface to run batch inference with different LLM engines. Configure the vLLM engine, define preprocessing and postprocessing functions, and build the processor.\n", + "\n", + "### Configure the processor engine\n", + "\n", + "Configure the model and compute resources needed for inference using `vLLMEngineProcessorConfig`.\n", + "\n", + "This example uses the `unsloth/Llama-3.1-8B-Instruct` model. The configuration specifies:\n", + "- `model_source`: The Hugging Face model identifier.\n", + "- `engine_kwargs`: vLLM engine parameters such as tensor parallelism and memory settings.\n", + "- `batch_size`: Number of requests to batch together (set to 256 for small prompts and outputs).\n", + "- `accelerator_type`: GPU type to use (L4 in this case).\n", + "- `concurrency`: Number of parallel workers (4 in this case).\n", + "\n", + "**Note:** Because the input prompts and expected output token lengths are small, `batch_size=256` is appropriate. However, depending on your workload, a large batch size can lead to increased idle GPU time when decoding long sequences. Adjust this value to find the optimal trade-off between throughput and latency." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.data.llm import vLLMEngineProcessorConfig\n", + "\n", + "processor_config = vLLMEngineProcessorConfig(\n", + " model_source=\"unsloth/Llama-3.1-8B-Instruct\",\n", + " engine_kwargs=dict(\n", + " max_model_len=256, # Hard cap: system prompt + user prompt + output tokens must fit within this limit\n", + " ),\n", + " batch_size=256,\n", + " accelerator_type=\"L4\",\n", + " concurrency=4,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details on the configuration options you can pass to the vLLM engine, see the [vLLM Engine Arguments documentation](https://docs.vllm.ai/en/stable/configuration/engine_args.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define the preprocess and postprocess functions\n", + "\n", + "The task is to infer the industry category from the `Company` field using an LLM.\n", + "\n", + "Define a preprocess function to prepare `messages` and `sampling_params` for the vLLM engine, and a postprocess function to extract the `generated_text`.\n", + "\n", + "Ray Data LLM sends the output of your preprocessing function directly to a vLLM engine, so you can take advantage of vLLM features. \n", + "\n", + "In this example, you use vLLM's structured output feature to restrict the LLM's responses to a predefined list of industry categories. This increases predictability and helps you reduce the cost of unnecessary output tokens." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "# For better output token control, restrain generation to these choices\n", + "CHOICES = [\n", + " \"Law Firm\",\n", + " \"Healthcare\",\n", + " \"Technology\",\n", + " \"Retail\",\n", + " \"Consulting\",\n", + " \"Manufacturing\",\n", + " \"Finance\",\n", + " \"Real Estate\",\n", + " \"Other\",\n", + "]\n", + "\n", + "# Preprocess function prepares `messages` and `sampling_params` for vLLM engine.\n", + "# All other fields are ignored by the engine.\n", + "def preprocess(row: dict[str, Any]) -> dict[str, Any]:\n", + " return dict(\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a helpful assistant that infers company industries. \"\n", + " \"Based on the company name provided, output only the industry category. \"\n", + " f\"Choose from: {', '.join(CHOICES)}.\"\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"What industry is this company in: {row['Company']}\"\n", + " },\n", + " ],\n", + " sampling_params=dict(\n", + " temperature=0, # Use 0 for deterministic output\n", + " max_tokens=16, # Max output tokens. Industry names are short\n", + " structured_outputs=dict(choice=CHOICES), # Constraint generation\n", + " ),\n", + " )\n", + "\n", + "# Postprocess function extracts the generated text from the engine output.\n", + "# The **row syntax returns all original columns in the input dataset.\n", + "def postprocess(row: dict[str, Any]) -> dict[str, Any]:\n", + " return {\n", + " \"inferred_industry\": row[\"generated_text\"],\n", + " **row, # Include all original columns.\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Build the processor\n", + "\n", + "With the configuration and functions defined, build the processor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.data.llm import build_llm_processor\n", + "\n", + "# Build the LLM processor with the configuration and functions.\n", + "processor = build_llm_processor(\n", + " processor_config,\n", + " preprocess=preprocess,\n", + " postprocess=postprocess,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Process the dataset\n", + "\n", + "Run the processor on your small dataset to perform batch inference. Ray Data automatically distributes the workload across available GPUs and handles batching, retries, and resource management." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "# Run the processor on the small dataset.\n", + "processed_small = processor(ds_small)\n", + "\n", + "# Materialize the dataset to memory.\n", + "# You can also use writing APIs such as write_parquet() or write_csv() to persist the dataset.\n", + "processed_small = processed_small.materialize()\n", + "\n", + "print(f\"\\nProcessed {processed_small.count()} rows successfully.\")\n", + "# Display the first 3 entries to verify the output.\n", + "sampled = processed_small.take(3)\n", + "print(\"\\n==================GENERATED OUTPUT===============\\n\")\n", + "pprint(sampled)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch to production with Anyscale Jobs\n", + "\n", + "For production workloads, deploy your batch inference processor as an [Anyscale Job](https://docs.anyscale.com/platform/jobs). Anyscale takes care of the infrastructure layer and runs your jobs on your dedicated clusters with automatic retries, monitoring, and scheduling.\n", + "\n", + "### Anyscale Runtime\n", + "\n", + "Anyscale Jobs run on [Anyscale Runtime](https://docs.anyscale.com/runtime/data), which includes performance optimizations over open-source Ray Data. Key improvements include faster shuffles, optimized memory management, improved autoscaling, and enhanced fault tolerance for large-scale data processing.\n", + "\n", + "These optimizations are automatic and require no code changes. Your Ray Data pipelines benefit from them simply by running on Anyscale. For batch inference workloads specifically, Anyscale Runtime provides better GPU utilization and reduced overhead when scaling across many nodes.\n", + "\n", + "### Configure an Anyscale Job\n", + "\n", + "Save your batch inference code as `batch_inference_text.py`, then create a job configuration file:\n", + "\n", + "```yaml\n", + "# job.yaml\n", + "name: my-llm-batch-inference-text\n", + "entrypoint: python batch_inference_text.py\n", + "image_uri: anyscale/ray-llm:2.51.1-py311-cu128\n", + "compute_config:\n", + " head_node:\n", + " instance_type: m5.2xlarge\n", + " worker_nodes:\n", + " - instance_type: g6.2xlarge\n", + " min_nodes: 0\n", + " max_nodes: 10\n", + "working_dir: .\n", + "max_retries: 2\n", + "```\n", + "\n", + "### Submit\n", + "\n", + "Submit your job using the Anyscale CLI:\n", + "\n", + "```bash\n", + "anyscale job submit --config-file job.yaml\n", + "```\n", + "\n", + "### Monitoring\n", + "\n", + "Track your job's progress in the Anyscale Console or through the CLI:\n", + "\n", + "```bash\n", + "# Check job status\n", + "anyscale job status --name my-llm-batch-inference-text\n", + "\n", + "# View logs\n", + "anyscale job logs --name my-llm-batch-inference-text\n", + "```\n", + "\n", + "The Ray Dashboard remains available for detailed monitoring. To access it, go to your Anyscale Job in your console. \n", + "For cluster-level information, click the **Metrics** tab then **Data** tab, and for task-level information, click the **Ray Workloads** tab then **Data** tab." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Monitor the execution\n", + "\n", + "Use the Ray Dashboard to monitor the execution. See [Monitoring your Workload](https://docs.ray.io/en/latest/data/monitoring-your-workload.html) for more information on visualizing your Ray Data jobs.\n", + "\n", + "The dashboard shows:\n", + "- Operator-level metrics (throughput, task execution times).\n", + "- Resource utilization (CPU, GPU, memory).\n", + "- Progress and remaining time estimates.\n", + "- Task status breakdown.\n", + "\n", + "**Tip**: If you encounter CUDA out of memory errors, reduce your batch size, use a smaller model, or switch to a larger GPU. For more troubleshooting tips, see [GPU Memory Management](https://docs.ray.io/en/latest/data/working-with-llms.html#gpu-memory-management-and-cuda-oom-prevention)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scale up to larger datasets\n", + "\n", + "Your Ray Data processing pipeline can easily scale up to process more data. By default, this section processes 1M rows." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The dataset has ~2M rows\n", + "# Configure how many images to process (default: 1M for demonstration).\n", + "print(f\"Processing 1M rows... (or the whole dataset if you picked >2M)\")\n", + "ds_large = ds.limit(1_000_000)\n", + "\n", + "# As we increase our compute, we can increase the number of partitions for more parallelism\n", + "num_partitions_large = 256\n", + "print(f\"Repartitioning dataset into {num_partitions_large} blocks for parallelism...\")\n", + "ds_large = ds_large.repartition(num_blocks=num_partitions_large)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can scale the number of concurrent workers based on the compute available in your cluster. In this case, each replica is a copy of your Llama model and fits in a single L4 GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processor_config_large = vLLMEngineProcessorConfig(\n", + " model_source=\"unsloth/Llama-3.1-8B-Instruct\",\n", + " engine_kwargs=dict(\n", + " max_model_len=256, # Hard cap: system prompt + user prompt + output tokens must fit within this limit\n", + " ),\n", + " batch_size=256,\n", + " accelerator_type=\"L4\", # Or upgrade to larger GPU\n", + " concurrency=10, # Deploy 10 workers across 10 GPUs to maximize throughput\n", + ")\n", + "\n", + "# Build the LLM processor with the configuration and functions.\n", + "processor_large = build_llm_processor(\n", + " processor_config_large,\n", + " preprocess=preprocess,\n", + " postprocess=postprocess,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Execute the new pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the same processor on the larger dataset.\n", + "processed_large = processor_large(ds_large)\n", + "processed_large = processed_large.materialize()\n", + "\n", + "print(f\"\\nProcessed {processed_large.count()} rows successfully.\")\n", + "# Display the first 3 entries to verify the output.\n", + "sampled = processed_large.take(3)\n", + "print(\"\\n==================GENERATED OUTPUT===============\\n\")\n", + "pprint(sampled)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance optimization tips\n", + "\n", + "When scaling to larger datasets, consider these optimizations tips. For comprehensive guidance, see the [Ray Data performance guide](https://docs.ray.io/en/latest/data/performance-tips.html) and the [throughput optimization guide with Anyscale](https://docs.anyscale.com/llm/batch-inference/throughput-optimization).\n", + "\n", + "**Analyze your pipeline** \n", + "Use *stats()* to analyze each steps in your pipeline and identify any bottlenecks.\n", + "```python\n", + "processed = processor(ds).materialize()\n", + "print(processed.stats())\n", + "```\n", + "The outputs contains detailed description of each step in your pipeline.\n", + "```text\n", + "Operator 0 ...\n", + "\n", + "...\n", + "\n", + "Operator 8 MapBatches(vLLMEngineStageUDF): 3908 tasks executed, 3908 blocks produced in 340.21s\n", + " * Remote wall time: ...\n", + " ...\n", + "\n", + "...\n", + "\n", + "Dataset throughput:\n", + "\t* Ray Data throughput: ...\n", + "\t* Estimated single node throughput: ...\n", + "```\n", + "\n", + "**Adjust concurrency** \n", + "The `concurrency` parameter controls how many model replicas run in parallel. To determine the right value:\n", + "- *Available GPU count:* Start with the number of GPUs in your cluster. Each replica needs at least one GPU (more if using tensor parallelism).\n", + "- *Model memory footprint:* Ensure your model fits in GPU memory. For example, an 8 B parameter model in FP16 requires ~16 GB, fitting on a single L4 (2 GB) or A10G (24 GB).\n", + "- *CPU-bound preprocessing:* If preprocessing is slower than inference, adding more GPU replicas won't help. Check `stats()` output to identify if preprocessing is the bottleneck.\n", + "\n", + "**Tune batch size** \n", + "The `batch_size` parameter controls how many requests Ray Data sends to vLLM at once. vLLM uses continuous batching internally, controlled by `max_num_seqs` in `engine_kwargs`. This directly impacts GPU memory allocation since vLLM pre-allocates KV cache for up to `max_num_seqs` concurrent sequences.\n", + "\n", + "- *Too small `batch_size`:* vLLM scheduler is under-saturated, risking GPU idle time.\n", + "- *Too large `batch_size`:* vLLM scheduler is over-saturated, causing overhead latency. Also increases retry cost on failure since the entire batch is retried.\n", + "\n", + "You can try the following suggestions:\n", + "1. Start with `batch_size` equal to `max_num_seqs` in your vLLM engine parameters. See [vLLM engine arguments](https://docs.vllm.ai/en/stable/serving/engine_args.html) for defaults.\n", + "2. Monitor GPU utilization in the Ray Dashboard (see [Monitor the execution](#monitor-the-execution) section).\n", + "3. Adjust `max_num_seqs` in `engine_kwargs` to optimize GPU utilization, and re-adapt `batch_size` accordingly.\n", + "\n", + "**Tune preprocessing and inference stage parallelism** \n", + "Use `repartition()` to control parallelism during your preprocessing stage. On the other hand, the number of inference tasks is determined by `dataset_size / batch_size`, where `batch_size` controls how many rows are grouped for each vLLM engine call. Ensure you have enough tasks to keep all workers busy and enable efficient load balancing.\n", + "\n", + "See [Configure parallelism for Ray Data LLM](https://docs.anyscale.com/llm/batch-inference/resource-allocation/concurrency-and-batching.md) for detailed guidance.\n", + "\n", + "**Use quantization to reduce memory footprint** \n", + "Quantization reduces model precision to save GPU memory and improve throughput; vLLM supports this with the `quantization` field in `engine_kwargs`. Note that lower precision may impact output quality, and not all models or GPUs support all quantization types, see [Quantization for LLM batch inference](https://docs.anyscale.com/llm/batch-inference/throughput-optimization/quantization.md) for more guidance.\n", + "\n", + "**Fault tolerance and checkpointing** \n", + "Ray Data automatically handles fault tolerance - if a worker fails, only that worker's current batch is retried. For long-running Anyscale Jobs, you can enable job-level checkpointing to resume from failures. See [Anyscale Runtime checkpointing documentation](https://docs.anyscale.com/runtime/data#enable-job-level-checkpointing) for more information.\n", + "\n", + "**Scale to larger models with model parallelism** \n", + "Model parallelism distributes large models across multiple GPUs when they don't fit on a single GPU. Use tensor parallelism to split model layers horizontally across multiple GPUs within a single node and use pipeline parallelism to split model layers vertically across multiple nodes, with each node processing different layers of the model.\n", + "\n", + "Forward model parallelism parameters to your inference engine using the `engine_kwargs` argument of your `vLLMEngineProcessorConfig` object. If your GPUs span multiple nodes, set `ray` as the distributed executor backend to enable cross-node parallelism. This example snippet uses DeepSeek-R1, a large reasoning model requiring multiple GPUs over multiple nodes:\n", + "\n", + "```python\n", + "processor_config = vLLMEngineProcessorConfig(\n", + " model_source=\"deepseek-ai/DeepSeek-R1\",\n", + " accelerator_type=\"H100\",\n", + " engine_kwargs={\n", + " \"tensor_parallel_size\": 8, # 8 GPUs per node\n", + " \"pipeline_parallel_size\": 2, # Split across 2 nodes\n", + " \"distributed_executor_backend\": \"ray\", # Required to enable cross-node parallelism\n", + "\n", + " },\n", + " concurrency=1,\n", + ")\n", + "# Each worker uses: 8 GPUs × 2 nodes = 16 GPUs total\n", + "```\n", + "\n", + "Each inference worker allocates GPUs based on `tensor_parallel_size × pipeline_parallel_size`. For detailed guidance on parallelism strategies, see the [vLLM parallelism and scaling documentation](https://docs.vllm.ai/en/stable/serving/distributed_serving.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this example, you built an end-to-end batch pipeline: loading a customer dataset from S3 into a Ray Dataset, configuring a vLLM processor for Llama 3.1 8 B, and adding simple pre/post-processing to infer company industries. You validated the flow on 10,000 rows, scaled to 1M+ records, monitored progress in the Ray Dashboard, and saved the results to persistent storage.\n", + "\n", + "See [Anyscale batch inference optimization](https://docs.anyscale.com/llm/batch-inference) for more information on using Ray Data with Anyscale and for more advanced use cases, see [Working with LLMs](https://docs.ray.io/en/latest/data/working-with-llms.html)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + }, + "orphan": true + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/source/data/examples/llm_batch_inference_text/content/README.md b/doc/source/data/examples/llm_batch_inference_text/content/README.md new file mode 100644 index 000000000000..c988e419d171 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/content/README.md @@ -0,0 +1,409 @@ + + +# LLM batch inference with Ray Data LLM + + + +**⏱️ Time to complete**: 15 minutes + +This example shows you how to run batch inference for large language models (LLMs) using [Ray Data LLM APIs](https://docs.ray.io/en/latest/data/api/llm.html). In this use case, the batch inference job infers company industries from company names across a large customer dataset. + + +## When to use LLM batch inference + +Offline (batch) inference optimizes for throughput over latency. Unlike online inference, which processes requests one at a time in real-time, batch inference processes thousands or millions of inputs together, maximizing GPU utilization and reducing per-inference costs. + +Choose batch inference when: +- You have a fixed dataset to process (such as daily reports or data migrations) +- Throughput matters more than immediate results +- You want to take advantage of fault tolerance for long-running jobs + +On the contrary, if you are more interested in optimizing for latency, consider [deploying your LLM with Ray Serve LLM for online inference](https://docs.ray.io/en/latest/serve/llm/index.html). + +## Prepare a Ray Data dataset + +Ray Data LLM runs batch inference for LLMs on Ray Data datasets. In this tutorial, you perform batch inference with an LLM to infer company industries from company names. The source is a 2-million-row CSV file containing sample customer data. + +First, load the data from a remote URL then repartition the dataset to ensure the workload can be distributed across multiple GPUs. + + +```python +import ray + +# Define the path to the sample CSV file hosted on S3. +# This dataset contains 2 million rows of synthetic customer data. +path = "https://llm-guide.s3.us-west-2.amazonaws.com/data/ray-data-llm/customers-2000000.csv" + +# Load the CSV file into a Ray Dataset. +print("Loading dataset from remote URL...") +ds = ray.data.read_csv(path) + +# Inspect the dataset schema and a few rows to verify it loaded correctly. +print("\nDataset schema:") +print(ds.schema()) +print("\nSample rows:") +ds.show(limit=2) +``` + +For this initial example, limit the dataset to 10,000 rows so you can process and test faster. Later, you can scale up to the full dataset. + +If you don't repartition, the system might read a large file into only a few blocks, which limits parallelism in later steps. For example, you might see that only 4 out of 8 GPUs in your cluster are being used. To address this, you can repartition the data into a specific number of blocks so the system can better parallelize work across all available GPUs in the pipeline. + + +```python +# Limit the dataset to 10,000 rows for this example. +print("Limiting dataset to 10,000 rows for initial processing.") +ds_small = ds.limit(10_000) + + +# Repartition the dataset to enable parallelism across multiple workers (GPUs). +# By default, streaming datasets might not be optimally partitioned. Repartitioning +# splits the data into a specified number of blocks, allowing Ray to process them +# in parallel. +# Tip: Repartition count should typically be 2-4x your worker (GPU) count. +# Example: 4 GPUs → 8-16 partitions, 10 GPUs → 20-40 partitions. +# This ensures enough parallelism while avoiding excessive overhead. +num_partitions = 128 +print(f"Repartitioning dataset into {num_partitions} blocks for parallelism...") +ds_small = ds_small.repartition(num_blocks=num_partitions) +``` + +## Configure Ray Data LLM + +Ray Data LLM provides a unified interface to run batch inference with different LLM engines. Configure the vLLM engine, define preprocessing and postprocessing functions, and build the processor. + +### Configure the processor engine + +Configure the model and compute resources needed for inference using `vLLMEngineProcessorConfig`. + +This example uses the `unsloth/Llama-3.1-8B-Instruct` model. The configuration specifies: +- `model_source`: The Hugging Face model identifier. +- `engine_kwargs`: vLLM engine parameters such as tensor parallelism and memory settings. +- `batch_size`: Number of requests to batch together (set to 256 for small prompts and outputs). +- `accelerator_type`: GPU type to use (L4 in this case). +- `concurrency`: Number of parallel workers (4 in this case). + +**Note:** Because the input prompts and expected output token lengths are small, `batch_size=256` is appropriate. However, depending on your workload, a large batch size can lead to increased idle GPU time when decoding long sequences. Adjust this value to find the optimal trade-off between throughput and latency. + + +```python +from ray.data.llm import vLLMEngineProcessorConfig + +processor_config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs=dict( + max_model_len=256, # Hard cap: system prompt + user prompt + output tokens must fit within this limit + ), + batch_size=256, + accelerator_type="L4", + concurrency=4, +) +``` + +For more details on the configuration options you can pass to the vLLM engine, see the [vLLM Engine Arguments documentation](https://docs.vllm.ai/en/stable/configuration/engine_args.html). + +### Define the preprocess and postprocess functions + +The task is to infer the industry category from the `Company` field using an LLM. + +Define a preprocess function to prepare `messages` and `sampling_params` for the vLLM engine, and a postprocess function to extract the `generated_text`. + +Ray Data LLM sends the output of your preprocessing function directly to a vLLM engine, so you can take advantage of vLLM features. + +In this example, you use vLLM's structured output feature to restrict the LLM's responses to a predefined list of industry categories. This increases predictability and helps you reduce the cost of unnecessary output tokens. + + +```python +from typing import Any + +# For better output token control, restrain generation to these choices +CHOICES = [ + "Law Firm", + "Healthcare", + "Technology", + "Retail", + "Consulting", + "Manufacturing", + "Finance", + "Real Estate", + "Other", +] + +# Preprocess function prepares `messages` and `sampling_params` for vLLM engine. +# All other fields are ignored by the engine. +def preprocess(row: dict[str, Any]) -> dict[str, Any]: + return dict( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that infers company industries. " + "Based on the company name provided, output only the industry category. " + f"Choose from: {', '.join(CHOICES)}." + }, + { + "role": "user", + "content": f"What industry is this company in: {row['Company']}" + }, + ], + sampling_params=dict( + temperature=0, # Use 0 for deterministic output + max_tokens=16, # Max output tokens. Industry names are short + structured_outputs=dict(choice=CHOICES), # Constraint generation + ), + ) + +# Postprocess function extracts the generated text from the engine output. +# The **row syntax returns all original columns in the input dataset. +def postprocess(row: dict[str, Any]) -> dict[str, Any]: + return { + "inferred_industry": row["generated_text"], + **row, # Include all original columns. + } +``` + +### Build the processor + +With the configuration and functions defined, build the processor. + + +```python +from ray.data.llm import build_llm_processor + +# Build the LLM processor with the configuration and functions. +processor = build_llm_processor( + processor_config, + preprocess=preprocess, + postprocess=postprocess, +) +``` + +## Process the dataset + +Run the processor on your small dataset to perform batch inference. Ray Data automatically distributes the workload across available GPUs and handles batching, retries, and resource management. + + +```python +from pprint import pprint + +# Run the processor on the small dataset. +processed_small = processor(ds_small) + +# Materialize the dataset to memory. +# You can also use writing APIs such as write_parquet() or write_csv() to persist the dataset. +processed_small = processed_small.materialize() + +print(f"\nProcessed {processed_small.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_small.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) +``` + +## Launch to production with Anyscale Jobs + +For production workloads, deploy your batch inference processor as an [Anyscale Job](https://docs.anyscale.com/platform/jobs). Anyscale takes care of the infrastructure layer and runs your jobs on your dedicated clusters with automatic retries, monitoring, and scheduling. + +### Anyscale Runtime + +Anyscale Jobs run on [Anyscale Runtime](https://docs.anyscale.com/runtime/data), which includes performance optimizations over open-source Ray Data. Key improvements include faster shuffles, optimized memory management, improved autoscaling, and enhanced fault tolerance for large-scale data processing. + +These optimizations are automatic and require no code changes. Your Ray Data pipelines benefit from them simply by running on Anyscale. For batch inference workloads specifically, Anyscale Runtime provides better GPU utilization and reduced overhead when scaling across many nodes. + +### Configure an Anyscale Job + +Save your batch inference code as `batch_inference_text.py`, then create a job configuration file: + +```yaml +# job.yaml +name: my-llm-batch-inference-text +entrypoint: python batch_inference_text.py +image_uri: anyscale/ray-llm:2.51.1-py311-cu128 +compute_config: + head_node: + instance_type: m5.2xlarge + worker_nodes: + - instance_type: g6.2xlarge + min_nodes: 0 + max_nodes: 10 +working_dir: . +max_retries: 2 +``` + +### Submit + +Submit your job using the Anyscale CLI: + +```bash +anyscale job submit --config-file job.yaml +``` + +### Monitoring + +Track your job's progress in the Anyscale Console or through the CLI: + +```bash +# Check job status +anyscale job status --name my-llm-batch-inference-text + +# View logs +anyscale job logs --name my-llm-batch-inference-text +``` + +The Ray Dashboard remains available for detailed monitoring. To access it, go to your Anyscale Job in your console. +For cluster-level information, click the **Metrics** tab then **Data** tab, and for task-level information, click the **Ray Workloads** tab then **Data** tab. + +## Monitor the execution + +Use the Ray Dashboard to monitor the execution. See [Monitoring your Workload](https://docs.ray.io/en/latest/data/monitoring-your-workload.html) for more information on visualizing your Ray Data jobs. + +The dashboard shows: +- Operator-level metrics (throughput, task execution times). +- Resource utilization (CPU, GPU, memory). +- Progress and remaining time estimates. +- Task status breakdown. + +**Tip**: If you encounter CUDA out of memory errors, reduce your batch size, use a smaller model, or switch to a larger GPU. For more troubleshooting tips, see [GPU Memory Management](https://docs.ray.io/en/latest/data/working-with-llms.html#gpu-memory-management-and-cuda-oom-prevention). + +## Scale up to larger datasets + +Your Ray Data processing pipeline can easily scale up to process more data. By default, this section processes 1M rows. + + +```python +# The dataset has ~2M rows +# Configure how many images to process (default: 1M for demonstration). +print(f"Processing 1M rows... (or the whole dataset if you picked >2M)") +ds_large = ds.limit(1_000_000) + +# As we increase our compute, we can increase the number of partitions for more parallelism +num_partitions_large = 256 +print(f"Repartitioning dataset into {num_partitions_large} blocks for parallelism...") +ds_large = ds_large.repartition(num_blocks=num_partitions_large) +``` + +You can scale the number of concurrent workers based on the compute available in your cluster. In this case, each replica is a copy of your Llama model and fits in a single L4 GPU. + + +```python +processor_config_large = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs=dict( + max_model_len=256, # Hard cap: system prompt + user prompt + output tokens must fit within this limit + ), + batch_size=256, + accelerator_type="L4", # Or upgrade to larger GPU + concurrency=10, # Deploy 10 workers across 10 GPUs to maximize throughput +) + +# Build the LLM processor with the configuration and functions. +processor_large = build_llm_processor( + processor_config_large, + preprocess=preprocess, + postprocess=postprocess, +) +``` + +Execute the new pipeline + + +```python +# Run the same processor on the larger dataset. +processed_large = processor_large(ds_large) +processed_large = processed_large.materialize() + +print(f"\nProcessed {processed_large.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_large.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) +``` + +## Performance optimization tips + +When scaling to larger datasets, consider these optimizations tips. For comprehensive guidance, see the [Ray Data performance guide](https://docs.ray.io/en/latest/data/performance-tips.html) and the [throughput optimization guide with Anyscale](https://docs.anyscale.com/llm/batch-inference/throughput-optimization). + +**Analyze your pipeline** +Use *stats()* to analyze each steps in your pipeline and identify any bottlenecks. +```python +processed = processor(ds).materialize() +print(processed.stats()) +``` +The outputs contains detailed description of each step in your pipeline. +```text +Operator 0 ... + +... + +Operator 8 MapBatches(vLLMEngineStageUDF): 3908 tasks executed, 3908 blocks produced in 340.21s + * Remote wall time: ... + ... + +... + +Dataset throughput: + * Ray Data throughput: ... + * Estimated single node throughput: ... +``` + +**Adjust concurrency** +The `concurrency` parameter controls how many model replicas run in parallel. To determine the right value: +- *Available GPU count:* Start with the number of GPUs in your cluster. Each replica needs at least one GPU (more if using tensor parallelism). +- *Model memory footprint:* Ensure your model fits in GPU memory. For example, an 8 B parameter model in FP16 requires ~16 GB, fitting on a single L4 (2 GB) or A10G (24 GB). +- *CPU-bound preprocessing:* If preprocessing is slower than inference, adding more GPU replicas won't help. Check `stats()` output to identify if preprocessing is the bottleneck. + +**Tune batch size** +The `batch_size` parameter controls how many requests Ray Data sends to vLLM at once. vLLM uses continuous batching internally, controlled by `max_num_seqs` in `engine_kwargs`. This directly impacts GPU memory allocation since vLLM pre-allocates KV cache for up to `max_num_seqs` concurrent sequences. + +- *Too small `batch_size`:* vLLM scheduler is under-saturated, risking GPU idle time. +- *Too large `batch_size`:* vLLM scheduler is over-saturated, causing overhead latency. Also increases retry cost on failure since the entire batch is retried. + +You can try the following suggestions: +1. Start with `batch_size` equal to `max_num_seqs` in your vLLM engine parameters. See [vLLM engine arguments](https://docs.vllm.ai/en/stable/serving/engine_args.html) for defaults. +2. Monitor GPU utilization in the Ray Dashboard (see [Monitor the execution](#monitor-the-execution) section). +3. Adjust `max_num_seqs` in `engine_kwargs` to optimize GPU utilization, and re-adapt `batch_size` accordingly. + +**Tune preprocessing and inference stage parallelism** +Use `repartition()` to control parallelism during your preprocessing stage. On the other hand, the number of inference tasks is determined by `dataset_size / batch_size`, where `batch_size` controls how many rows are grouped for each vLLM engine call. Ensure you have enough tasks to keep all workers busy and enable efficient load balancing. + +See [Configure parallelism for Ray Data LLM](https://docs.anyscale.com/llm/batch-inference/resource-allocation/concurrency-and-batching.md) for detailed guidance. + +**Use quantization to reduce memory footprint** +Quantization reduces model precision to save GPU memory and improve throughput; vLLM supports this with the `quantization` field in `engine_kwargs`. Note that lower precision may impact output quality, and not all models or GPUs support all quantization types, see [Quantization for LLM batch inference](https://docs.anyscale.com/llm/batch-inference/throughput-optimization/quantization.md) for more guidance. + +**Fault tolerance and checkpointing** +Ray Data automatically handles fault tolerance - if a worker fails, only that worker's current batch is retried. For long-running Anyscale Jobs, you can enable job-level checkpointing to resume from failures. See [Anyscale Runtime checkpointing documentation](https://docs.anyscale.com/runtime/data#enable-job-level-checkpointing) for more information. + +**Scale to larger models with model parallelism** +Model parallelism distributes large models across multiple GPUs when they don't fit on a single GPU. Use tensor parallelism to split model layers horizontally across multiple GPUs within a single node and use pipeline parallelism to split model layers vertically across multiple nodes, with each node processing different layers of the model. + +Forward model parallelism parameters to your inference engine using the `engine_kwargs` argument of your `vLLMEngineProcessorConfig` object. If your GPUs span multiple nodes, set `ray` as the distributed executor backend to enable cross-node parallelism. This example snippet uses DeepSeek-R1, a large reasoning model requiring multiple GPUs over multiple nodes: + +```python +processor_config = vLLMEngineProcessorConfig( + model_source="deepseek-ai/DeepSeek-R1", + accelerator_type="H100", + engine_kwargs={ + "tensor_parallel_size": 8, # 8 GPUs per node + "pipeline_parallel_size": 2, # Split across 2 nodes + "distributed_executor_backend": "ray", # Required to enable cross-node parallelism + + }, + concurrency=1, +) +# Each worker uses: 8 GPUs × 2 nodes = 16 GPUs total +``` + +Each inference worker allocates GPUs based on `tensor_parallel_size × pipeline_parallel_size`. For detailed guidance on parallelism strategies, see the [vLLM parallelism and scaling documentation](https://docs.vllm.ai/en/stable/serving/distributed_serving.html). + +## Summary + +In this example, you built an end-to-end batch pipeline: loading a customer dataset from S3 into a Ray Dataset, configuring a vLLM processor for Llama 3.1 8 B, and adding simple pre/post-processing to infer company industries. You validated the flow on 10,000 rows, scaled to 1M+ records, monitored progress in the Ray Dashboard, and saved the results to persistent storage. + +See [Anyscale batch inference optimization](https://docs.anyscale.com/llm/batch-inference) for more information on using Ray Data with Anyscale and for more advanced use cases, see [Working with LLMs](https://docs.ray.io/en/latest/data/working-with-llms.html). diff --git a/doc/source/data/examples/llm_batch_inference_text/content/batch_inference_text.py b/doc/source/data/examples/llm_batch_inference_text/content/batch_inference_text.py new file mode 100644 index 000000000000..e1b1498ce861 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/content/batch_inference_text.py @@ -0,0 +1,105 @@ +from typing import Any + +from pprint import pprint +import ray +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig +from vllm.sampling_params import StructuredOutputsParams + +DATASET_LIMIT = 10_000 + +# Define the path to the sample CSV file hosted on S3. +# This dataset contains 2 million rows of synthetic customer data. +path = "https://llm-guide.s3.us-west-2.amazonaws.com/data/ray-data-llm/customers-2000000.csv" + +# Load the CSV file into a Ray Dataset. +print("Loading dataset from remote URL...") +ds = ray.data.read_csv(path) + +# Limit the dataset. If DATASET_LIMIT > dataset size, the entire dataset will be processed. +print(f"Limiting dataset to {DATASET_LIMIT} images for initial processing.") +ds_small = ds.limit(DATASET_LIMIT) + +# Repartition the dataset to enable parallelism across multiple workers (GPUs). +# By default, streaming datasets might not be optimally partitioned. Repartitioning +# splits the data into a specified number of blocks, allowing Ray to process them +# in parallel. +# Tip: Repartition count should typically be 2-4x your worker (GPU) count. +# Example: 4 GPUs → 8-16 partitions, 10 GPUs → 20-40 partitions. +# This ensures enough parallelism while avoiding excessive overhead. +num_partitions = 128 +print(f"Repartitioning dataset into {num_partitions} blocks for parallelism...") +ds_small = ds_small.repartition(num_blocks=num_partitions) + +processor_config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs=dict( + max_model_len=256, # Hard cap: system prompt + user prompt + output tokens must fit within this limit + ), + batch_size=256, + accelerator_type="L4", + concurrency=4, +) + +# For better output token control, restrain generation to these choices +CHOICES = [ + "Law Firm", + "Healthcare", + "Technology", + "Retail", + "Consulting", + "Manufacturing", + "Finance", + "Real Estate", + "Other", +] + +# Preprocess function prepares `messages` and `sampling_params` for vLLM engine. +# All other fields are ignored by the engine. +def preprocess(row: dict[str, Any]) -> dict[str, Any]: + return dict( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that infers company industries. " + "Based on the company name provided, output only the industry category. " + "Choose from: Law Firm, Healthcare, Technology, Retail, Consulting, Manufacturing, Finance, Real Estate, Other." + }, + { + "role": "user", + "content": f"What industry is this company in: {row['Company']}" + }, + ], + sampling_params=dict( + temperature=0, # Use 0 for deterministic output + max_tokens=16, # Max output tokens. Industry names are short + structured_outputs=dict(choice=CHOICES), # Constraint generation + ), + ) + +# Postprocess function extracts the generated text from the engine output. +# The **row syntax returns all original columns in the input dataset. +def postprocess(row: dict[str, Any]) -> dict[str, Any]: + return { + "inferred_industry": row["generated_text"], + **row, # Include all original columns. + } + +# Build the LLM processor with the configuration and functions. +processor = build_llm_processor( + processor_config, + preprocess=preprocess, + postprocess=postprocess, +) + +# Run the processor on the small dataset. +processed_small = processor(ds_small) + +# Materialize the dataset to memory. +# You can also use writing APIs such as write_parquet() or write_csv() to persist the dataset. +processed_small = processed_small.materialize() + +print(f"\nProcessed {processed_small.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_small.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) diff --git a/doc/source/data/examples/llm_batch_inference_text/content/batch_inference_text_scaled.py b/doc/source/data/examples/llm_batch_inference_text/content/batch_inference_text_scaled.py new file mode 100644 index 000000000000..df2b5672f588 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/content/batch_inference_text_scaled.py @@ -0,0 +1,97 @@ +from typing import Any + +from pprint import pprint +import ray +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig +from vllm.sampling_params import StructuredOutputsParams + +DATASET_LIMIT = 1_000_000 + +# Define the path to the sample CSV file hosted on S3. +# This dataset contains 2 million rows of synthetic customer data. +path = "https://llm-guide.s3.us-west-2.amazonaws.com/data/ray-data-llm/customers-2000000.csv" + +# Load the CSV file into a Ray Dataset. +print("Loading dataset from remote URL...") +ds = ray.data.read_csv(path) + +# Limit the dataset. If DATASET_LIMIT > dataset size, the entire dataset will be processed. +print(f"Limiting dataset to {DATASET_LIMIT} images for initial processing.") +ds_large = ds.limit(DATASET_LIMIT) + +# As we increase our compute, we can increase the number of partitions for more parallelism +num_partitions_large = 256 +print(f"Repartitioning dataset into {num_partitions_large} blocks for parallelism...") +ds_large = ds_large.repartition(num_blocks=num_partitions_large) + +processor_config_large = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs=dict( + max_model_len=256, # Hard cap: system prompt + user prompt + output tokens must fit within this limit + ), + batch_size=256, + accelerator_type="L4", # Or upgrade to larger GPU + concurrency=10, # Deploy 10 workers across 10 GPUs to maximize throughput +) + +# For better output token control, restrain generation to these choices +CHOICES = [ + "Law Firm", + "Healthcare", + "Technology", + "Retail", + "Consulting", + "Manufacturing", + "Finance", + "Real Estate", + "Other", +] + +# Preprocess function prepares `messages` and `sampling_params` for vLLM engine. +# All other fields are ignored by the engine. +def preprocess(row: dict[str, Any]) -> dict[str, Any]: + return dict( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that infers company industries. " + "Based on the company name provided, output only the industry category. " + "Choose from: Law Firm, Healthcare, Technology, Retail, Consulting, Manufacturing, Finance, Real Estate, Other." + }, + { + "role": "user", + "content": f"What industry is this company in: {row['Company']}" + }, + ], + sampling_params=dict( + temperature=0, # Use 0 for deterministic output + max_tokens=16, # Max output tokens. Industry names are short + structured_outputs=dict(choice=CHOICES), # Constraint generation + ), + ) + +# Postprocess function extracts the generated text from the engine output. +# The **row syntax returns all original columns in the input dataset. +def postprocess(row: dict[str, Any]) -> dict[str, Any]: + return { + "inferred_industry": row["generated_text"], + **row, # Include all original columns. + } + +# Build the LLM processor with the configuration and functions. +processor_large = build_llm_processor( + processor_config_large, + preprocess=preprocess, + postprocess=postprocess, +) + + +# Run the same processor on the larger dataset. +processed_large = processor_large(ds_large) +processed_large = processed_large.materialize() + +print(f"\nProcessed {processed_large.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_large.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) diff --git a/doc/source/data/examples/llm_batch_inference_text/content/job.yaml b/doc/source/data/examples/llm_batch_inference_text/content/job.yaml new file mode 100644 index 000000000000..5fe12d7c4b72 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/content/job.yaml @@ -0,0 +1,12 @@ +name: my-llm-batch-inference-text +entrypoint: python batch_inference_text.py +image_uri: anyscale/ray-llm:2.51.1-py311-cu128 +compute_config: + head_node: + instance_type: m5.2xlarge + worker_nodes: + - instance_type: g6.2xlarge + min_nodes: 0 + max_nodes: 10 +working_dir: . +max_retries: 2 diff --git a/doc/source/data/examples/llm_batch_inference_text/convert_to_md.sh b/doc/source/data/examples/llm_batch_inference_text/convert_to_md.sh new file mode 100755 index 000000000000..7ef7f34bceb1 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_text/convert_to_md.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -exo pipefail + +nb_rel_path="content/README.ipynb" +content_dir="$(dirname "$nb_rel_path")" +nb_filename="$(basename "$nb_rel_path")" +md_rel_path="$content_dir/README.md" + +# Delete README if it already exists +[ -f "$md_rel_path" ] && rm "$md_rel_path" + +# Convert notebook to Markdown +jupyter nbconvert "$nb_rel_path" --to markdown --output README.md + +# Prepend warning comment (will be hidden when rendered in the console) +tmp_file="$(mktemp)" +{ +echo "" +echo "" +cat "$md_rel_path" +} > "$tmp_file" +mv "$tmp_file" "$md_rel_path" diff --git a/doc/source/data/examples/llm_batch_inference_vision/ci/aws.yaml b/doc/source/data/examples/llm_batch_inference_vision/ci/aws.yaml new file mode 100644 index 000000000000..5438eeda3105 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/ci/aws.yaml @@ -0,0 +1,18 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: m5.2xlarge + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g6.2xlarge + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true + diff --git a/doc/source/data/examples/llm_batch_inference_vision/ci/gce.yaml b/doc/source/data/examples/llm_batch_inference_vision/ci/gce.yaml new file mode 100644 index 000000000000..24dc4b9106f1 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/ci/gce.yaml @@ -0,0 +1,18 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-central1 + +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: n2-standard-8 + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g2-standard-8-nvidia-l4-1 + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true + diff --git a/doc/source/data/examples/llm_batch_inference_vision/ci/nb2py.py b/doc/source/data/examples/llm_batch_inference_vision/ci/nb2py.py new file mode 100644 index 000000000000..b0a055698286 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/ci/nb2py.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +import argparse +import nbformat + + +def convert_notebook( + input_path: str, output_path: str, ignore_cmds: bool = False +) -> None: + """ + Read a Jupyter notebook and write a Python script, converting all %%bash + cells and IPython "!" commands into subprocess.run calls that raise on error. + Cells that load or autoreload extensions are ignored. + """ + nb = nbformat.read(input_path, as_version=4) + with open(output_path, "w") as out: + for cell in nb.cells: + # Only process code cells + if cell.cell_type != "code": + continue + + lines = cell.source.splitlines() + + if lines: + # Detect any IPython '!' shell commands in code lines + has_bang = any(line.lstrip().startswith("!") for line in lines) + # Detect %pip magic commands + has_pip_magic = any(line.lstrip().startswith("%pip") for line in lines) + # Start with "serve run" "serve shutdown" "curl" or "anyscale service" commands + to_ignore_cmd = ( + "serve run", + "serve shutdown", + "curl", + "anyscale service", + ) + has_ignored_start = any( + line.lstrip().startswith(to_ignore_cmd) for line in lines + ) + # Skip %pip cells entirely + if has_pip_magic: + continue + if has_bang or has_ignored_start: + if ignore_cmds: + continue + out.write("import subprocess\n") + for line in lines: + stripped = line.lstrip() + if stripped.startswith("!"): + cmd = stripped[1:].lstrip() + out.write( + f"subprocess.run(r'''{cmd}''',\n" + " shell=True,\n" + " check=True,\n" + " executable='/bin/bash')\n" + ) + else: + out.write(line.rstrip() + "\n") + out.write("\n") + else: + # Regular Python cell: + code = cell.source.rstrip() + if "ds_large = ds.limit(1_000_000)" in code: + # Instead of testing a large dataset in CI, test a small dataset + code = code.replace("ds.limit(1_000_000)", "ds.limit(10_000)") + # else, dump as-is + out.write(code + "\n\n") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert a Jupyter notebook to a Python script, preserving bash cells and '!' commands as subprocess calls unless ignored with --ignore-cmds." + ) + parser.add_argument("input_nb", help="Path to the input .ipynb file") + parser.add_argument("output_py", help="Path for the output .py script") + parser.add_argument( + "--ignore-cmds", action="store_true", help="Ignore bash cells and '!' commands" + ) + args = parser.parse_args() + convert_notebook(args.input_nb, args.output_py, ignore_cmds=args.ignore_cmds) + + +if __name__ == "__main__": + main() diff --git a/doc/source/data/examples/llm_batch_inference_vision/ci/tests.sh b/doc/source/data/examples/llm_batch_inference_vision/ci/tests.sh new file mode 100755 index 000000000000..93e957238ec4 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/ci/tests.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Install requirements first (done by CI automatically): +# release/ray_release/byod/byod_llm_batch_inference_vision.sh + +# Don't use nbconvert or jupytext unless you're willing +# to check each subprocess unit and validate that errors +# aren't being consumed/hidden + +set -exo pipefail + +python ci/nb2py.py "content/README.ipynb" "content/README.py" --ignore-cmds +python "content/README.py" +rm "content/README.py" + diff --git a/doc/source/data/examples/llm_batch_inference_vision/configs/aws.yaml b/doc/source/data/examples/llm_batch_inference_vision/configs/aws.yaml new file mode 100644 index 000000000000..6a940933e3de --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/configs/aws.yaml @@ -0,0 +1,15 @@ +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: m5.2xlarge + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g6.2xlarge + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true + diff --git a/doc/source/data/examples/llm_batch_inference_vision/configs/gce.yaml b/doc/source/data/examples/llm_batch_inference_vision/configs/gce.yaml new file mode 100644 index 000000000000..da09f147dfa7 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/configs/gce.yaml @@ -0,0 +1,15 @@ +# Head node +head_node_type: + name: 8CPU-32GB + instance_type: n2-standard-8 + +# Worker nodes +worker_node_types: + - name: 1xL4:8CPU-32GB + instance_type: g2-standard-8-nvidia-l4-1 + min_workers: 0 + max_workers: 10 + +flags: + allow-cross-zone-autoscaling: true + diff --git a/doc/source/data/examples/llm_batch_inference_vision/content/README.ipynb b/doc/source/data/examples/llm_batch_inference_vision/content/README.ipynb new file mode 100644 index 000000000000..bb03947ee11e --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/content/README.ipynb @@ -0,0 +1,586 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multimodal LLM batch inference with Ray Data LLM\n", + "\n", + "\n", + "\n", + "**⏱️ Time to complete**: 20 minutes\n", + "\n", + "This example shows you how to run batch inference for vision-language models (VLMs) using [Ray Data LLM APIs](https://docs.ray.io/en/latest/data/api/llm.html). In this use case, the batch inference job generates captions for a large-scale image dataset.\n", + "\n", + "## When to use LLM batch inference\n", + "\n", + "Offline (batch) inference optimizes for throughput over latency. Unlike online inference, which processes requests one at a time in real-time, batch inference processes thousands or millions of inputs together, maximizing GPU utilization and reducing per-inference costs.\n", + "\n", + "Choose batch inference when:\n", + "- You have a fixed dataset to process (such as daily reports or data migrations)\n", + "- Throughput matters more than immediate results\n", + "- You want to take advantage of fault tolerance for long-running jobs\n", + "\n", + "On the contrary, if you are more interested in optimizing for latency, consider [deploying your LLM with Ray Serve LLM for online inference](https://docs.ray.io/en/latest/serve/llm/index.html).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare a Ray Data dataset with images\n", + "\n", + "Ray Data LLM runs batch inference for VLMs on Ray Data datasets containing images. In this tutorial, you perform batch inference with a vision-language model to generate image captions from the `BLIP3o/BLIP3o-Pretrain-Short-Caption` dataset, which contains approximately 5 million images.\n", + "\n", + "First, load the data from a remote URL then repartition the dataset to ensure the workload can be distributed across multiple GPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install datasets==4.4.2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "import datasets\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "\n", + "# Load the BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face with ~5M images.\n", + "print(\"Loading BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face...\")\n", + "hf_dataset = datasets.load_dataset(\"BLIP3o/BLIP3o-Pretrain-Short-Caption\", split=\"train\", streaming=True)\n", + "hf_dataset = hf_dataset.select_columns([\"jpg\"])\n", + "\n", + "ds = ray.data.from_huggingface(hf_dataset)\n", + "print(\"Dataset loaded successfully.\")\n", + "\n", + "sample = ds.take(2)\n", + "print(\"Sample data:\")\n", + "for i, item in enumerate(sample):\n", + " print(f\"\\nSample {i+1}:\")\n", + " image = Image.open(BytesIO(item['jpg']['bytes']))\n", + " image.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this initial example, limit the dataset to 10,000 rows so you can process and test faster. Later, you can scale up to the full dataset.\n", + "\n", + "If you don't repartition, the system might read a large file into only a few blocks, which limits parallelism in later steps. For example, you might see that only 4 out of 8 GPUs in your cluster are being used. To address this, you can repartition the data into a specific number of blocks so the system can better parallelize work across all available GPUs in the pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Limit the dataset to 10,000 images for this example.\n", + "print(\"Limiting dataset to 10,000 images for initial processing.\")\n", + "ds_small = ds.limit(10_000)\n", + "\n", + "# Repartition the dataset to enable parallelism across multiple workers (GPUs).\n", + "# By default, streaming datasets might not be optimally partitioned. Repartitioning\n", + "# splits the data into a specified number of blocks, allowing Ray to process them\n", + "# in parallel.\n", + "# Tip: Repartition count should typically be 2-4x your worker (GPU) count.\n", + "# Example: 4 GPUs → 8-16 partitions, 10 GPUs → 20-40 partitions.\n", + "# This ensures enough parallelism while avoiding excessive overhead.\n", + "num_partitions = 64\n", + "print(f\"Repartitioning dataset into {num_partitions} blocks for parallelism...\")\n", + "ds_small = ds_small.repartition(num_blocks=num_partitions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Ray Data LLM\n", + "\n", + "Ray Data LLM provides a unified interface to run batch inference with different VLM engines. Configure the vLLM engine with a vision-language model, define preprocessing and postprocessing functions, and build the processor.\n", + "\n", + "### Configure the processor engine\n", + "\n", + "Configure the model and compute resources needed for inference using `vLLMEngineProcessorConfig` with vision support enabled.\n", + "\n", + "This example uses the `Qwen/Qwen2.5-VL-3B-Instruct` model, a vision-language model. The configuration specifies:\n", + "- `model_source`: The Hugging Face model identifier.\n", + "- `engine_kwargs`: vLLM engine parameters such as memory settings and batching.\n", + "- `batch_size`: Number of requests to batch together (set to 16 for vision models).\n", + "- `accelerator_type`: GPU type to use (L4 in this case).\n", + "- `concurrency`: Number of parallel workers (4 in this case).\n", + "- `has_image`: Enable image input support.\n", + "\n", + "Vision models process each image as hundreds or thousands of vision tokens, unlike text-only models. You can set a larger token limit using `max_model_len`. You also need to use smaller batch sizes because image processing increases per-request memory. Adjust both `max_model_len` and `batch_size` for your vision token requirements and available memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.data.llm import vLLMEngineProcessorConfig\n", + "\n", + "processor_config = vLLMEngineProcessorConfig(\n", + " model_source=\"Qwen/Qwen2.5-VL-3B-Instruct\",\n", + " engine_kwargs=dict(\n", + " max_model_len=8192 # Hard cap: all text + vision tokens must fit within this limit\n", + " ),\n", + " batch_size=16,\n", + " accelerator_type=\"L4\",\n", + " concurrency=4,\n", + " has_image=True, # Enable image input.\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details on the configuration options you can pass to the vLLM engine, see the [vLLM Engine Arguments documentation](https://docs.vllm.ai/en/stable/configuration/engine_args.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define the preprocess and postprocess functions\n", + "\n", + "The task is to generate descriptive captions for images using a vision-language model.\n", + "\n", + "Define a preprocess function to prepare `messages` with image content and `sampling_params` for the vLLM engine, and a postprocess function to extract the `generated_text`.\n", + "\n", + "For production workloads with potentially corrupt or malformed images, filter them out before processing. Use Ray Data's `filter()` to validate images upfront - this prevents failures during inference and provides cleaner error handling than catching exceptions in the preprocess function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "\n", + "# Filter function to validate images before processing.\n", + "# Returns True for valid images, False for corrupt/malformed ones.\n", + "def is_valid_image(row: dict[str, Any]) -> bool:\n", + " try:\n", + " Image.open(BytesIO(row['jpg']['bytes']))\n", + " return True\n", + " except Exception:\n", + " return False\n", + "\n", + "# Preprocess function prepares messages with image content for the VLM.\n", + "def preprocess(row: dict[str, Any]) -> dict[str, Any]:\n", + " # Convert bytes image to PIL \n", + " image = row['jpg']['bytes']\n", + " image = Image.open(BytesIO(image))\n", + " # Resize to 225x225 for consistency and predictable vision-token budget.\n", + " # This resolution balances quality with memory usage. Adjust based on your\n", + " # model's expected input size and available GPU memory.\n", + " image = image.resize((225, 225), Image.Resampling.BICUBIC)\n", + " \n", + " return dict(\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a helpful assistant that generates accurate and descriptive captions for images.\"\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": \"Describe this image in detail. Focus on the main subjects, actions, and setting.\"\n", + " },\n", + " {\n", + " \"type\": \"image\",\n", + " \"image\": image # Ray Data accepts PIL Image or image URL.\n", + " }\n", + " ]\n", + " },\n", + " ],\n", + " sampling_params=dict(\n", + " temperature=0.3,\n", + " max_tokens=256\n", + " ),\n", + " )\n", + "\n", + "# Postprocess function extracts the generated caption.\n", + "def postprocess(row: dict[str, Any]) -> dict[str, Any]:\n", + " # Example: validation check, formatting...\n", + " \n", + " return {\n", + " \"generated_caption\": row[\"generated_text\"],\n", + " # Note: Don't include **row here to avoid returning the large image data.\n", + " # Include only the fields you need in the output.\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Build the processor\n", + "\n", + "With the configuration and functions defined, build the processor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.data.llm import build_llm_processor\n", + "\n", + "# Build the LLM processor with the configuration and functions.\n", + "processor = build_llm_processor(\n", + " processor_config,\n", + " preprocess=preprocess,\n", + " postprocess=postprocess,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Process the dataset\n", + "\n", + "Run the processor on your small dataset to perform batch inference. Ray Data automatically distributes the workload across available GPUs and handles batching, retries, and resource management.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "# Filter out invalid images before processing.\n", + "ds_small_filtered = ds_small.filter(is_valid_image)\n", + "\n", + "# Run the processor on the filtered dataset.\n", + "processed_small = processor(ds_small_filtered)\n", + "\n", + "# Materialize the dataset to memory.\n", + "# You can also use writing APIs such as write_parquet() or write_json() to persist the dataset.\n", + "processed_small = processed_small.materialize()\n", + "\n", + "print(f\"\\nProcessed {processed_small.count()} rows successfully.\")\n", + "# Display the first 3 entries to verify the output.\n", + "sampled = processed_small.take(3)\n", + "print(\"\\n==================GENERATED OUTPUT===============\\n\")\n", + "pprint(sampled)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch to production with Anyscale Jobs\n", + "\n", + "For production workloads, deploy your batch inference processor as an [Anyscale Job](https://docs.anyscale.com/platform/jobs). Anyscale takes care of the infrastructure layer and runs your jobs on your dedicated clusters with automatic retries, monitoring, and scheduling.\n", + "\n", + "### Anyscale Runtime\n", + "\n", + "Anyscale Jobs run on [Anyscale Runtime](https://docs.anyscale.com/runtime/data), which includes performance optimizations over open-source Ray Data. Key improvements include faster shuffles, optimized memory management, improved autoscaling, and enhanced fault tolerance for large-scale data processing.\n", + "\n", + "These optimizations are automatic and require no code changes. Your Ray Data pipelines benefit from them simply by running on Anyscale. For batch inference workloads specifically, Anyscale Runtime provides better GPU utilization and reduced overhead when scaling across many nodes.\n", + "\n", + "### Configure an Anyscale Job\n", + "\n", + "Save your batch inference code as `batch_inference_vision.py`, then create a job configuration file:\n", + "\n", + "```yaml\n", + "# job.yaml\n", + "name: my-llm-batch-inference-vision\n", + "entrypoint: python batch_inference_vision.py\n", + "image_uri: anyscale/ray-llm:2.51.1-py311-cu128\n", + "compute_config:\n", + " head_node:\n", + " instance_type: m5.2xlarge\n", + " worker_nodes:\n", + " - instance_type: g6.2xlarge\n", + " min_nodes: 0\n", + " max_nodes: 10\n", + "requirements: # Python dependencies - can be list or path to requirements.txt\n", + " - datasets==4.4.1\n", + "working_dir: .\n", + "max_retries: 2\n", + "\n", + "\n", + "```\n", + "\n", + "### Submit\n", + "\n", + "Submit your job using the Anyscale CLI:\n", + "\n", + "```bash\n", + "anyscale job submit --config-file job.yaml\n", + "```\n", + "\n", + "### Monitoring\n", + "\n", + "Track your job's progress in the Anyscale Console or through the CLI:\n", + "\n", + "```bash\n", + "# Check job status.\n", + "anyscale job status --name my-llm-batch-inference-vision\n", + "\n", + "# View logs.\n", + "anyscale job logs --name my-llm-batch-inference-vision\n", + "```\n", + "\n", + "The Ray Dashboard remains available for detailed monitoring. To access it, go to your Anyscale Job in your console. \n", + "For cluster-level information, click the **Metrics** tab then **Data** tab, and for task-level information, click the **Ray Workloads** tab then **Data** tab." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Monitor the execution\n", + "\n", + "Use the Ray Dashboard to monitor the execution. See [Monitoring your Workload](https://docs.ray.io/en/latest/data/monitoring-your-workload.html) for more information on visualizing your Ray Data jobs.\n", + "\n", + "The dashboard shows:\n", + "- Operator-level metrics (throughput, task execution times).\n", + "- Resource utilization (CPU, GPU, memory).\n", + "- Progress and remaining time estimates.\n", + "- Task status breakdown.\n", + "\n", + "**Tip**: If you encounter CUDA out of memory errors, reduce your batch size, use a smaller model, or switch to a larger GPU. For more troubleshooting tips, see [GPU Memory Management](https://docs.ray.io/en/latest/data/working-with-llms.html#gpu-memory-management-and-cuda-oom-prevention).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scale up to larger datasets\n", + "\n", + "Your Ray Data processing pipeline can easily scale up to process more images. By default, this section processes 1M images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The BLIP3o/BLIP3o-Pretrain-Short-Caption dataset has ~5M images\n", + "# Configure how many images to process (default: 1M for demonstration).\n", + "print(f\"Processing 1M images... (or the whole dataset if you picked >5M)\")\n", + "ds_large = ds.limit(1_000_000)\n", + "\n", + "# As we increase our compute, we can increase the number of partitions for more parallelism\n", + "num_partitions_large = 128\n", + "print(f\"Repartitioning dataset into {num_partitions_large} blocks for parallelism...\")\n", + "ds_large = ds_large.repartition(num_blocks=num_partitions_large)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can scale the number of concurrent replicas based on the compute available in your cluster. In this case, each replica is a copy of your Qwen-VL model and fits in a single L4 GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processor_config_large = vLLMEngineProcessorConfig(\n", + " model_source=\"Qwen/Qwen2.5-VL-3B-Instruct\",\n", + " engine_kwargs=dict(\n", + " max_model_len=8192, # Hard cap: all text + vision tokens must fit within this limit\n", + " ),\n", + " batch_size=16,\n", + " accelerator_type=\"L4\", # Or upgrade to larger GPU\n", + " concurrency=10, # Increase the number of parallel workers\n", + " has_image=True, # Enable image input\n", + ")\n", + "\n", + "# Build the LLM processor with the configuration and functions.\n", + "processor_large = build_llm_processor(\n", + " processor_config_large,\n", + " preprocess=preprocess,\n", + " postprocess=postprocess,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Execute the new pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filter out invalid images before processing.\n", + "ds_large_filtered = ds_large.filter(is_valid_image)\n", + "\n", + "# Run the compute-scaled processor on the larger dataset.\n", + "processed_large = processor_large(ds_large_filtered)\n", + "processed_large = processed_large.materialize()\n", + "\n", + "print(f\"\\nProcessed {processed_large.count()} rows successfully.\")\n", + "# Display the first 3 entries to verify the output.\n", + "sampled = processed_large.take(3)\n", + "print(\"\\n==================GENERATED OUTPUT===============\\n\")\n", + "pprint(sampled)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance optimization tips\n", + "\n", + "When scaling to larger datasets, consider these optimizations. For comprehensive guidance, see the [Ray Data performance guide](https://docs.ray.io/en/latest/data/performance-tips.html) and the [throughput optimization guide with Anyscale](https://docs.anyscale.com/llm/batch-inference/throughput-optimization).\n", + "\n", + "**Analyze your pipeline**\n", + "You can use *stats()* to examine the throughput and timing at every step in your pipeline and spot potential bottlenecks.\n", + "The *stats()* output reports how long each operator took and its throughput, so you can compare these values to expected throughput for your hardware. If you see a step with significantly lower throughput or much higher task durations than others, that's likely a bottleneck.\n", + "The following example shows how to print pipeline stats:\n", + "```python\n", + "processed = processor(ds).materialize()\n", + "print(processed.stats())\n", + "```\n", + "The outputs include detailed timing, throughput, and resource utilization for each pipeline operator.\n", + "For example:\n", + "```text\n", + "Operator 0 ...\n", + "\n", + "...\n", + "\n", + "Operator 8 MapBatches(vLLMEngineStageUDF): 3908 tasks executed, 3908 blocks produced in 340.21s\n", + " * Remote wall time: 340.21s \n", + " * Input/output rows: ...\n", + " * Throughput: 2,900 rows/s\n", + " ...\n", + "\n", + "...\n", + "\n", + "Dataset throughput:\n", + " * Ray Data throughput: 2,500 rows/s\n", + " * Estimated single node throughput: 5,000 rows/s\n", + "```\n", + "\n", + "Review the per-operator throughput numbers and durations to spot slowest stages or unexpected bottlenecks. You can then adjust batch size, concurrency, or optimize resource usage for affected steps.\n", + "\n", + "**Adjust concurrency** \n", + "The `concurrency` parameter controls how many model replicas run in parallel. To determine the right value:\n", + "- *Available GPU count:* Start with the number of GPUs in your cluster. Each replica needs at least one GPU (more if using tensor parallelism).\n", + "- *Model memory footprint:* Ensure your model fits in GPU memory. For example, an 8 B parameter model in FP16 requires ~16 GB, fitting on a single L4 (24 GB) or A10G (24 GB).\n", + "- *CPU-bound preprocessing:* If preprocessing (image decoding, resizing) is slower than inference, adding more GPU replicas won't help. Check `stats()` output to identify if preprocessing is the bottleneck.\n", + "\n", + "**Tune batch size** \n", + "The `batch_size` parameter controls how many requests Ray Data sends to vLLM at once. vLLM uses continuous batching internally, controlled by `max_num_seqs` in `engine_kwargs`. This directly impacts GPU memory allocation since vLLM pre-allocates KV cache for up to `max_num_seqs` concurrent sequences.\n", + "\n", + "- *Too small `batch_size`:* vLLM scheduler is under-saturated, risking GPU idle time.\n", + "- *Too large `batch_size`:* vLLM scheduler is over-saturated, causing overhead latency. Also increases retry cost on failure since the entire batch is retried.\n", + "\n", + "You can try the following suggestions:\n", + "1. Start with `batch_size` equal to `max_num_seqs` in your vLLM engine parameters. See [vLLM engine arguments](https://docs.vllm.ai/en/stable/serving/engine_args.html) for defaults.\n", + "2. Monitor GPU utilization in the Ray Dashboard (see [Monitor the execution](#monitor-the-execution) section).\n", + "3. Adjust `max_num_seqs` in `engine_kwargs` to optimize GPU utilization, and re-adapt `batch_size` accordingly.\n", + "\n", + "**Optimize image loading** \n", + "Pre-resize images to a consistent size to reduce memory usage and improve throughput.\n", + "\n", + "**Tune preprocessing and inference stage parallelism** \n", + "Use `repartition()` to control parallelism during your preprocessing stage. On the other hand, the number of inference tasks is determined by `dataset_size / batch_size`, where `batch_size` controls how many rows are grouped for each vLLM engine call. Ensure you have enough tasks to keep all workers busy and enable efficient load balancing.\n", + "\n", + "See [Configure parallelism for Ray Data LLM](https://docs.anyscale.com/llm/batch-inference/resource-allocation/concurrency-and-batching.md) for detailed guidance.\n", + "\n", + "**Use quantization to reduce memory footprint** \n", + "Quantization reduces model precision to save GPU memory and improve throughput; vLLM supports this with the `quantization` field in `engine_kwargs`. Note that lower precision may impact output quality, and not all models or GPUs support all quantization types, see [Quantization for LLM batch inference](https://docs.anyscale.com/llm/batch-inference/throughput-optimization/quantization.md) for more guidance.\n", + "\n", + "**Fault tolerance and checkpointing** \n", + "Ray Data automatically handles fault tolerance - if a worker fails, only that worker's current batch is retried. For long-running Anyscale Jobs, you can enable job-level checkpointing to resume from failures. See [Anyscale Runtime checkpointing documentation](https://docs.anyscale.com/runtime/data#enable-job-level-checkpointing) for more information.\n", + "\n", + "**Scale to larger models with model parallelism** \n", + "Model parallelism distributes large models across multiple GPUs when they don't fit on a single GPU. Use tensor parallelism to split model layers horizontally across multiple GPUs within a single node and use pipeline parallelism to split model layers vertically across multiple nodes, with each node processing different layers of the model.\n", + "\n", + "Forward model parallelism parameters to your inference engine using the `engine_kwargs` argument of your `vLLMEngineProcessorConfig` object. If your GPUs span multiple nodes, set `ray` as the distributed executor backend to enable cross-node parallelism. This example snippet uses Llama-3.2-90 B-Vision-Instruct, a large vision model requiring multiple GPUs:\n", + "\n", + "```python\n", + "processor_config = vLLMEngineProcessorConfig(\n", + " model_source=\"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n", + " accelerator_type=\"H100\",\n", + " engine_kwargs={\n", + " \"tensor_parallel_size\": 8, # 8 GPUs per node\n", + " \"pipeline_parallel_size\": 2, # Split across 2 nodes\n", + " \"distributed_executor_backend\": \"ray\", # Required to enable cross-node parallelism\n", + "\n", + " },\n", + " concurrency=1,\n", + ")\n", + "# Each worker uses: 8 GPUs × 2 nodes = 16 GPUs total\n", + "```\n", + "\n", + "Each inference worker allocates GPUs based on `tensor_parallel_size × pipeline_parallel_size`. For detailed guidance on parallelism strategies, see the [vLLM parallelism and scaling documentation](https://docs.vllm.ai/en/stable/serving/distributed_serving.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this example, you built an end-to-end vision batch inference pipeline: loading an HuggingFace image dataset into Ray Dataset, configuring a vLLM processor for the Qwen2.5-VL vision-language model, and adding pre/post-processing to generate image captions. You validated the flow on 10,000 images, scaled to 1M images, monitored progress in the Ray Dashboard, and saved the results to persistent storage.\n", + "\n", + "See [Anyscale batch inference optimization](https://docs.anyscale.com/llm/batch-inference) for more information on using Ray Data with Anyscale and for more advanced use cases, see [Working with LLMs](https://docs.ray.io/en/latest/data/working-with-llms.html).\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + }, + "orphan": true + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/source/data/examples/llm_batch_inference_vision/content/README.md b/doc/source/data/examples/llm_batch_inference_vision/content/README.md new file mode 100644 index 000000000000..16999f624818 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/content/README.md @@ -0,0 +1,455 @@ + + +# Multimodal LLM batch inference with Ray Data LLM + + + +**⏱️ Time to complete**: 20 minutes + +This example shows you how to run batch inference for vision-language models (VLMs) using [Ray Data LLM APIs](https://docs.ray.io/en/latest/data/api/llm.html). In this use case, the batch inference job generates captions for a large-scale image dataset. + +## When to use LLM batch inference + +Offline (batch) inference optimizes for throughput over latency. Unlike online inference, which processes requests one at a time in real-time, batch inference processes thousands or millions of inputs together, maximizing GPU utilization and reducing per-inference costs. + +Choose batch inference when: +- You have a fixed dataset to process (such as daily reports or data migrations) +- Throughput matters more than immediate results +- You want to take advantage of fault tolerance for long-running jobs + +On the contrary, if you are more interested in optimizing for latency, consider [deploying your LLM with Ray Serve LLM for online inference](https://docs.ray.io/en/latest/serve/llm/index.html). + + +## Prepare a Ray Data dataset with images + +Ray Data LLM runs batch inference for VLMs on Ray Data datasets containing images. In this tutorial, you perform batch inference with a vision-language model to generate image captions from the `BLIP3o/BLIP3o-Pretrain-Short-Caption` dataset, which contains approximately 5 million images. + +First, load the data from a remote URL then repartition the dataset to ensure the workload can be distributed across multiple GPUs. + + +```python +%pip install datasets==4.4.2 +``` + + +```python +import ray +import datasets +from PIL import Image +from io import BytesIO + +# Load the BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face with ~5M images. +print("Loading BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face...") +hf_dataset = datasets.load_dataset("BLIP3o/BLIP3o-Pretrain-Short-Caption", split="train", streaming=True) +hf_dataset = hf_dataset.select_columns(["jpg"]) + +ds = ray.data.from_huggingface(hf_dataset) +print("Dataset loaded successfully.") + +sample = ds.take(2) +print("Sample data:") +for i, item in enumerate(sample): + print(f"\nSample {i+1}:") + image = Image.open(BytesIO(item['jpg']['bytes'])) + image.show() +``` + +For this initial example, limit the dataset to 10,000 rows so you can process and test faster. Later, you can scale up to the full dataset. + +If you don't repartition, the system might read a large file into only a few blocks, which limits parallelism in later steps. For example, you might see that only 4 out of 8 GPUs in your cluster are being used. To address this, you can repartition the data into a specific number of blocks so the system can better parallelize work across all available GPUs in the pipeline. + + +```python +# Limit the dataset to 10,000 images for this example. +print("Limiting dataset to 10,000 images for initial processing.") +ds_small = ds.limit(10_000) + +# Repartition the dataset to enable parallelism across multiple workers (GPUs). +# By default, streaming datasets might not be optimally partitioned. Repartitioning +# splits the data into a specified number of blocks, allowing Ray to process them +# in parallel. +# Tip: Repartition count should typically be 2-4x your worker (GPU) count. +# Example: 4 GPUs → 8-16 partitions, 10 GPUs → 20-40 partitions. +# This ensures enough parallelism while avoiding excessive overhead. +num_partitions = 64 +print(f"Repartitioning dataset into {num_partitions} blocks for parallelism...") +ds_small = ds_small.repartition(num_blocks=num_partitions) +``` + +## Configure Ray Data LLM + +Ray Data LLM provides a unified interface to run batch inference with different VLM engines. Configure the vLLM engine with a vision-language model, define preprocessing and postprocessing functions, and build the processor. + +### Configure the processor engine + +Configure the model and compute resources needed for inference using `vLLMEngineProcessorConfig` with vision support enabled. + +This example uses the `Qwen/Qwen2.5-VL-3B-Instruct` model, a vision-language model. The configuration specifies: +- `model_source`: The Hugging Face model identifier. +- `engine_kwargs`: vLLM engine parameters such as memory settings and batching. +- `batch_size`: Number of requests to batch together (set to 16 for vision models). +- `accelerator_type`: GPU type to use (L4 in this case). +- `concurrency`: Number of parallel workers (4 in this case). +- `has_image`: Enable image input support. + +Vision models process each image as hundreds or thousands of vision tokens, unlike text-only models. You can set a larger token limit using `max_model_len`. You also need to use smaller batch sizes because image processing increases per-request memory. Adjust both `max_model_len` and `batch_size` for your vision token requirements and available memory. + + +```python +from ray.data.llm import vLLMEngineProcessorConfig + +processor_config = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + max_model_len=8192 # Hard cap: all text + vision tokens must fit within this limit + ), + batch_size=16, + accelerator_type="L4", + concurrency=4, + has_image=True, # Enable image input. +) + +``` + +For more details on the configuration options you can pass to the vLLM engine, see the [vLLM Engine Arguments documentation](https://docs.vllm.ai/en/stable/configuration/engine_args.html). + +### Define the preprocess and postprocess functions + +The task is to generate descriptive captions for images using a vision-language model. + +Define a preprocess function to prepare `messages` with image content and `sampling_params` for the vLLM engine, and a postprocess function to extract the `generated_text`. + +For production workloads with potentially corrupt or malformed images, filter them out before processing. Use Ray Data's `filter()` to validate images upfront - this prevents failures during inference and provides cleaner error handling than catching exceptions in the preprocess function. + + +```python +from typing import Any +from PIL import Image +from io import BytesIO + +# Filter function to validate images before processing. +# Returns True for valid images, False for corrupt/malformed ones. +def is_valid_image(row: dict[str, Any]) -> bool: + try: + Image.open(BytesIO(row['jpg']['bytes'])) + return True + except Exception: + return False + +# Preprocess function prepares messages with image content for the VLM. +def preprocess(row: dict[str, Any]) -> dict[str, Any]: + # Convert bytes image to PIL + image = row['jpg']['bytes'] + image = Image.open(BytesIO(image)) + # Resize to 225x225 for consistency and predictable vision-token budget. + # This resolution balances quality with memory usage. Adjust based on your + # model's expected input size and available GPU memory. + image = image.resize((225, 225), Image.Resampling.BICUBIC) + + return dict( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that generates accurate and descriptive captions for images." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image in detail. Focus on the main subjects, actions, and setting." + }, + { + "type": "image", + "image": image # Ray Data accepts PIL Image or image URL. + } + ] + }, + ], + sampling_params=dict( + temperature=0.3, + max_tokens=256 + ), + ) + +# Postprocess function extracts the generated caption. +def postprocess(row: dict[str, Any]) -> dict[str, Any]: + # Example: validation check, formatting... + + return { + "generated_caption": row["generated_text"], + # Note: Don't include **row here to avoid returning the large image data. + # Include only the fields you need in the output. + } +``` + +### Build the processor + +With the configuration and functions defined, build the processor. + + +```python +from ray.data.llm import build_llm_processor + +# Build the LLM processor with the configuration and functions. +processor = build_llm_processor( + processor_config, + preprocess=preprocess, + postprocess=postprocess, +) +``` + +## Process the dataset + +Run the processor on your small dataset to perform batch inference. Ray Data automatically distributes the workload across available GPUs and handles batching, retries, and resource management. + + + +```python +from pprint import pprint + +# Filter out invalid images before processing. +ds_small_filtered = ds_small.filter(is_valid_image) + +# Run the processor on the filtered dataset. +processed_small = processor(ds_small_filtered) + +# Materialize the dataset to memory. +# You can also use writing APIs such as write_parquet() or write_json() to persist the dataset. +processed_small = processed_small.materialize() + +print(f"\nProcessed {processed_small.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_small.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) +``` + +## Launch to production with Anyscale Jobs + +For production workloads, deploy your batch inference processor as an [Anyscale Job](https://docs.anyscale.com/platform/jobs). Anyscale takes care of the infrastructure layer and runs your jobs on your dedicated clusters with automatic retries, monitoring, and scheduling. + +### Anyscale Runtime + +Anyscale Jobs run on [Anyscale Runtime](https://docs.anyscale.com/runtime/data), which includes performance optimizations over open-source Ray Data. Key improvements include faster shuffles, optimized memory management, improved autoscaling, and enhanced fault tolerance for large-scale data processing. + +These optimizations are automatic and require no code changes. Your Ray Data pipelines benefit from them simply by running on Anyscale. For batch inference workloads specifically, Anyscale Runtime provides better GPU utilization and reduced overhead when scaling across many nodes. + +### Configure an Anyscale Job + +Save your batch inference code as `batch_inference_vision.py`, then create a job configuration file: + +```yaml +# job.yaml +name: my-llm-batch-inference-vision +entrypoint: python batch_inference_vision.py +image_uri: anyscale/ray-llm:2.51.1-py311-cu128 +compute_config: + head_node: + instance_type: m5.2xlarge + worker_nodes: + - instance_type: g6.2xlarge + min_nodes: 0 + max_nodes: 10 +requirements: # Python dependencies - can be list or path to requirements.txt + - datasets==4.4.1 +working_dir: . +max_retries: 2 + + +``` + +### Submit + +Submit your job using the Anyscale CLI: + +```bash +anyscale job submit --config-file job.yaml +``` + +### Monitoring + +Track your job's progress in the Anyscale Console or through the CLI: + +```bash +# Check job status. +anyscale job status --name my-llm-batch-inference-vision + +# View logs. +anyscale job logs --name my-llm-batch-inference-vision +``` + +The Ray Dashboard remains available for detailed monitoring. To access it, go to your Anyscale Job in your console. +For cluster-level information, click the **Metrics** tab then **Data** tab, and for task-level information, click the **Ray Workloads** tab then **Data** tab. + +## Monitor the execution + +Use the Ray Dashboard to monitor the execution. See [Monitoring your Workload](https://docs.ray.io/en/latest/data/monitoring-your-workload.html) for more information on visualizing your Ray Data jobs. + +The dashboard shows: +- Operator-level metrics (throughput, task execution times). +- Resource utilization (CPU, GPU, memory). +- Progress and remaining time estimates. +- Task status breakdown. + +**Tip**: If you encounter CUDA out of memory errors, reduce your batch size, use a smaller model, or switch to a larger GPU. For more troubleshooting tips, see [GPU Memory Management](https://docs.ray.io/en/latest/data/working-with-llms.html#gpu-memory-management-and-cuda-oom-prevention). + + +## Scale up to larger datasets + +Your Ray Data processing pipeline can easily scale up to process more images. By default, this section processes 1M images. + + +```python +# The BLIP3o/BLIP3o-Pretrain-Short-Caption dataset has ~5M images +# Configure how many images to process (default: 1M for demonstration). +print(f"Processing 1M images... (or the whole dataset if you picked >5M)") +ds_large = ds.limit(1_000_000) + +# As we increase our compute, we can increase the number of partitions for more parallelism +num_partitions_large = 128 +print(f"Repartitioning dataset into {num_partitions_large} blocks for parallelism...") +ds_large = ds_large.repartition(num_blocks=num_partitions_large) +``` + +You can scale the number of concurrent replicas based on the compute available in your cluster. In this case, each replica is a copy of your Qwen-VL model and fits in a single L4 GPU. + + +```python +processor_config_large = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + max_model_len=8192, # Hard cap: all text + vision tokens must fit within this limit + ), + batch_size=16, + accelerator_type="L4", # Or upgrade to larger GPU + concurrency=10, # Increase the number of parallel workers + has_image=True, # Enable image input +) + +# Build the LLM processor with the configuration and functions. +processor_large = build_llm_processor( + processor_config_large, + preprocess=preprocess, + postprocess=postprocess, +) +``` + +Execute the new pipeline + + +```python +# Filter out invalid images before processing. +ds_large_filtered = ds_large.filter(is_valid_image) + +# Run the compute-scaled processor on the larger dataset. +processed_large = processor_large(ds_large_filtered) +processed_large = processed_large.materialize() + +print(f"\nProcessed {processed_large.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_large.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) +``` + +## Performance optimization tips + +When scaling to larger datasets, consider these optimizations. For comprehensive guidance, see the [Ray Data performance guide](https://docs.ray.io/en/latest/data/performance-tips.html) and the [throughput optimization guide with Anyscale](https://docs.anyscale.com/llm/batch-inference/throughput-optimization). + +**Analyze your pipeline** +You can use *stats()* to examine the throughput and timing at every step in your pipeline and spot potential bottlenecks. +The *stats()* output reports how long each operator took and its throughput, so you can compare these values to expected throughput for your hardware. If you see a step with significantly lower throughput or much higher task durations than others, that's likely a bottleneck. +The following example shows how to print pipeline stats: +```python +processed = processor(ds).materialize() +print(processed.stats()) +``` +The outputs include detailed timing, throughput, and resource utilization for each pipeline operator. +For example: +```text +Operator 0 ... + +... + +Operator 8 MapBatches(vLLMEngineStageUDF): 3908 tasks executed, 3908 blocks produced in 340.21s + * Remote wall time: 340.21s + * Input/output rows: ... + * Throughput: 2,900 rows/s + ... + +... + +Dataset throughput: + * Ray Data throughput: 2,500 rows/s + * Estimated single node throughput: 5,000 rows/s +``` + +Review the per-operator throughput numbers and durations to spot slowest stages or unexpected bottlenecks. You can then adjust batch size, concurrency, or optimize resource usage for affected steps. + +**Adjust concurrency** +The `concurrency` parameter controls how many model replicas run in parallel. To determine the right value: +- *Available GPU count:* Start with the number of GPUs in your cluster. Each replica needs at least one GPU (more if using tensor parallelism). +- *Model memory footprint:* Ensure your model fits in GPU memory. For example, an 8 B parameter model in FP16 requires ~16 GB, fitting on a single L4 (24 GB) or A10G (24 GB). +- *CPU-bound preprocessing:* If preprocessing (image decoding, resizing) is slower than inference, adding more GPU replicas won't help. Check `stats()` output to identify if preprocessing is the bottleneck. + +**Tune batch size** +The `batch_size` parameter controls how many requests Ray Data sends to vLLM at once. vLLM uses continuous batching internally, controlled by `max_num_seqs` in `engine_kwargs`. This directly impacts GPU memory allocation since vLLM pre-allocates KV cache for up to `max_num_seqs` concurrent sequences. + +- *Too small `batch_size`:* vLLM scheduler is under-saturated, risking GPU idle time. +- *Too large `batch_size`:* vLLM scheduler is over-saturated, causing overhead latency. Also increases retry cost on failure since the entire batch is retried. + +You can try the following suggestions: +1. Start with `batch_size` equal to `max_num_seqs` in your vLLM engine parameters. See [vLLM engine arguments](https://docs.vllm.ai/en/stable/serving/engine_args.html) for defaults. +2. Monitor GPU utilization in the Ray Dashboard (see [Monitor the execution](#monitor-the-execution) section). +3. Adjust `max_num_seqs` in `engine_kwargs` to optimize GPU utilization, and re-adapt `batch_size` accordingly. + +**Optimize image loading** +Pre-resize images to a consistent size to reduce memory usage and improve throughput. + +**Tune preprocessing and inference stage parallelism** +Use `repartition()` to control parallelism during your preprocessing stage. On the other hand, the number of inference tasks is determined by `dataset_size / batch_size`, where `batch_size` controls how many rows are grouped for each vLLM engine call. Ensure you have enough tasks to keep all workers busy and enable efficient load balancing. + +See [Configure parallelism for Ray Data LLM](https://docs.anyscale.com/llm/batch-inference/resource-allocation/concurrency-and-batching.md) for detailed guidance. + +**Use quantization to reduce memory footprint** +Quantization reduces model precision to save GPU memory and improve throughput; vLLM supports this with the `quantization` field in `engine_kwargs`. Note that lower precision may impact output quality, and not all models or GPUs support all quantization types, see [Quantization for LLM batch inference](https://docs.anyscale.com/llm/batch-inference/throughput-optimization/quantization.md) for more guidance. + +**Fault tolerance and checkpointing** +Ray Data automatically handles fault tolerance - if a worker fails, only that worker's current batch is retried. For long-running Anyscale Jobs, you can enable job-level checkpointing to resume from failures. See [Anyscale Runtime checkpointing documentation](https://docs.anyscale.com/runtime/data#enable-job-level-checkpointing) for more information. + +**Scale to larger models with model parallelism** +Model parallelism distributes large models across multiple GPUs when they don't fit on a single GPU. Use tensor parallelism to split model layers horizontally across multiple GPUs within a single node and use pipeline parallelism to split model layers vertically across multiple nodes, with each node processing different layers of the model. + +Forward model parallelism parameters to your inference engine using the `engine_kwargs` argument of your `vLLMEngineProcessorConfig` object. If your GPUs span multiple nodes, set `ray` as the distributed executor backend to enable cross-node parallelism. This example snippet uses Llama-3.2-90 B-Vision-Instruct, a large vision model requiring multiple GPUs: + +```python +processor_config = vLLMEngineProcessorConfig( + model_source="meta-llama/Llama-3.2-90B-Vision-Instruct", + accelerator_type="H100", + engine_kwargs={ + "tensor_parallel_size": 8, # 8 GPUs per node + "pipeline_parallel_size": 2, # Split across 2 nodes + "distributed_executor_backend": "ray", # Required to enable cross-node parallelism + + }, + concurrency=1, +) +# Each worker uses: 8 GPUs × 2 nodes = 16 GPUs total +``` + +Each inference worker allocates GPUs based on `tensor_parallel_size × pipeline_parallel_size`. For detailed guidance on parallelism strategies, see the [vLLM parallelism and scaling documentation](https://docs.vllm.ai/en/stable/serving/distributed_serving.html). + +## Summary + +In this example, you built an end-to-end vision batch inference pipeline: loading an HuggingFace image dataset into Ray Dataset, configuring a vLLM processor for the Qwen2.5-VL vision-language model, and adding pre/post-processing to generate image captions. You validated the flow on 10,000 images, scaled to 1M images, monitored progress in the Ray Dashboard, and saved the results to persistent storage. + +See [Anyscale batch inference optimization](https://docs.anyscale.com/llm/batch-inference) for more information on using Ray Data with Anyscale and for more advanced use cases, see [Working with LLMs](https://docs.ray.io/en/latest/data/working-with-llms.html). + diff --git a/doc/source/data/examples/llm_batch_inference_vision/content/batch_inference_vision.py b/doc/source/data/examples/llm_batch_inference_vision/content/batch_inference_vision.py new file mode 100644 index 000000000000..ddcb767d9034 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/content/batch_inference_vision.py @@ -0,0 +1,123 @@ +from typing import Any + +import ray +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig +import datasets +from PIL import Image +from io import BytesIO + +DATASET_LIMIT = 10_000 + +# Load the BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face with ~5M images. +print("Loading BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face...") +hf_dataset = datasets.load_dataset("BLIP3o/BLIP3o-Pretrain-Short-Caption", split="train", streaming=True) +hf_dataset = hf_dataset.select_columns(["jpg"]) + +ds = ray.data.from_huggingface(hf_dataset) +print("Dataset loaded successfully.") + +# Limit the dataset. If DATASET_LIMIT > dataset size, the entire dataset will be processed. +print(f"Limiting dataset to {DATASET_LIMIT} images for initial processing.") +ds_small = ds.limit(DATASET_LIMIT) + +# Repartition the dataset to enable parallelism across multiple workers (GPUs). +# By default, streaming datasets might not be optimally partitioned. Repartitioning +# splits the data into a specified number of blocks, allowing Ray to process them +# in parallel. +# Tip: Repartition count should typically be 2-4x your worker (GPU) count. +# Example: 4 GPUs → 8-16 partitions, 10 GPUs → 20-40 partitions. +# This ensures enough parallelism while avoiding excessive overhead. +num_partitions = 64 +print(f"Repartitioning dataset into {num_partitions} blocks for parallelism...") +ds_small = ds_small.repartition(num_blocks=num_partitions) + + +processor_config = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + max_model_len=8192 + ), + batch_size=16, + accelerator_type="L4", + concurrency=4, + has_image=True, # Enable image input. +) + + +# Filter function to validate images before processing. +# Returns True for valid images, False for corrupt/malformed ones. +def is_valid_image(row: dict[str, Any]) -> bool: + try: + Image.open(BytesIO(row['jpg']['bytes'])) + return True + except Exception: + return False + +# Preprocess function prepares messages with image content for the VLM. +def preprocess(row: dict[str, Any]) -> dict[str, Any]: + # Convert bytes image to PIL + image = row['jpg']['bytes'] + image = Image.open(BytesIO(image)) + # Resize to 225x225 for consistency and predictable vision-token budget. + # This resolution balances quality with memory usage. Adjust based on your + # model's expected input size and available GPU memory. + image = image.resize((225, 225), Image.Resampling.BICUBIC) + + return dict( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that generates accurate and descriptive captions for images." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image in detail. Focus on the main subjects, actions, and setting." + }, + { + "type": "image", + "image": image # Ray Data accepts PIL Image or image URL. + } + ] + }, + ], + sampling_params=dict( + temperature=0.3, + max_tokens=256 + ), + ) + +# Postprocess function extracts the generated caption. +def postprocess(row: dict[str, Any]) -> dict[str, Any]: + return { + "generated_caption": row["generated_text"], + # Note: Don't include **row here to avoid returning the large image data. + # Include only the fields you need in the output. + } + +# Build the LLM processor with the configuration and functions. +processor = build_llm_processor( + processor_config, + preprocess=preprocess, + postprocess=postprocess, +) + +from pprint import pprint + +# Filter out invalid images before processing. +ds_small_filtered = ds_small.filter(is_valid_image) + +# Run the processor on the filtered dataset. +processed_small = processor(ds_small_filtered) + +# Materialize the dataset to memory. +# You can also use writing APIs such as write_parquet() or write_json() to persist the dataset. +processed_small = processed_small.materialize() + +print(f"\nProcessed {processed_small.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_small.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) diff --git a/doc/source/data/examples/llm_batch_inference_vision/content/batch_inference_vision_scaled.py b/doc/source/data/examples/llm_batch_inference_vision/content/batch_inference_vision_scaled.py new file mode 100644 index 000000000000..fc2e98f4967b --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/content/batch_inference_vision_scaled.py @@ -0,0 +1,114 @@ +from typing import Any + +from pprint import pprint +import ray +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig +import datasets +from PIL import Image +from io import BytesIO + +# Dataset limit for this example. +DATASET_LIMIT = 1_000_000 + +# Load the BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face with ~5M images. +print("Loading BLIP3o/BLIP3o-Pretrain-Short-Caption dataset from Hugging Face...") +hf_dataset = datasets.load_dataset("BLIP3o/BLIP3o-Pretrain-Short-Caption", split="train", streaming=True) +hf_dataset = hf_dataset.select_columns(["jpg"]) + +ds = ray.data.from_huggingface(hf_dataset) +print("Dataset loaded successfully.") + +# Limit the dataset. If DATASET_LIMIT > dataset size, the entire dataset will be processed. +print(f"Limiting dataset to {DATASET_LIMIT} images for initial processing.") +ds_large = ds.limit(DATASET_LIMIT) + +# As we increase our compute, we can increase the number of partitions for more parallelism +num_partitions_large = 128 +print(f"Repartitioning dataset into {num_partitions_large} blocks for parallelism...") +ds_large = ds_large.repartition(num_blocks=num_partitions_large) + + +processor_config_large = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + max_model_len=8192, + ), + batch_size=16, + accelerator_type="L4", # Or upgrade to larger GPU + concurrency=10, # Increase the number of parallel workers + has_image=True, # Enable image input +) + + +# Filter function to validate images before processing. +# Returns True for valid images, False for corrupt/malformed ones. +def is_valid_image(row: dict[str, Any]) -> bool: + try: + Image.open(BytesIO(row['jpg']['bytes'])) + return True + except Exception: + return False + +# Preprocess function prepares messages with image content for the VLM. +def preprocess(row: dict[str, Any]) -> dict[str, Any]: + # Convert bytes image to PIL + image = row['jpg']['bytes'] + image = Image.open(BytesIO(image)) + # Resize to 225x225 for consistency and predictable vision-token budget. + # This resolution balances quality with memory usage. Adjust based on your + # model's expected input size and available GPU memory. + image = image.resize((225, 225), Image.Resampling.BICUBIC) + + return dict( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that generates accurate and descriptive captions for images." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image in detail. Focus on the main subjects, actions, and setting." + }, + { + "type": "image", + "image": image # Ray Data accepts PIL Image or image URL. + } + ] + }, + ], + sampling_params=dict( + temperature=0.3, + max_tokens=256 + ), + ) + +# Postprocess function extracts the generated caption. +def postprocess(row: dict[str, Any]) -> dict[str, Any]: + return { + "generated_caption": row["generated_text"], + # Note: Don't include **row here to avoid returning the large image data. + # Include only the fields you need in the output. + } + +# Build the LLM processor with the configuration and functions. +processor_large = build_llm_processor( + processor_config_large, + preprocess=preprocess, + postprocess=postprocess, +) + +# Filter out invalid images before processing. +ds_large_filtered = ds_large.filter(is_valid_image) + +# Run the compute-scaled processor on the larger dataset. +processed_large = processor_large(ds_large_filtered) +processed_large = processed_large.materialize() + +print(f"\nProcessed {processed_large.count()} rows successfully.") +# Display the first 3 entries to verify the output. +sampled = processed_large.take(3) +print("\n==================GENERATED OUTPUT===============\n") +pprint(sampled) diff --git a/doc/source/data/examples/llm_batch_inference_vision/content/job.yaml b/doc/source/data/examples/llm_batch_inference_vision/content/job.yaml new file mode 100644 index 000000000000..d39a4449e2ea --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/content/job.yaml @@ -0,0 +1,14 @@ +name: llm-batch-inference-vision +entrypoint: python batch_inference_vision.py +image_uri: anyscale/ray-llm:2.51.1-py311-cu128 +compute_config: + head_node: + instance_type: m5.2xlarge + worker_nodes: + - instance_type: g6.2xlarge + min_nodes: 0 + max_nodes: 10 +requirements: # Python dependencies - can be list or path to requirements.txt + - datasets==4.4.1 +working_dir: . +max_retries: 2 diff --git a/doc/source/data/examples/llm_batch_inference_vision/convert_to_md.sh b/doc/source/data/examples/llm_batch_inference_vision/convert_to_md.sh new file mode 100755 index 000000000000..7ef7f34bceb1 --- /dev/null +++ b/doc/source/data/examples/llm_batch_inference_vision/convert_to_md.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -exo pipefail + +nb_rel_path="content/README.ipynb" +content_dir="$(dirname "$nb_rel_path")" +nb_filename="$(basename "$nb_rel_path")" +md_rel_path="$content_dir/README.md" + +# Delete README if it already exists +[ -f "$md_rel_path" ] && rm "$md_rel_path" + +# Convert notebook to Markdown +jupyter nbconvert "$nb_rel_path" --to markdown --output README.md + +# Prepend warning comment (will be hidden when rendered in the console) +tmp_file="$(mktemp)" +{ +echo "" +echo "" +cat "$md_rel_path" +} > "$tmp_file" +mv "$tmp_file" "$md_rel_path" diff --git a/release/ray_release/byod/byod_llm_batch_inference_text.sh b/release/ray_release/byod/byod_llm_batch_inference_text.sh new file mode 100755 index 000000000000..ef7e19de90b6 --- /dev/null +++ b/release/ray_release/byod/byod_llm_batch_inference_text.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +set -exo pipefail diff --git a/release/ray_release/byod/byod_llm_batch_inference_vision.sh b/release/ray_release/byod/byod_llm_batch_inference_vision.sh new file mode 100755 index 000000000000..bf55a16a6857 --- /dev/null +++ b/release/ray_release/byod/byod_llm_batch_inference_vision.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +set -exo pipefail + +# Python dependencies +pip3 install --no-cache-dir "datasets==4.4.2" diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 0e1192d54f0c..050288f7d460 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -4816,3 +4816,55 @@ frequency: manual cluster: cluster_compute: ci/gce.yaml # relative to working_dir + + +- name: llm_batch_inference_text # do not use dashes (regex sensitive) + frequency: weekly + python: "3.11" + group: ray-examples + team: ml + working_dir: //doc/source/data/examples/llm_batch_inference_text # use // to access from repo's root + + cluster: + byod: + type: llm-cu128 # anyscale/ray-llm: