From d1e100e50e33b71fbaf4b3c4dd2a7fe54ab42e1b Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Mon, 12 Aug 2019 19:06:50 +0200 Subject: [PATCH] Fix fp16 bug from #1129 and add fp16 test case (#1160) Initializing member variable `OutputType padding_value = 0;` for `__half` in the host code produced wrong results. `padding_val` is now a fp32 that is casted to `OuputType` in the CUDA kernel Signed-off-by: Serge Panev --- .../slice_flip_normalize_permute_common.h | 16 +++--- .../slice/slice_flip_normalize_permute_cpu.h | 2 +- .../slice/slice_flip_normalize_permute_gpu.h | 18 +++---- ...slice_flip_normalize_permute_kernel_test.h | 54 +++++++++---------- ...e_flip_normalize_permute_pad_cuda_impl.cuh | 8 +-- .../operators/fused/crop_mirror_normalize.cc | 2 +- .../operators/fused/crop_mirror_normalize.cu | 2 +- .../test_operator_crop_mirror_normalize.py | 2 +- 8 files changed, 52 insertions(+), 52 deletions(-) diff --git a/dali/kernels/slice/slice_flip_normalize_permute_common.h b/dali/kernels/slice/slice_flip_normalize_permute_common.h index 772e7aafda..679a78846a 100644 --- a/dali/kernels/slice/slice_flip_normalize_permute_common.h +++ b/dali/kernels/slice/slice_flip_normalize_permute_common.h @@ -24,7 +24,7 @@ namespace dali { namespace kernels { -template +template struct SliceFlipNormalizePermutePadArgs { template explicit SliceFlipNormalizePermutePadArgs(const Shape &_shape) { @@ -46,12 +46,12 @@ struct SliceFlipNormalizePermutePadArgs { size_t normalization_index = 0; std::vector mean; std::vector inv_stddev; - OutputType padding_val = 0; + float padding_val = 0.0f; }; namespace detail { -template +template struct SliceFlipNormalizePermutePadProcessedArgs { size_t input_offset; std::array in_strides; @@ -61,7 +61,7 @@ struct SliceFlipNormalizePermutePadProcessedArgs { std::vector mean; std::vector inv_stddev; size_t normalization_dim; - OutputType padding_val = 0; + float padding_val = 0.0f; }; template @@ -83,11 +83,11 @@ std::array inverse_permutation(const std::array &p return inv_perm; } -template -SliceFlipNormalizePermutePadProcessedArgs ProcessArgs( - const SliceFlipNormalizePermutePadArgs &args, +template +SliceFlipNormalizePermutePadProcessedArgs ProcessArgs( + const SliceFlipNormalizePermutePadArgs &args, const Shape &in_shape) { - SliceFlipNormalizePermutePadProcessedArgs processed_args; + SliceFlipNormalizePermutePadProcessedArgs processed_args; processed_args.input_offset = 0; processed_args.in_strides = GetStrides(in_shape); diff --git a/dali/kernels/slice/slice_flip_normalize_permute_cpu.h b/dali/kernels/slice/slice_flip_normalize_permute_cpu.h index fcd5fa8bb9..0073f822ba 100644 --- a/dali/kernels/slice/slice_flip_normalize_permute_cpu.h +++ b/dali/kernels/slice/slice_flip_normalize_permute_cpu.h @@ -188,7 +188,7 @@ void SliceFlipNormalizePermute(OutputType *output, const InputType *input, template class SliceFlipNormalizePermuteCPU { public: - using Args = SliceFlipNormalizePermutePadArgs; + using Args = SliceFlipNormalizePermutePadArgs; KernelRequirements Setup(KernelContext &context, const InTensorCPU &in, diff --git a/dali/kernels/slice/slice_flip_normalize_permute_gpu.h b/dali/kernels/slice/slice_flip_normalize_permute_gpu.h index 227d5ef070..33f0d982c7 100644 --- a/dali/kernels/slice/slice_flip_normalize_permute_gpu.h +++ b/dali/kernels/slice/slice_flip_normalize_permute_gpu.h @@ -38,15 +38,15 @@ class SliceFlipNormalizePermutePadGPU { size_t block_count_ = 0; public: - using Args = SliceFlipNormalizePermutePadArgs; + using Args = SliceFlipNormalizePermutePadArgs; KernelRequirements Setup(KernelContext &context, const InListGPU &in, const std::vector &args) { KernelRequirements req; ScratchpadEstimator se; const size_t num_samples = in.size(); - se.add>(AllocType::Host, num_samples); - se.add>(AllocType::GPU, num_samples); + se.add>(AllocType::Host, num_samples); + se.add>(AllocType::GPU, num_samples); DALI_ENFORCE(args[0].mean.size() == args[0].inv_stddev.size()); size_t norm_args_size = args[0].mean.size(); @@ -88,9 +88,9 @@ class SliceFlipNormalizePermutePadGPU { auto inv_stddev_data = args[0].inv_stddev; DALI_ENFORCE(mean_data.size() == inv_stddev_data.size()); - detail::SampleDesc* sample_descs_cpu = - context.scratchpad->Allocate>(AllocType::Host, - num_samples); + detail::SampleDesc* sample_descs_cpu = + context.scratchpad->Allocate>(AllocType::Host, + num_samples); float *norm_add_cpu = mean_data.empty() ? nullptr : context.scratchpad->Allocate(AllocType::Host, mean_data.size()); float *norm_mul_cpu = inv_stddev_data.empty() ? nullptr : @@ -135,8 +135,8 @@ class SliceFlipNormalizePermutePadGPU { } } - detail::SampleDesc *sample_descs = - context.scratchpad->Allocate>( + detail::SampleDesc *sample_descs = + context.scratchpad->Allocate>( AllocType::GPU, num_samples); float *norm_add = mean_data.empty() ? nullptr : @@ -152,7 +152,7 @@ class SliceFlipNormalizePermutePadGPU { AllocType::GPU, block_count_); // Memory is allocated contiguously, so we launch only one cudaMemcpyAsync - size_t total_bytes = num_samples * sizeof(detail::SampleDesc) + size_t total_bytes = num_samples * sizeof(detail::SampleDesc) + mean_data.size() * sizeof(float) + inv_stddev_data.size() * sizeof(float) + block_count_ * sizeof(detail::BlockDesc); diff --git a/dali/kernels/slice/slice_flip_normalize_permute_kernel_test.h b/dali/kernels/slice/slice_flip_normalize_permute_kernel_test.h index 48135bf4d9..41c3d5d0c6 100644 --- a/dali/kernels/slice/slice_flip_normalize_permute_kernel_test.h +++ b/dali/kernels/slice/slice_flip_normalize_permute_kernel_test.h @@ -35,7 +35,7 @@ class SliceFlipNormalizePermuteTest : public ::testing::Test { static constexpr size_t DimSize0 = TestArgs::DimSize0; static constexpr size_t DimSize1 = TestArgs::DimSize1; using ArgsGenerator = typename TestArgs::ArgsGenerator; - using KernelArgs = SliceFlipNormalizePermutePadArgs; + using KernelArgs = SliceFlipNormalizePermutePadArgs; void PrepareData(TestTensorList& test_data) { std::vector sample_dims(Dims, DimSize); @@ -145,30 +145,30 @@ class SliceFlipNormalizePermuteTest : public ::testing::Test { template struct SliceFlipNormPermArgsGen_CopyOnly { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); return args; } }; template struct SliceFlipNormPermArgsGen_SliceOnly { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { auto shape = input_shape; shape[0] /= 2; shape[1] /= 2; - SliceFlipNormalizePermutePadArgs args(shape); + SliceFlipNormalizePermutePadArgs args(shape); return args; } }; template struct SliceFlipNormPermArgsGen_SliceOnly_WithAnchor { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { auto shape = input_shape; shape[0] = input_shape[0]/2; shape[1] = input_shape[0]/2; - SliceFlipNormalizePermutePadArgs args(shape); + SliceFlipNormalizePermutePadArgs args(shape); args.anchor[0] = input_shape[0]/2; args.anchor[1] = input_shape[1]/2; return args; @@ -177,8 +177,8 @@ struct SliceFlipNormPermArgsGen_SliceOnly_WithAnchor { template struct SliceFlipNormPermArgsGen_FlipHW { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); // assuming last dims are HWC, flip H and W args.flip[Dims-2] = true; args.flip[Dims-3] = true; @@ -188,8 +188,8 @@ struct SliceFlipNormPermArgsGen_FlipHW { template struct SliceFlipNormPermArgsGen_FlipDim { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); args.flip[FlipDim] = true; return args; } @@ -197,8 +197,8 @@ struct SliceFlipNormPermArgsGen_FlipDim { template struct SliceFlipNormPermArgsGen_NormalizeOnly { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); args.mean.resize(args.shape[Dims-1]); args.inv_stddev.resize(args.shape[Dims-1]); for (int i = 0; i < args.shape[Dims-1]; i++) { @@ -211,8 +211,8 @@ struct SliceFlipNormPermArgsGen_NormalizeOnly { template struct SliceFlipNormPermArgsGen_NormalizeOnly_Scalar { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); args.mean = { 3.5f }; args.inv_stddev = { 1.f / 8.0f }; return args; @@ -221,8 +221,8 @@ struct SliceFlipNormPermArgsGen_NormalizeOnly_Scalar { template struct SliceFlipNormPermArgsGen_NormalizeAndFlipDim { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); args.flip[FlipDim] = true; args.mean.resize(args.shape[Dims-1], 3.5f); args.inv_stddev.resize(args.shape[Dims-1], 1.0/3.5f); @@ -233,8 +233,8 @@ struct SliceFlipNormPermArgsGen_NormalizeAndFlipDim { template struct SliceFlipNormPermArgsGen_PermuteOnly_ReversedDims { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); for (size_t d = 0; d < Dims; d++) { args.permuted_dims[d] = Dims-1-d; } @@ -244,8 +244,8 @@ struct SliceFlipNormPermArgsGen_PermuteOnly_ReversedDims { template struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_ReversedDims { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); for (size_t d = 0; d < Dims; d++) { args.anchor[d] = input_shape[d]/4; args.shape[d] = args.padded_shape[d] = input_shape[d]/2; @@ -257,8 +257,8 @@ struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_ReversedDims { template struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_PermuteHW { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); for (size_t d = 0; d < Dims; d++) { args.anchor[d] = input_shape[d]/4; args.shape[d] = args.padded_shape[d] = input_shape[d]/2; @@ -280,8 +280,8 @@ struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_PermuteHW { template struct SliceFlipNormPermArgsGen_SliceFlipNormalizePermute_PermuteHWC2CHW { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); for (size_t d = 0; d < Dims; d++) { args.anchor[d] = d == 0 || d == 1 ? input_shape[d]/2 : 0; @@ -311,8 +311,8 @@ struct SliceFlipNormPermArgsGen_SliceFlipNormalizePermute_PermuteHWC2CHW { template struct SliceFlipNormPermArgsGen_OnlyPad_GivenDim { - SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { - SliceFlipNormalizePermutePadArgs args(input_shape); + SliceFlipNormalizePermutePadArgs Get(const TensorShape& input_shape) { + SliceFlipNormalizePermutePadArgs args(input_shape); args.padded_shape[PaddedDim] += PadSize; return args; } diff --git a/dali/kernels/slice/slice_flip_normalize_permute_pad_cuda_impl.cuh b/dali/kernels/slice/slice_flip_normalize_permute_pad_cuda_impl.cuh index 7b323aebd4..da5069b60c 100644 --- a/dali/kernels/slice/slice_flip_normalize_permute_pad_cuda_impl.cuh +++ b/dali/kernels/slice/slice_flip_normalize_permute_pad_cuda_impl.cuh @@ -33,7 +33,7 @@ namespace kernels { namespace detail { -template +template struct SampleDesc { void *__restrict__ out; const void *__restrict__ in; @@ -41,7 +41,7 @@ struct SampleDesc { DeviceArray out_strides; DeviceArray out_shape; DeviceArray padded_out_shape; - OutputType padding_val; + float padding_val; }; struct BlockDesc { @@ -123,7 +123,7 @@ __device__ inline void SliceFlipNormalizePermutePadFunc(OutputType *__restrict__ } template -__global__ void SliceFlipNormalizePermutePadKernel(const SampleDesc *samples, +__global__ void SliceFlipNormalizePermutePadKernel(const SampleDesc *samples, const BlockDesc *blocks, const float *norm_add, const float *norm_mul, @@ -146,7 +146,7 @@ __global__ void SliceFlipNormalizePermutePadKernel(const SampleDesc(sample.padding_val), offset, block_end); } } // namespace detail diff --git a/dali/pipeline/operators/fused/crop_mirror_normalize.cc b/dali/pipeline/operators/fused/crop_mirror_normalize.cc index f513278c60..d6fcc07cc0 100644 --- a/dali/pipeline/operators/fused/crop_mirror_normalize.cc +++ b/dali/pipeline/operators/fused/crop_mirror_normalize.cc @@ -73,7 +73,7 @@ void RunHelper(Tensor &output, VALUE_SWITCH(number_of_dims, Dims, (3, 4), ( auto in_view = view(input); - kernels::SliceFlipNormalizePermutePadArgs args(slice_shape); + kernels::SliceFlipNormalizePermutePadArgs args(slice_shape); for (std::size_t d = 0; d < Dims; d++) { args.anchor[d] = slice_anchor[d]; } diff --git a/dali/pipeline/operators/fused/crop_mirror_normalize.cu b/dali/pipeline/operators/fused/crop_mirror_normalize.cu index 202c856a49..229393f28a 100644 --- a/dali/pipeline/operators/fused/crop_mirror_normalize.cu +++ b/dali/pipeline/operators/fused/crop_mirror_normalize.cu @@ -43,7 +43,7 @@ void RunHelper(TensorList &output, ctx.gpu.stream = stream; auto in_view = view(input); - std::vector> per_sample_args; + std::vector> per_sample_args; per_sample_args.reserve(slice_anchors.size()); for (std::size_t i = 0; i < slice_anchors.size(); i++) { per_sample_args.emplace_back(slice_shapes[i]); diff --git a/dali/test/python/test_operator_crop_mirror_normalize.py b/dali/test/python/test_operator_crop_mirror_normalize.py index 6a5562ae78..94e32d896c 100644 --- a/dali/test/python/test_operator_crop_mirror_normalize.py +++ b/dali/test/python/test_operator_crop_mirror_normalize.py @@ -71,7 +71,7 @@ def check_cmn_cpu_vs_gpu(batch_size, output_dtype, output_layout, mirror_probabi def test_cmn_cpu_vs_gpu(): for batch_size in [1, 8]: - for output_dtype in [types.FLOAT, types.INT32]: + for output_dtype in [types.FLOAT, types.INT32, types.FLOAT16]: for output_layout in [types.NHWC, types.NCHW]: for mirror_probability in [0.0, 0.5, 1.0]: norm_data = [ ([0., 0., 0.], [1., 1., 1.]),