Skip to content

Commit

Permalink
Ensure allocated temp memory is usable by nvImageCodec streams, as we…
Browse files Browse the repository at this point in the history
… are skipping pre-sync due to unnecessary overhead in the general case

Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Feb 4, 2025
1 parent cefc16b commit 08282b3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 54 deletions.
13 changes: 3 additions & 10 deletions dali/operators/imgcodec/image_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,15 +552,6 @@ class ImageDecoder : public StatelessOperator<Backend> {
MAKE_SEMANTIC_VERSION(req_major, req_minor, req_patch);
}

/**
* @brief nvImageCodec up to 0.2 doesn't synchronize with the user stream before decoding.
* Because of that, we need to host synchronize before passing the async allocated buffer
* to the decoding function
*/
bool need_host_sync_alloc() {
return !version_at_least(0, 3, 0);
}

void PrepareOutput(SampleState &st, void *out_ptr, const ROI &roi, const Workspace &ws) {
// Make a copy of the parsed img info. We might modify it
// (for example, request planar vs. interleaved, etc)
Expand Down Expand Up @@ -794,7 +785,9 @@ class ImageDecoder : public StatelessOperator<Backend> {
size_t nsamples_decode = batch_images_.size();
size_t nsamples_cache = nsamples - nsamples_decode;

if (ws.has_stream() && need_host_sync_alloc() && any_need_processing) {
// Ensure allocated memory is usable by the decoder's internal streams,
// as we are intentionally skipping pre-sync to avoid slowing down the general case.
if (ws.has_stream() && any_need_processing) {
DomainTimeRange tr("alloc sync", DomainTimeRange::kOrange);
CUDA_CALL(cudaStreamSynchronize(ws.stream()));
}
Expand Down
9 changes: 4 additions & 5 deletions docs/examples/use_cases/pytorch/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,12 @@ def parse():
'"dali" for DALI data loader, or "dali_proxy" for PyTorch dataloader with DALI proxy preprocessing.')
parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic',
help="If enabled, random seeds are fixed to ensure deterministic results for reproducibility.",
action='store_true')
parser.add_argument('--deterministic', action='store_true')

parser.add_argument('--fp16-mode', default=False, action='store_true',
help='Enable half precision mode.')
parser.add_argument('--loss-scale', type=float, help="Loss scaling factor for mixed precision training. Default is 1.", default=1)
parser.add_argument('--channels-last', type=bool, help="Use channels-last memory format for model and data. Default is False.", default=False)
parser.add_argument('--loss-scale', type=float, default=1)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument('-t', '--test', action='store_true',
help='Launch test mode with preset arguments')
args = parser.parse_args()
Expand Down
65 changes: 26 additions & 39 deletions docs/examples/use_cases/pytorch/resnet50/pytorch-resnet50.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,42 +54,29 @@ Usage
PyTorch ImageNet Training
positional arguments:
DIR path(s) to dataset (if one path is provided, it is assumed to have subdirectories named "train" and "val"; alternatively, train and val paths can be specified
directly by providing both paths as arguments)
options:
-h, --help show this help message and exit
--arch ARCH, -a ARCH model architecture: alexnet | convnext_base | convnext_large | convnext_small | convnext_tiny | densenet121 | densenet161 | densenet169 | densenet201 |
efficientnet_b0 | efficientnet_b1 | efficientnet_b2 | efficientnet_b3 | efficientnet_b4 | efficientnet_b5 | efficientnet_b6 | efficientnet_b7 | efficientnet_v2_l |
efficientnet_v2_m | efficientnet_v2_s | get_model | get_model_builder | get_model_weights | get_weight | googlenet | inception_v3 | list_models | maxvit_t |
mnasnet0_5 | mnasnet0_75 | mnasnet1_0 | mnasnet1_3 | mobilenet_v2 | mobilenet_v3_large | mobilenet_v3_small | regnet_x_16gf | regnet_x_1_6gf | regnet_x_32gf |
regnet_x_3_2gf | regnet_x_400mf | regnet_x_800mf | regnet_x_8gf | regnet_y_128gf | regnet_y_16gf | regnet_y_1_6gf | regnet_y_32gf | regnet_y_3_2gf | regnet_y_400mf
| regnet_y_800mf | regnet_y_8gf | resnet101 | resnet152 | resnet18 | resnet34 | resnet50 | resnext101_32x8d | resnext101_64x4d | resnext50_32x4d |
shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 | squeezenet1_0 | squeezenet1_1 | swin_b | swin_s | swin_t | swin_v2_b | swin_v2_s
| swin_v2_t | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19 | vgg19_bn | vit_b_16 | vit_b_32 | vit_h_14 | vit_l_16 | vit_l_32 | wide_resnet101_2 |
wide_resnet50_2 (default: resnet18)
-j N, --workers N number of data loading workers (default: 4)
--epochs N number of total epochs to run
--start-epoch N manual epoch number (useful on restarts)
-b N, --batch-size N mini-batch size per process (default: 256)
--lr LR, --learning-rate LR
Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be
applied over the first 5 epochs.
--momentum M momentum
--weight-decay W, --wd W
weight decay (default: 1e-4)
--print-freq N, -p N print frequency (default: 10)
--resume PATH path to latest checkpoint (default: none)
-e, --evaluate evaluate model on validation set
--pretrained use pre-trained model
--dali_cpu Runs CPU based version of DALI pipeline.
--data_loader {pytorch,dali,dali_proxy}
Select data loader: "pytorch" for native PyTorch data loader, "dali" for DALI data loader, or "dali_proxy" for PyTorch dataloader with DALI proxy preprocessing.
--prof PROF Only run 10 iterations for profiling.
--deterministic If enabled, random seeds are fixed to ensure deterministic results for reproducibility.
--fp16-mode Enable half precision mode.
--loss-scale LOSS_SCALE
Loss scaling factor for mixed precision training. Default is 1.
--channels-last CHANNELS_LAST
Use channels-last memory format for model and data. Default is False.
-t, --test Launch test mode with preset arguments
DIR path(s) to dataset (if one path is provided, it is assumed to have subdirectories named "train" and "val"; alternatively, train and val paths can be specified directly by providing both paths as arguments)
optional arguments (for the full list please check `Apex ImageNet example
<https://github.com/NVIDIA/apex/tree/master/examples/imagenet>`_)
-h, --help show this help message and exit
--arch ARCH, -a ARCH model architecture: alexnet | resnet | resnet101
| resnet152 | resnet18 | resnet34 | resnet50 | vgg
| vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16
| vgg16_bn | vgg19 | vgg19_bn (default: resnet18)
-j N, --workers N number of data loading workers (default: 4)
--epochs N number of total epochs to run
--start-epoch N manual epoch number (useful on restarts)
-b N, --batch-size N mini-batch size (default: 256)
--lr LR, --learning-rate LR initial learning rate
--momentum M momentum
--weight-decay W, --wd W weight decay (default: 1e-4)
--print-freq N, -p N print frequency (default: 10)
--resume PATH path to latest checkpoint (default: none)
-e, --evaluate evaluate model on validation set
--pretrained use pre-trained model
--dali_cpu use CPU based pipeline for DALI, for heavy GPU
networks it may work better, for IO bottlenecked
one like RN18 GPU default should be faster
--data_loader Select data loader: "pytorch" for native PyTorch data loader,
"dali" for DALI data loader, or "dali_proxy" for PyTorch dataloader with DALI proxy preprocessing.
--fp16-mode enables mixed precision mode

0 comments on commit 08282b3

Please sign in to comment.