Skip to content

Commit

Permalink
Fix fp16 bug from #1129 and add fp16 test case (#1160)
Browse files Browse the repository at this point in the history
Initializing member variable `OutputType padding_value = 0;` for `__half` in the host code produced wrong results.

`padding_val` is now a fp32 that is casted to `OuputType` in the CUDA kernel


Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored and klecki committed Aug 12, 2019
1 parent 5f4d60f commit d1e100e
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 52 deletions.
16 changes: 8 additions & 8 deletions dali/kernels/slice/slice_flip_normalize_permute_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace dali {
namespace kernels {

template <size_t Dims, typename OutputType>
template <size_t Dims>
struct SliceFlipNormalizePermutePadArgs {
template <typename Shape>
explicit SliceFlipNormalizePermutePadArgs(const Shape &_shape) {
Expand All @@ -46,12 +46,12 @@ struct SliceFlipNormalizePermutePadArgs {
size_t normalization_index = 0;
std::vector<float> mean;
std::vector<float> inv_stddev;
OutputType padding_val = 0;
float padding_val = 0.0f;
};

namespace detail {

template <size_t Dims, typename OutputType>
template <size_t Dims>
struct SliceFlipNormalizePermutePadProcessedArgs {
size_t input_offset;
std::array<int64_t, Dims> in_strides;
Expand All @@ -61,7 +61,7 @@ struct SliceFlipNormalizePermutePadProcessedArgs {
std::vector<float> mean;
std::vector<float> inv_stddev;
size_t normalization_dim;
OutputType padding_val = 0;
float padding_val = 0.0f;
};

template <size_t Dims, typename Container>
Expand All @@ -83,11 +83,11 @@ std::array<int64_t, Dims> inverse_permutation(const std::array<int64_t, Dims> &p
return inv_perm;
}

template <size_t Dims, typename Shape, typename OutputType>
SliceFlipNormalizePermutePadProcessedArgs<Dims, OutputType> ProcessArgs(
const SliceFlipNormalizePermutePadArgs<Dims, OutputType> &args,
template <size_t Dims, typename Shape>
SliceFlipNormalizePermutePadProcessedArgs<Dims> ProcessArgs(
const SliceFlipNormalizePermutePadArgs<Dims> &args,
const Shape &in_shape) {
SliceFlipNormalizePermutePadProcessedArgs<Dims, OutputType> processed_args;
SliceFlipNormalizePermutePadProcessedArgs<Dims> processed_args;

processed_args.input_offset = 0;
processed_args.in_strides = GetStrides<Dims>(in_shape);
Expand Down
2 changes: 1 addition & 1 deletion dali/kernels/slice/slice_flip_normalize_permute_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ void SliceFlipNormalizePermute(OutputType *output, const InputType *input,
template <typename OutputType, typename InputType, size_t Dims>
class SliceFlipNormalizePermuteCPU {
public:
using Args = SliceFlipNormalizePermutePadArgs<Dims, OutputType>;
using Args = SliceFlipNormalizePermutePadArgs<Dims>;

KernelRequirements Setup(KernelContext &context,
const InTensorCPU<InputType, Dims> &in,
Expand Down
18 changes: 9 additions & 9 deletions dali/kernels/slice/slice_flip_normalize_permute_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ class SliceFlipNormalizePermutePadGPU {
size_t block_count_ = 0;

public:
using Args = SliceFlipNormalizePermutePadArgs<Dims, OutputType>;
using Args = SliceFlipNormalizePermutePadArgs<Dims>;
KernelRequirements Setup(KernelContext &context,
const InListGPU<InputType, Dims> &in,
const std::vector<Args> &args) {
KernelRequirements req;
ScratchpadEstimator se;
const size_t num_samples = in.size();
se.add<detail::SampleDesc<Dims, OutputType>>(AllocType::Host, num_samples);
se.add<detail::SampleDesc<Dims, OutputType>>(AllocType::GPU, num_samples);
se.add<detail::SampleDesc<Dims>>(AllocType::Host, num_samples);
se.add<detail::SampleDesc<Dims>>(AllocType::GPU, num_samples);

DALI_ENFORCE(args[0].mean.size() == args[0].inv_stddev.size());
size_t norm_args_size = args[0].mean.size();
Expand Down Expand Up @@ -88,9 +88,9 @@ class SliceFlipNormalizePermutePadGPU {
auto inv_stddev_data = args[0].inv_stddev;
DALI_ENFORCE(mean_data.size() == inv_stddev_data.size());

detail::SampleDesc<Dims, OutputType>* sample_descs_cpu =
context.scratchpad->Allocate<detail::SampleDesc<Dims, OutputType>>(AllocType::Host,
num_samples);
detail::SampleDesc<Dims>* sample_descs_cpu =
context.scratchpad->Allocate<detail::SampleDesc<Dims>>(AllocType::Host,
num_samples);
float *norm_add_cpu = mean_data.empty() ? nullptr :
context.scratchpad->Allocate<float>(AllocType::Host, mean_data.size());
float *norm_mul_cpu = inv_stddev_data.empty() ? nullptr :
Expand Down Expand Up @@ -135,8 +135,8 @@ class SliceFlipNormalizePermutePadGPU {
}
}

detail::SampleDesc<Dims, OutputType> *sample_descs =
context.scratchpad->Allocate<detail::SampleDesc<Dims, OutputType>>(
detail::SampleDesc<Dims> *sample_descs =
context.scratchpad->Allocate<detail::SampleDesc<Dims>>(
AllocType::GPU, num_samples);

float *norm_add = mean_data.empty() ? nullptr :
Expand All @@ -152,7 +152,7 @@ class SliceFlipNormalizePermutePadGPU {
AllocType::GPU, block_count_);

// Memory is allocated contiguously, so we launch only one cudaMemcpyAsync
size_t total_bytes = num_samples * sizeof(detail::SampleDesc<Dims, OutputType>)
size_t total_bytes = num_samples * sizeof(detail::SampleDesc<Dims>)
+ mean_data.size() * sizeof(float)
+ inv_stddev_data.size() * sizeof(float)
+ block_count_ * sizeof(detail::BlockDesc);
Expand Down
54 changes: 27 additions & 27 deletions dali/kernels/slice/slice_flip_normalize_permute_kernel_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SliceFlipNormalizePermuteTest : public ::testing::Test {
static constexpr size_t DimSize0 = TestArgs::DimSize0;
static constexpr size_t DimSize1 = TestArgs::DimSize1;
using ArgsGenerator = typename TestArgs::ArgsGenerator;
using KernelArgs = SliceFlipNormalizePermutePadArgs<Dims, OutputType>;
using KernelArgs = SliceFlipNormalizePermutePadArgs<Dims>;

void PrepareData(TestTensorList<InputType, Dims>& test_data) {
std::vector<int> sample_dims(Dims, DimSize);
Expand Down Expand Up @@ -145,30 +145,30 @@ class SliceFlipNormalizePermuteTest : public ::testing::Test {

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_CopyOnly {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
return args;
}
};

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_SliceOnly {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
auto shape = input_shape;
shape[0] /= 2;
shape[1] /= 2;
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(shape);
SliceFlipNormalizePermutePadArgs<Dims> args(shape);
return args;
}
};

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_SliceOnly_WithAnchor {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
auto shape = input_shape;
shape[0] = input_shape[0]/2;
shape[1] = input_shape[0]/2;
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(shape);
SliceFlipNormalizePermutePadArgs<Dims> args(shape);
args.anchor[0] = input_shape[0]/2;
args.anchor[1] = input_shape[1]/2;
return args;
Expand All @@ -177,8 +177,8 @@ struct SliceFlipNormPermArgsGen_SliceOnly_WithAnchor {

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_FlipHW {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
// assuming last dims are HWC, flip H and W
args.flip[Dims-2] = true;
args.flip[Dims-3] = true;
Expand All @@ -188,17 +188,17 @@ struct SliceFlipNormPermArgsGen_FlipHW {

template <typename OutputType, size_t Dims, size_t FlipDim>
struct SliceFlipNormPermArgsGen_FlipDim {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
args.flip[FlipDim] = true;
return args;
}
};

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_NormalizeOnly {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
args.mean.resize(args.shape[Dims-1]);
args.inv_stddev.resize(args.shape[Dims-1]);
for (int i = 0; i < args.shape[Dims-1]; i++) {
Expand All @@ -211,8 +211,8 @@ struct SliceFlipNormPermArgsGen_NormalizeOnly {

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_NormalizeOnly_Scalar {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
args.mean = { 3.5f };
args.inv_stddev = { 1.f / 8.0f };
return args;
Expand All @@ -221,8 +221,8 @@ struct SliceFlipNormPermArgsGen_NormalizeOnly_Scalar {

template <typename OutputType, size_t Dims, size_t FlipDim>
struct SliceFlipNormPermArgsGen_NormalizeAndFlipDim {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
args.flip[FlipDim] = true;
args.mean.resize(args.shape[Dims-1], 3.5f);
args.inv_stddev.resize(args.shape[Dims-1], 1.0/3.5f);
Expand All @@ -233,8 +233,8 @@ struct SliceFlipNormPermArgsGen_NormalizeAndFlipDim {

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_PermuteOnly_ReversedDims {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
for (size_t d = 0; d < Dims; d++) {
args.permuted_dims[d] = Dims-1-d;
}
Expand All @@ -244,8 +244,8 @@ struct SliceFlipNormPermArgsGen_PermuteOnly_ReversedDims {

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_ReversedDims {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
for (size_t d = 0; d < Dims; d++) {
args.anchor[d] = input_shape[d]/4;
args.shape[d] = args.padded_shape[d] = input_shape[d]/2;
Expand All @@ -257,8 +257,8 @@ struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_ReversedDims {

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_PermuteHW {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
for (size_t d = 0; d < Dims; d++) {
args.anchor[d] = input_shape[d]/4;
args.shape[d] = args.padded_shape[d] = input_shape[d]/2;
Expand All @@ -280,8 +280,8 @@ struct SliceFlipNormPermArgsGen_PermuteAndSliceHalf_PermuteHW {

template <typename OutputType, size_t Dims>
struct SliceFlipNormPermArgsGen_SliceFlipNormalizePermute_PermuteHWC2CHW {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
for (size_t d = 0; d < Dims; d++) {
args.anchor[d] = d == 0 || d == 1 ?
input_shape[d]/2 : 0;
Expand Down Expand Up @@ -311,8 +311,8 @@ struct SliceFlipNormPermArgsGen_SliceFlipNormalizePermute_PermuteHWC2CHW {

template <typename OutputType, size_t Dims, size_t PaddedDim, size_t PadSize>
struct SliceFlipNormPermArgsGen_OnlyPad_GivenDim {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(input_shape);
SliceFlipNormalizePermutePadArgs<Dims> Get(const TensorShape<Dims>& input_shape) {
SliceFlipNormalizePermutePadArgs<Dims> args(input_shape);
args.padded_shape[PaddedDim] += PadSize;
return args;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ namespace kernels {

namespace detail {

template <size_t Dims, typename OutputType>
template <size_t Dims>
struct SampleDesc {
void *__restrict__ out;
const void *__restrict__ in;
DeviceArray<int64_t, Dims> in_strides;
DeviceArray<int64_t, Dims> out_strides;
DeviceArray<int64_t, Dims> out_shape;
DeviceArray<int64_t, Dims> padded_out_shape;
OutputType padding_val;
float padding_val;
};

struct BlockDesc {
Expand Down Expand Up @@ -123,7 +123,7 @@ __device__ inline void SliceFlipNormalizePermutePadFunc(OutputType *__restrict__
}

template <typename OutputType, typename InputType, size_t Dims, bool should_normalize>
__global__ void SliceFlipNormalizePermutePadKernel(const SampleDesc<Dims, OutputType> *samples,
__global__ void SliceFlipNormalizePermutePadKernel(const SampleDesc<Dims> *samples,
const BlockDesc *blocks,
const float *norm_add,
const float *norm_mul,
Expand All @@ -146,7 +146,7 @@ __global__ void SliceFlipNormalizePermutePadKernel(const SampleDesc<Dims, Output
out, in, sample.out_strides.data(), sample.in_strides.data(),
sample.out_shape.data(), sample.padded_out_shape.data(),
should_pad, normalization_dim, norm_add, norm_mul,
sample.padding_val, offset, block_end);
static_cast<OutputType>(sample.padding_val), offset, block_end);
}

} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/operators/fused/crop_mirror_normalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void RunHelper(Tensor<CPUBackend> &output,
VALUE_SWITCH(number_of_dims, Dims, (3, 4), (
auto in_view = view<const InputType, Dims>(input);

kernels::SliceFlipNormalizePermutePadArgs<Dims, OutputType> args(slice_shape);
kernels::SliceFlipNormalizePermutePadArgs<Dims> args(slice_shape);
for (std::size_t d = 0; d < Dims; d++) {
args.anchor[d] = slice_anchor[d];
}
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/operators/fused/crop_mirror_normalize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void RunHelper(TensorList<GPUBackend> &output,
ctx.gpu.stream = stream;
auto in_view = view<const InputType, NumDims>(input);

std::vector<kernels::SliceFlipNormalizePermutePadArgs<NumDims, OutputType>> per_sample_args;
std::vector<kernels::SliceFlipNormalizePermutePadArgs<NumDims>> per_sample_args;
per_sample_args.reserve(slice_anchors.size());
for (std::size_t i = 0; i < slice_anchors.size(); i++) {
per_sample_args.emplace_back(slice_shapes[i]);
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/test_operator_crop_mirror_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def check_cmn_cpu_vs_gpu(batch_size, output_dtype, output_layout, mirror_probabi

def test_cmn_cpu_vs_gpu():
for batch_size in [1, 8]:
for output_dtype in [types.FLOAT, types.INT32]:
for output_dtype in [types.FLOAT, types.INT32, types.FLOAT16]:
for output_layout in [types.NHWC, types.NCHW]:
for mirror_probability in [0.0, 0.5, 1.0]:
norm_data = [ ([0., 0., 0.], [1., 1., 1.]),
Expand Down

0 comments on commit d1e100e

Please sign in to comment.