Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Jan 23, 2025
1 parent 81946f1 commit 154ee55
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nvidia.dali.auto_aug import auto_augment, trivial_augment


def resnet_processing_training(
def efficientnet_processing_training(
jpegs_input,
interpolation,
image_size,
Expand Down Expand Up @@ -100,7 +100,7 @@ def training_pipe(
random_shuffle=True,
pad_last_batch=True,
)
outputs = resnet_processing_training(
outputs = efficientnet_processing_training(
jpegs,
interpolation,
image_size,
Expand All @@ -123,7 +123,7 @@ def training_pipe_external_source(
):
filepaths = fn.external_source(name="images", no_copy=True)
jpegs = fn.io.file.read(filepaths)
outputs = resnet_processing_training(
outputs = efficientnet_processing_training(
jpegs,
interpolation,
image_size,
Expand All @@ -134,7 +134,7 @@ def training_pipe_external_source(
return outputs


def resnet_processing_validation(
def efficientnet_processing_validation(
jpegs, interpolation, image_size, image_crop, output_layout
):
"""
Expand Down Expand Up @@ -178,7 +178,7 @@ def validation_pipe(
random_shuffle=False,
pad_last_batch=True,
)
outputs = resnet_processing_validation(
outputs = efficientnet_processing_validation(
jpegs, interpolation, image_size, image_crop, output_layout
)
return outputs, label
Expand All @@ -190,7 +190,7 @@ def validation_pipe_external_source(
):
filepaths = fn.external_source(name="images", no_copy=True)
jpegs = fn.io.file.read(filepaths)
outputs = resnet_processing_validation(
outputs = efficientnet_processing_validation(
jpegs, interpolation, image_size, image_crop, output_layout
)
return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def load_jpeg_from_file(path, cuda=True):


class DALIWrapper(object):

@staticmethod
def gen_wrapper(loader, num_classes, one_hot, memory_format):
for data in loader:
if memory_format == torch.channels_last:
Expand Down

0 comments on commit 154ee55

Please sign in to comment.