From f6258716445a30c3750c031b39e7b9332db02e5a Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Mon, 27 Jan 2025 13:15:23 +0100 Subject: [PATCH] Add DALI proxy option to EfficientNet example (#5791) Add DALI proxy option to EfficientNet example Add performance check of DALI proxy variant to efficient net performance test Signed-off-by: Joaquin Anton Guirao --- .../image_classification/autoaugment.py | 2 +- .../efficientnet/image_classification/dali.py | 191 +++++++++++--- .../image_classification/dataloaders.py | 233 +++++++++++++++++- .../use_cases/pytorch/efficientnet/main.py | 67 +++-- qa/TL3_EfficientNet_benchmark/test_pytorch.sh | 21 +- 5 files changed, 433 insertions(+), 81 deletions(-) diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/autoaugment.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/autoaugment.py index 41c886525db..b725b6f96a0 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/autoaugment.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/autoaugment.py @@ -96,7 +96,7 @@ def solarize_add(image, addition=0, threshold=128): for i in range(256): if i < threshold: res = i + addition if i + addition <= 255 else 255 - res = res if res >= 0 else 0 + res = int(res if res >= 0 else 0) lut.append(res) else: lut.append(i) diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py index 3fff439d519..9faf8de083c 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py @@ -20,31 +20,37 @@ from nvidia.dali.auto_aug import auto_augment, trivial_augment -@pipeline_def(enable_conditionals=True) -def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_augmentation, - dali_device="gpu", rank=0, world_size=1): - rng = fn.random.coin_flip(probability=0.5) - - jpegs, labels = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank, - num_shards=world_size, random_shuffle=True, pad_last_batch=True) - - if dali_device == "gpu": - decoder_device = "mixed" - resize_device = "gpu" - else: - decoder_device = "cpu" - resize_device = "cpu" - - images = fn.decoders.image_random_crop(jpegs, device=decoder_device, output_type=types.RGB, - random_aspect_ratio=[0.75, 4.0 / 3.0], - random_area=[0.08, 1.0]) - - images = fn.resize(images, device=resize_device, size=[image_size, image_size], - interp_type=interpolation, antialias=False) +def efficientnet_processing_training( + jpegs_input, + interpolation, + image_size, + output_layout, + automatic_augmentation, + dali_device="gpu", +): + """ + Image processing part of the ResNet training pipeline (excluding data loading) + """ + decoder_device = "mixed" if dali_device == "gpu" else "cpu" + images = fn.decoders.image_random_crop( + jpegs_input, + device=decoder_device, + output_type=types.RGB, + random_aspect_ratio=[0.75, 4.0 / 3.0], + random_area=[0.08, 1.0], + ) + + images = fn.resize( + images, + size=[image_size, image_size], + interp_type=interpolation, + antialias=False, + ) # Make sure that from this point we are processing on GPU regardless of dali_device parameter images = images.gpu() + rng = fn.random.coin_flip(probability=0.5) images = fn.flip(images, horizontal=rng) # Based on the specification, apply the automatic augmentation policy. Note, that from the point @@ -53,33 +59,138 @@ def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_ # We pass the shape of the image after the resize so the translate operations are done # relative to the image size. if automatic_augmentation == "autoaugment": - output = auto_augment.auto_augment_image_net(images, shape=[image_size, image_size]) + output = auto_augment.auto_augment_image_net( + images, shape=[image_size, image_size] + ) elif automatic_augmentation == "trivialaugment": - output = trivial_augment.trivial_augment_wide(images, shape=[image_size, image_size]) + output = trivial_augment.trivial_augment_wide( + images, shape=[image_size, image_size] + ) else: output = images - output = fn.crop_mirror_normalize(output, dtype=types.FLOAT, output_layout=output_layout, - crop=(image_size, image_size), - mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], - std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + output = fn.crop_mirror_normalize( + output, + dtype=types.FLOAT, + output_layout=output_layout, + crop=(image_size, image_size), + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255], + ) - return output, labels + return output -@pipeline_def -def validation_pipe(data_dir, interpolation, image_size, image_crop, output_layout, rank=0, - world_size=1): - jpegs, label = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank, - num_shards=world_size, random_shuffle=False, pad_last_batch=True) +@pipeline_def(enable_conditionals=True) +def training_pipe( + data_dir, + interpolation, + image_size, + output_layout, + automatic_augmentation, + dali_device="gpu", + rank=0, + world_size=1, +): + jpegs, labels = fn.readers.file( + name="Reader", + file_root=data_dir, + shard_id=rank, + num_shards=world_size, + random_shuffle=True, + pad_last_batch=True, + ) + outputs = efficientnet_processing_training( + jpegs, + interpolation, + image_size, + output_layout, + automatic_augmentation, + dali_device, + ) + return outputs, labels + +@pipeline_def(enable_conditionals=True) +def training_pipe_external_source( + interpolation, + image_size, + output_layout, + automatic_augmentation, + dali_device="gpu", + rank=0, + world_size=1, +): + filepaths = fn.external_source(name="images", no_copy=True) + jpegs = fn.io.file.read(filepaths) + outputs = efficientnet_processing_training( + jpegs, + interpolation, + image_size, + output_layout, + automatic_augmentation, + dali_device, + ) + return outputs + + +def efficientnet_processing_validation( + jpegs, interpolation, image_size, image_crop, output_layout +): + """ + Image processing part of the ResNet validation pipeline (excluding data loading) + """ images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) - images = fn.resize(images, resize_shorter=image_size, interp_type=interpolation, - antialias=False) + images = fn.resize( + images, + resize_shorter=image_size, + interp_type=interpolation, + antialias=False, + ) + + output = fn.crop_mirror_normalize( + images, + dtype=types.FLOAT, + output_layout=output_layout, + crop=(image_crop, image_crop), + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255], + ) + return output + - output = fn.crop_mirror_normalize(images, dtype=types.FLOAT, output_layout=output_layout, - crop=(image_crop, image_crop), - mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], - std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) - return output, label +@pipeline_def +def validation_pipe( + data_dir, + interpolation, + image_size, + image_crop, + output_layout, + rank=0, + world_size=1, +): + jpegs, label = fn.readers.file( + name="Reader", + file_root=data_dir, + shard_id=rank, + num_shards=world_size, + random_shuffle=False, + pad_last_batch=True, + ) + outputs = efficientnet_processing_validation( + jpegs, interpolation, image_size, image_crop, output_layout + ) + return outputs, label + + +@pipeline_def +def validation_pipe_external_source( + interpolation, image_size, image_crop, output_layout +): + filepaths = fn.external_source(name="images", no_copy=True) + jpegs = fn.io.file.read(filepaths) + outputs = efficientnet_processing_validation( + jpegs, interpolation, image_size, image_crop, output_layout + ) + return outputs diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py index 5b43fbaa402..8dbf4c50296 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py @@ -38,11 +38,18 @@ DATA_BACKEND_CHOICES = ["pytorch", "pytorch_optimized", "synthetic"] try: from nvidia.dali.plugin.pytorch import DALIClassificationIterator + from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy import nvidia.dali.types as types - from image_classification.dali import training_pipe, validation_pipe + from image_classification.dali import ( + training_pipe, + training_pipe_external_source, + validation_pipe, + validation_pipe_external_source, + ) DATA_BACKEND_CHOICES.append("dali") + DATA_BACKEND_CHOICES.append("dali_proxy") except ImportError as e: print( "Please install DALI from https://www.github.com/NVIDIA/DALI to run this example." @@ -81,9 +88,9 @@ def load_jpeg_from_file(path, cuda=True): class DALIWrapper(object): - - def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format): - for data in dalipipeline: + @staticmethod + def gen_wrapper(loader, num_classes, one_hot, memory_format): + for data in loader: if memory_format == torch.channels_last: # If we requested the data in channels_last form, utilize the fact that DALI # can return it as NHWC. The network expects NCHW shape with NHWC internal memory, @@ -107,17 +114,18 @@ def nhwc_to_nchw(t): if one_hot: target = expand(num_classes, torch.float, target) yield input, target - dalipipeline.reset() - def __init__(self, dalipipeline, num_classes, one_hot, memory_format): - self.dalipipeline = dalipipeline + loader.reset() + + def __init__(self, loader, num_classes, one_hot, memory_format): + self.loader = loader self.num_classes = num_classes self.one_hot = one_hot self.memory_format = memory_format def __iter__(self): return DALIWrapper.gen_wrapper( - self.dalipipeline, + self.loader, self.num_classes, self.one_hot, self.memory_format, @@ -274,24 +282,48 @@ def expand(num_classes, dtype, tensor): return e +def as_memory_format(next_input, memory_format): + if memory_format == torch.channels_last: + shape = next_input.shape + stride = next_input.stride() + + # permute shape and stride from NHWC to NCHW + def nhwc_to_nchw(t): + return t[0], t[3], t[1], t[2] + + next_input = torch.as_strided( + next_input, + size=nhwc_to_nchw(shape), + stride=nhwc_to_nchw(stride), + ) + elif memory_format == torch.contiguous_format: + next_input = next_input.contiguous(memory_format=memory_format) + return next_input + + class PrefetchedWrapper(object): @staticmethod - def prefetched_loader(loader, num_classes, one_hot): + def prefetched_loader(loader, num_classes, one_hot, memory_format): stream = torch.cuda.Stream() for next_input, next_target in loader: with torch.cuda.stream(stream): - next_input = next_input.to(device="cuda") + next_input = as_memory_format( + next_input, memory_format=memory_format + ).to(device="cuda") next_target = next_target.to(device="cuda") next_input = next_input.float() if one_hot: next_target = expand(num_classes, torch.float, next_target) yield next_input, next_target - def __init__(self, dataloader, start_epoch, num_classes, one_hot): + def __init__( + self, dataloader, start_epoch, num_classes, one_hot, memory_format=None + ): self.dataloader = dataloader self.epoch = start_epoch - self.one_hot = one_hot self.num_classes = num_classes + self.one_hot = one_hot + self.memory_format = memory_format def __iter__(self): if self.dataloader.sampler is not None and isinstance( @@ -302,7 +334,7 @@ def __iter__(self): self.dataloader.sampler.set_epoch(self.epoch) self.epoch += 1 return PrefetchedWrapper.prefetched_loader( - self.dataloader, self.num_classes, self.one_hot + self.dataloader, self.num_classes, self.one_hot, self.memory_format ) def __len__(self): @@ -642,6 +674,181 @@ def get_pytorch_optimize_val_loader( ) +def read_file(path): + return np.fromfile(path, dtype=np.uint8) + + +def read_filepath(path): + return np.frombuffer(path.encode(), dtype=np.int8) + + +def get_dali_proxy_train_loader(dali_device="gpu"): + def get_impl( + data_path, + image_size, + batch_size, + num_classes, + one_hot, + interpolation="bilinear", + augmentation=None, + start_epoch=0, + workers=5, + _worker_init_fn=None, + prefetch_factor=2, + memory_format=torch.contiguous_format, + ): + interpolation = { + "bicubic": types.INTERP_CUBIC, + "bilinear": types.INTERP_LINEAR, + "triangular": types.INTERP_TRIANGULAR, + }[interpolation] + + output_layout = "HWC" if memory_format == torch.channels_last else "CHW" + + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ) + + pipeline_kwargs = { + "batch_size": batch_size, + "num_threads": workers, + "device_id": rank % torch.cuda.device_count(), + "seed": 12 + rank % torch.cuda.device_count(), + } + + pipe = training_pipe_external_source( + interpolation=interpolation, + image_size=image_size, + output_layout=output_layout, + automatic_augmentation=augmentation, + dali_device=dali_device, + prefetch_queue_depth=8, + **pipeline_kwargs, + ) + + dali_server = dali_proxy.DALIServer(pipe) + + train_dataset = datasets.ImageFolder( + os.path.join(data_path, "train"), + transform=dali_server.proxy, + loader=read_filepath, + ) + + if torch.distributed.is_initialized(): + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, shuffle=True + ) + else: + train_sampler = None + + train_loader = dali_proxy.DataLoader( + dali_server, + train_dataset, + sampler=train_sampler, + batch_size=batch_size, + shuffle=(train_sampler is None), + num_workers=workers, + worker_init_fn=_worker_init_fn, + pin_memory=True, + collate_fn=None, + drop_last=True, + persistent_workers=True, + prefetch_factor=prefetch_factor, + ) + + return ( + PrefetchedWrapper( + train_loader, start_epoch, num_classes, one_hot, memory_format + ), + len(train_loader), + ) + + return get_impl + + +def get_dali_proxy_val_loader(dali_device="gpu"): + def get_impl( + data_path, + image_size, + batch_size, + num_classes, + one_hot, + interpolation="bilinear", + workers=5, + _worker_init_fn=None, + crop_padding=32, + memory_format=torch.contiguous_format, + prefetch_factor=2, + ): + interpolation = { + "bicubic": types.INTERP_CUBIC, + "bilinear": types.INTERP_LINEAR, + "triangular": types.INTERP_TRIANGULAR, + }[interpolation] + + output_layout = "HWC" if memory_format == torch.channels_last else "CHW" + + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ) + pipeline_kwargs = { + "batch_size": batch_size, + "num_threads": workers, + "device_id": rank % torch.cuda.device_count(), + "seed": 12 + rank % torch.cuda.device_count(), + } + + pipe = validation_pipe_external_source( + interpolation=interpolation, + image_size=image_size + crop_padding, + image_crop=image_size, + output_layout=output_layout, + **pipeline_kwargs, + ) + + dali_server = dali_proxy.DALIServer(pipe) + val_dataset = datasets.ImageFolder( + os.path.join(data_path, "val"), + transform=dali_server.proxy, + loader=read_filepath, + ) + + if torch.distributed.is_initialized(): + val_sampler = torch.utils.data.distributed.DistributedSampler( + val_dataset, shuffle=False + ) + else: + val_sampler = None + + val_loader = dali_proxy.DataLoader( + dali_server, + val_dataset, + sampler=val_sampler, + batch_size=batch_size, + shuffle=(val_sampler is None), + num_workers=workers, + worker_init_fn=_worker_init_fn, + pin_memory=True, + collate_fn=None, + drop_last=True, + persistent_workers=True, + prefetch_factor=prefetch_factor, + ) + + return ( + PrefetchedWrapper( + val_loader, 0, num_classes, one_hot, memory_format + ), + len(val_loader), + ) + + return get_impl + + class SynteticDataLoader(object): def __init__( self, diff --git a/docs/examples/use_cases/pytorch/efficientnet/main.py b/docs/examples/use_cases/pytorch/efficientnet/main.py index 43c6eacfa66..02c31eb303e 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/main.py +++ b/docs/examples/use_cases/pytorch/efficientnet/main.py @@ -35,7 +35,7 @@ import argparse import random -from copy import deepcopy +from contextlib import nullcontext import torch.backends.cudnn as cudnn import torch.distributed as dist @@ -526,6 +526,11 @@ def _worker_init_fn(id): elif args.data_backend == "dali": get_train_loader = get_dali_train_loader(dali_device=args.dali_device) get_val_loader = get_dali_val_loader() + elif args.data_backend == "dali_proxy": + get_train_loader = get_dali_proxy_train_loader( + dali_device=args.dali_device + ) + get_val_loader = get_dali_proxy_val_loader() elif args.data_backend == "synthetic": get_val_loader = get_synthetic_loader get_train_loader = get_synthetic_loader @@ -648,30 +653,42 @@ def main(args, model_args, model_arch): best_prec1, ) = prepare_for_training(args, model_args, model_arch) - train_loop( - trainer, - lr_policy, - train_loader, - train_loader_len, - val_loader, - logger, - start_epoch=start_epoch, - end_epoch=( - min((start_epoch + args.run_epochs), args.epochs) - if args.run_epochs != -1 - else args.epochs - ), - early_stopping_patience=args.early_stopping_patience, - best_prec1=best_prec1, - prof=args.prof, - skip_training=args.evaluate, - skip_validation=args.training_only, - save_checkpoints=args.save_checkpoints and not args.evaluate, - checkpoint_dir=args.workspace, - checkpoint_filename=args.checkpoint_filename, - keep_last_n_checkpoints=args.gather_checkpoints, - topk=args.topk, - ) + def get_ctx(loader): + """ + Get context from a dataloader object. This is a utility so that we can run with the + same code for DALI iterators, PyTorch dataloader, or DALI proxy dataloader. + """ + if isinstance(loader, dali_proxy.DataLoader): + return loader.dali_server + if hasattr(loader, "dataloader"): + return get_ctx(loader.dataloader) + return nullcontext() + + with get_ctx(train_loader), get_ctx(val_loader): + train_loop( + trainer, + lr_policy, + train_loader, + train_loader_len, + val_loader, + logger, + start_epoch=start_epoch, + end_epoch=( + min((start_epoch + args.run_epochs), args.epochs) + if args.run_epochs != -1 + else args.epochs + ), + early_stopping_patience=args.early_stopping_patience, + best_prec1=best_prec1, + prof=args.prof, + skip_training=args.evaluate, + skip_validation=args.training_only, + save_checkpoints=args.save_checkpoints and not args.evaluate, + checkpoint_dir=args.workspace, + checkpoint_filename=args.checkpoint_filename, + keep_last_n_checkpoints=args.gather_checkpoints, + topk=args.topk, + ) exp_duration = time.time() - exp_start_time if ( not torch.distributed.is_initialized() diff --git a/qa/TL3_EfficientNet_benchmark/test_pytorch.sh b/qa/TL3_EfficientNet_benchmark/test_pytorch.sh index a09db046393..b643a65eef5 100644 --- a/qa/TL3_EfficientNet_benchmark/test_pytorch.sh +++ b/qa/TL3_EfficientNet_benchmark/test_pytorch.sh @@ -1,6 +1,6 @@ #!/bin/bash -e -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -57,9 +57,10 @@ export RESULT_WORKSPACE=./ # synthetic benchmark python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --epochs 1 --prof 1000 --no-checkpoints --training-only --data-backend synthetic --workspace $RESULT_WORKSPACE --report-file bench_report_synthetic.json $PATH_TO_IMAGENET +# ----- # DALI without automatic augmentations -python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_dali.json $PATH_TO_IMAGENET +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_dali.json $PATH_TO_IMAGENET # DALI with AutoAugment python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_aa.json $PATH_TO_IMAGENET @@ -67,6 +68,15 @@ python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 - # DALI with TrivialAugment python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali --automatic-augmentation trivialaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_ta.json $PATH_TO_IMAGENET +# DALI proxy without automatic augmentations +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali_proxy --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_dali_proxy.json $PATH_TO_IMAGENET + +# DALI proxy with AutoAugment +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali_proxy --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_proxy_aa.json $PATH_TO_IMAGENET + +# DALI proxy with TrivialAugment +python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 13 --epochs 3 --no-checkpoints --training-only --data-backend dali_proxy --automatic-augmentation trivialaugment --workspace $RESULT_WORKSPACE --report-file bench_report_dali_proxy_ta.json $PATH_TO_IMAGENET + # PyTorch without automatic augmentations python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 10 --epochs 3 --no-checkpoints --training-only --data-backend pytorch --automatic-augmentation disabled --workspace $RESULT_WORKSPACE --report-file bench_report_pytorch.json $PATH_TO_IMAGENET @@ -79,6 +89,7 @@ python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 - # Optimized PyTorch with AutoAugment: python multiproc.py --nproc_per_node 8 ./main.py --amp --static-loss-scale 128 --batch-size 512 --workers 10 --epochs 3 --no-checkpoints --training-only --data-backend pytorch_optimized --automatic-augmentation autoaugment --workspace $RESULT_WORKSPACE --report-file bench_report_optimized_pytorch_aa.json $PATH_TO_IMAGENET +# ----- # The line below finds the lines with `train.total_ips`, takes the last one (with the result we # want) cuts the DLLL (this is highly useful for JSON parsing) from the JSON logs, and parses it @@ -90,6 +101,9 @@ SYNTH_THRESHOLD=38000 DALI_NONE_THRESHOLD=32000 DALI_AA_THRESHOLD=32000 DALI_TA_THRESHOLD=32000 +DALI_PROXY_NONE_THRESHOLD=32000 +DALI_PROXY_AA_THRESHOLD=32000 +DALI_PROXY_TA_THRESHOLD=32000 PYTORCH_NONE_THRESHOLD=32000 PYTORCH_AA_THRESHOLD=32000 @@ -111,6 +125,9 @@ CHECK_PERF_THRESHOLD "bench_report_synthetic.json" $SYNTH_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_dali.json" $DALI_NONE_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_dali_aa.json" $DALI_AA_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_dali_ta.json" $DALI_TA_THRESHOLD +CHECK_PERF_THRESHOLD "bench_report_dali_proxy.json" $DALI_PROXY_NONE_THRESHOLD +CHECK_PERF_THRESHOLD "bench_report_dali_proxy_aa.json" $DALI_PROXY_AA_THRESHOLD +CHECK_PERF_THRESHOLD "bench_report_dali_proxy_ta.json" $DALI_PROXY_TA_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_pytorch.json" $PYTORCH_NONE_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_pytorch_aa.json" $PYTORCH_AA_THRESHOLD CHECK_PERF_THRESHOLD "bench_report_optimized_pytorch.json" $PYTORCH_NONE_THRESHOLD