Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
github.event_name == 'workflow_dispatch' ||
(github.event_name == 'schedule' && github.repository == 'apache/beam') ||
github.event.comment.body == 'Run Inference Benchmarks'
runs-on: [self-hosted, ubuntu-24.04, main]
runs-on: [self-hosted, ubuntu-20.04, main]
timeout-minutes: 1000
name: ${{ matrix.job_name }} (${{ matrix.job_phrase }})
strategy:
Expand Down Expand Up @@ -99,12 +99,14 @@ jobs:
run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV
- name: Build VLLM Development Image
id: build_vllm_image
if: false
uses: ./.github/actions/build-push-docker-action
with:
dockerfile_path: 'sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile'
image_name: 'us-docker.pkg.dev/apache-beam-testing/beam-temp/beam-vllm-gpu-base'
image_tag: ${{ github.sha }}
- name: Run VLLM Gemma Batch Test
if: false
uses: ./.github/actions/gradle-command-self-hosted-action
timeout-minutes: 180
with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@
--device=CPU
--input_file=gs://apache-beam-ml/testing/inputs/sentences_50k.txt
--runner=DataflowRunner
--sdk_location=container
--sdk_container_image=us.gcr.io/apache-beam-testing/python-postcommit-it/tensor_rt@sha256:884d67e96d9a3c22fb21fcd412c10a012d4c82a7c723f1c1ffe41fca609b5a6a
--model_path=distilbert-base-uncased-finetuned-sst-2-english
--model_state_dict_path=gs://apache-beam-ml/models/huggingface.sentiment.distilbert-base-uncased.pth
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
--device=CPU
--input_file=gs://apache-beam-ml/testing/inputs/sentences_50k.txt
--runner=DataflowRunner
--sdk_location=container
--sdk_container_image=us.gcr.io/apache-beam-testing/python-postcommit-it/tensor_rt@sha256:884d67e96d9a3c22fb21fcd412c10a012d4c82a7c723f1c1ffe41fca609b5a6a
--dataflow_service_options=worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver
--model_path=distilbert-base-uncased-finetuned-sst-2-english
--model_state_dict_path=gs://apache-beam-ml/models/huggingface.sentiment.distilbert-base-uncased.pth
59 changes: 41 additions & 18 deletions sdks/python/apache_beam/examples/inference/pytorch_sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@

class SentimentPostProcessor(beam.DoFn):
"""Processes PredictionResult to extract sentiment label and confidence."""
def __init__(self, tokenizer: DistilBertTokenizerFast):
self.tokenizer = tokenizer

def process(self, element: tuple[str, PredictionResult]) -> Iterable[dict]:
text, prediction_result = element
logits = prediction_result.inference['logits']
Expand All @@ -62,16 +59,26 @@ def process(self, element: tuple[str, PredictionResult]) -> Iterable[dict]:
}


def tokenize_text(text: str,
tokenizer: DistilBertTokenizerFast) -> tuple[str, dict]:
"""Tokenizes input text using the specified tokenizer."""
tokenized = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors="pt")
return text, {k: torch.squeeze(v) for k, v in tokenized.items()}
class TokenizeTextDoFn(beam.DoFn):
"""Initializes tokenizer per worker and tokenizes input text."""
def __init__(self, model_path: str):
self.model_path = model_path
self.tokenizer = None

def setup(self):
self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.model_path)
# Some transformers builds expose pad token through legacy attributes.
if not hasattr(self.tokenizer, '_pad_token'):
self.tokenizer._pad_token = '[PAD]'

def process(self, text: str) -> Iterable[tuple[str, dict]]:
tokenized = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors="pt")
yield text, {k: torch.squeeze(v) for k, v in tokenized.items()}


class RateLimitDoFn(beam.DoFn):
Expand All @@ -83,6 +90,21 @@ def process(self, element):
yield element


def _ensure_transformers_config_compat(config: DistilBertConfig) -> DistilBertConfig:
"""Adds missing config attributes for cross-version transformers compatibility.

The benchmark can run with container images whose transformers version differs
from the launcher environment. Some versions assume these attributes exist.
"""
if not hasattr(config, 'pruned_heads'):
config.pruned_heads = {}
if not hasattr(config, 'torchscript'):
config.torchscript = False
if not hasattr(config, 'return_dict'):
config.return_dict = True
return config


def parse_known_args(argv):
"""Parses command-line arguments for pipeline execution."""
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -234,14 +256,15 @@ def run(
method = beam.io.WriteToBigQuery.Method.STREAMING_INSERTS
pipeline_options.view_as(StandardOptions).streaming = True

model_config = _ensure_transformers_config_compat(
DistilBertConfig.from_pretrained(known_args.model_path, num_labels=2))

model_handler = PytorchModelHandlerKeyedTensor(
model_class=DistilBertForSequenceClassification,
model_params={'config': DistilBertConfig(num_labels=2)},
model_params={'config': model_config},
state_dict_path=known_args.model_state_dict_path,
device='GPU')

tokenizer = DistilBertTokenizerFast.from_pretrained(known_args.model_path)

pipeline = test_pipeline or beam.Pipeline(options=pipeline_options)

# Main pipeline: read, process, write result to BigQuery output table
Expand All @@ -264,9 +287,9 @@ def run(

_ = (
input
| 'Tokenize' >> beam.Map(lambda text: tokenize_text(text, tokenizer))
| 'Tokenize' >> beam.ParDo(TokenizeTextDoFn(known_args.model_path))
| 'RunInference' >> RunInference(KeyedModelHandler(model_handler))
| 'PostProcess' >> beam.ParDo(SentimentPostProcessor(tokenizer))
| 'PostProcess' >> beam.ParDo(SentimentPostProcessor())
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
known_args.output_table,
schema='text:STRING, sentiment:STRING, confidence:FLOAT',
Expand Down
Loading