Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CWT operator #4860

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
937b963
add MotherWavelet helper and WaveletGpu kernel
May 18, 2023
cf7b6a6
Cwt WIP
mwdowski May 18, 2023
68bb330
Merge branch 'NVIDIA:main' into wavelet-computing
kubo11 May 18, 2023
9d6e0b0
Merge pull request #2 from mwdowski/wavelet-computing
mwdowski May 18, 2023
359d79c
Merge pull request #1 from mwdowski/mwdowski
mwdowski May 18, 2023
b034619
Rename namespace
mwdowski May 18, 2023
6bb49f5
Merge branch 'main' into mwdowski
mwdowski May 18, 2023
5eed0c5
add WaveletArgs class
May 22, 2023
09196c6
Merge pull request #3 from mwdowski/wavelet-computing
kubo11 May 29, 2023
279e61b
Improve wavelet computing kernel
Jun 5, 2023
c4814f9
Optimize and remove discrete wavelets
Jun 7, 2023
11df6aa
Merge pull request #4 from mwdowski/wavelet-computing-improvements
kubo11 Jun 7, 2023
d3a8d6a
add DALIWaveletName enum
Jun 11, 2023
27cedd3
fix linting errors
Jun 11, 2023
2875c95
replace MeyerWavelet with GaussianWavelet
Jun 13, 2023
20d5d7e
Merge pull request #5 from mwdowski/wavelet-computing-improvements
kubo11 Jun 13, 2023
0efec3d
Fix wavelet exceptions
Jul 3, 2023
1ed22bc
Add CWT operator docstr
Jul 4, 2023
3c36192
Merge pull request #6 from mwdowski/wavelet-fixes
kubo11 Jul 6, 2023
1cdc5e7
WIP
mwdowski Sep 8, 2023
e99099e
Merge branch 'NVIDIA:main' into main
mwdowski Sep 8, 2023
15ce332
Merge branch 'main' into mwdowski2
mwdowski Sep 8, 2023
101efc4
Good size but full of zeros
mwdowski Sep 12, 2023
276f87e
WIP
mwdowski Sep 12, 2023
1849a30
Merge pull request #7 from mwdowski/mwdowski2
mwdowski Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dali/kernels/signal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_subdirectory(decibel)
if (BUILD_FFTS)
add_subdirectory(fft)
endif()
add_subdirectory(wavelet)
add_subdirectory(window)

collect_headers(DALI_INST_HDRS PARENT_SCOPE)
Expand Down
17 changes: 17 additions & 0 deletions dali/kernels/signal/wavelet/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

collect_headers(DALI_INST_HDRS PARENT_SCOPE)
collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE)
collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE)
33 changes: 33 additions & 0 deletions dali/kernels/signal/wavelet/cwt_args.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_

namespace dali {
namespace kernels {
namespace signal {
namespace wavelet {

template <typename T = float>
struct CwtArgs {
T a;
};

} // namespace wavelet
} // namespace signal
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_
98 changes: 98 additions & 0 deletions dali/kernels/signal/wavelet/cwt_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>
#include <complex>
#include <vector>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/kernels/kernel.h"
#include "dali/kernels/signal/wavelets/cwt_args.h"
#include "dali/kernels/signal/wavelets/cwt_gpu.h"

namespace dali {
namespace kernels {
namespace signal {
namespace wavelet {

template <typename T>
struct SampleDesc {
const T *in = nullptr;
T *out = nullptr;
int64_t size = 0;
};

template <typename T>
__global__ void CwtKernel(const SampleDesc<T> *sample_data, CwtArgs<T> args) {
const int64_t block_size = blockDim.y * blockDim.x;
const int64_t grid_size = gridDim.x * block_size;
const int sample_idx = blockIdx.y;
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
const auto sample = sample_data[sample_idx];
const int64_t offset = block_size * blockIdx.x;
const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x;

for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) {
sample.out[idx] = sample.in[idx] * args.a;
}
}

template <typename T>
CwtGpu<T>::~CwtGpu() = default;

template <typename T>
KernelRequirements CwtGpu<T>::Setup(KernelContext &context,
const InListGPU<T, DynamicDimensions> &in) {
auto out_shape = in.shape;
const size_t num_samples = in.size();
ScratchpadEstimator se;
se.add<mm::memory_kind::host, SampleDesc<T>>(num_samples);
se.add<mm::memory_kind::device, SampleDesc<T>>(num_samples);
KernelRequirements req;
req.scratch_sizes = se.sizes;
req.output_shapes = {out_shape};
return req;
}

template <typename T>
void CwtGpu<T>::Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out,
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args) {
auto num_samples = in.size();
auto *sample_data = context.scratchpad->AllocateHost<SampleDesc<T>>(num_samples);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as dali/kernels/signal/wavelet/wavelet_gpu.cu


for (int i = 0; i < num_samples; i++) {
auto &sample = sample_data[i];
sample.out = out.tensor_data(i);
sample.in = in.tensor_data(i);
sample.size = volume(in.tensor_shape(i));
assert(sample.size == volume(out.tensor_shape(i)));
}

auto *sample_data_gpu = context.scratchpad->AllocateGPU<SampleDesc<T>>(num_samples);
CUDA_CALL(cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc<T>),
cudaMemcpyHostToDevice, context.gpu.stream));

dim3 block(32, 32);
auto blocks_per_sample = std::max(32, 1024 / num_samples);
dim3 grid(blocks_per_sample, num_samples);
CwtKernel<T><<<grid, block, 0, context.gpu.stream>>>(sample_data_gpu, args);
}

template class CwtGpu<float>;
template class CwtGpu<double>;

} // namespace wavelet
} // namespace signal
} // namespace kernels
} // namespace dali
50 changes: 50 additions & 0 deletions dali/kernels/signal/wavelet/cwt_gpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_

#include <memory>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/core/util.h"
#include "dali/kernels/kernel.h"
#include "dali/kernels/signal/wavelets/cwt_args.h"

namespace dali {
namespace kernels {
namespace signal {
namespace wavelet {

template <typename T = float>
class DLL_PUBLIC CwtGpu {
public:
static_assert(std::is_floating_point<T>::value, "Only floating point types are supported");

DLL_PUBLIC ~CwtGpu();

DLL_PUBLIC KernelRequirements Setup(KernelContext &context,
const InListGPU<T, DynamicDimensions> &in);

DLL_PUBLIC void Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out,
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args);
};

} // namespace wavelet
} // namespace signal
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_
162 changes: 162 additions & 0 deletions dali/kernels/signal/wavelet/mother_wavelet.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>
#include <vector>
#include "dali/kernels/signal/wavelet/mother_wavelet.cuh"
#include "dali/core/math_util.h"

namespace dali {
namespace kernels {
namespace signal {

template <typename T>
HaarWavelet<T>::HaarWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw new std::invalid_argument("HaarWavelet doesn't accept any arguments.");
}
}

template <typename T>
__device__ T HaarWavelet<T>::operator()(const T &t) const {
if (0.0 <= t && t < 0.5) {
return 1.0;
}
if (0.5 <= t && t < 1.0) {
return -1.0;
}
return 0.0;
}

template class HaarWavelet<float>;
template class HaarWavelet<double>;

template <typename T>
GaussianWavelet<T>::GaussianWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("GaussianWavelet accepts exactly one argument - n.");
}
if (args[0] < 1.0 || args[0] > 8.0) {
throw new std::invalid_argument(
"GaussianWavelet's argument n should be integer from range [1,8].");
}
this->n = args[0];
}

template <typename T>
__device__ T GaussianWavelet<T>::operator()(const T &t) const {
T expTerm = std::exp(-std::pow(t, 2.0));
T sqrtTerm = 1.2533141373155001; // std::sqrt(M_PI/2.0)
switch (static_cast<int>(n)) {
case 1:
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
return -2.0*t*expTerm/std::sqrt(sqrtTerm);
case 2:
return (-4.0*std::pow(t, 2.0)+2.0)*expTerm/std::sqrt(3.0*sqrtTerm);
case 3:
return (8.0*std::pow(t, 3.0)-12.0*t)*expTerm/std::sqrt(15.0*sqrtTerm);
case 4:
return (-48.0*std::pow(t, 2.0)+16.0*std::pow(t, 4.0)+12.0)*expTerm/std::sqrt(105.0*sqrtTerm);
case 5:
return (-32.0*std::pow(t, 5.0)+160.0*std::pow(t, 3.0)-120.0*t)*
expTerm/std::sqrt(945.0*sqrtTerm);
case 6:
return (-64.0*std::pow(t, 6.0)+480.0*std::pow(t, 4.0)-720.0*std::pow(t, 2.0)+120.0)*
expTerm/std::sqrt(10395.0*sqrtTerm);
case 7:
return (128.0*std::pow(t, 7.0)-1344.0*std::pow(t, 5.0)+3360.0*std::pow(t, 3.0)-1680.0*t)*
expTerm/std::sqrt(135135.0*sqrtTerm);
case 8:
return (256.0*std::pow(t, 8.0)-3584.0*std::pow(t, 6.0)+13440.0*std::pow(t, 4.0)-13440.0*
std::pow(t, 2.0)+1680.0)*expTerm/std::sqrt(2027025.0*sqrtTerm);
}
}

template class GaussianWavelet<float>;
template class GaussianWavelet<double>;

template <typename T>
MexicanHatWavelet<T>::MexicanHatWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma.");
}
this->sigma = args[0];
}

template <typename T>
__device__ T MexicanHatWavelet<T>::operator()(const T &t) const {
return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))*
std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0)));
}

template class MexicanHatWavelet<float>;
template class MexicanHatWavelet<double>;

template <typename T>
MorletWavelet<T>::MorletWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C.");
}
this->C = args[0];
}

template <typename T>
__device__ T MorletWavelet<T>::operator()(const T &t) const {
return C * std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t);
}

template class MorletWavelet<float>;
template class MorletWavelet<double>;

template <typename T>
ShannonWavelet<T>::ShannonWavelet(const std::vector<T> &args) {
if (args.size() != 2) {
throw new std::invalid_argument(
"ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order.");
}
this->fb = args[0];
this->fc = args[1];
}

template <typename T>
__device__ T ShannonWavelet<T>::operator()(const T &t) const {
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb);
return t == 0.0 ? res : res*std::sin(t*fb*(T)(M_PI))/(t*fb*(T)(M_PI));
}

template class ShannonWavelet<float>;
template class ShannonWavelet<double>;

template <typename T>
FbspWavelet<T>::FbspWavelet(const std::vector<T> &args) {
if (args.size() != 3) {
throw new std::invalid_argument(
"FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
}
this->m = args[0];
this->fb = args[1];
this->fc = args[2];
}

template <typename T>
__device__ T FbspWavelet<T>::operator()(const T &t) const {
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb);
return t == 0.0 ? res : res*std::pow(std::sin((T)(M_PI)*t*fb/m)/((T)(M_PI)*t*fb/m), m);
}

template class FbspWavelet<float>;
template class FbspWavelet<double>;

} // namespace signal
} // namespace kernels
} // namespace dali
Loading