From e6ec7010d34e3ac58a931c733e6cf90c68910b63 Mon Sep 17 00:00:00 2001 From: Dan Smith Date: Tue, 23 Jan 2024 13:31:29 -0500 Subject: [PATCH] reduce duplicate code --- .../include/six/sicd/NearestNeighbors.h | 12 +-- .../source/NearestNeighbors_unseq.cpp | 98 ++++++------------- 2 files changed, 36 insertions(+), 74 deletions(-) diff --git a/six/modules/c++/six.sicd/include/six/sicd/NearestNeighbors.h b/six/modules/c++/six.sicd/include/six/sicd/NearestNeighbors.h index 6e53c3b81..e9629551e 100644 --- a/six/modules/c++/six.sicd/include/six/sicd/NearestNeighbors.h +++ b/six/modules/c++/six.sicd/include/six/sicd/NearestNeighbors.h @@ -68,15 +68,15 @@ struct NearestNeighbors final uint8_t find_nearest(six::zfloat phase_direction, six::zfloat v) const; uint8_t getPhase(six::zfloat) const; -#if SIX_sicd_ComplexToAMP8IPHS8I_unseq + #if SIX_sicd_ComplexToAMP8IPHS8I_unseq + void nearest_neighbors_(execution_policy, std::span inputs, std::span results) const; + template - auto nearest_neighbors_unseq_T(std::span p) const; // TODO: std::span ... ? - template - void nearest_neighbors_unseq_(std::span inputs, std::span results) const; + auto unseq_nearest_neighbors(std::span p) const; // TODO: std::span ... ? template - void nearest_neighbors_par_unseq_T(std::span inputs, std::span results) const; -#endif + void nearest_neighbors_T(execution_policy, std::span inputs, std::span results) const; + #endif }; } diff --git a/six/modules/c++/six.sicd/source/NearestNeighbors_unseq.cpp b/six/modules/c++/six.sicd/source/NearestNeighbors_unseq.cpp index 564c5415c..0de69a6fc 100644 --- a/six/modules/c++/six.sicd/source/NearestNeighbors_unseq.cpp +++ b/six/modules/c++/six.sicd/source/NearestNeighbors_unseq.cpp @@ -901,7 +901,7 @@ template using IntV = decltype(::getPhase(ZFloatV{}, 0.0f)); template -auto six::sicd::NearestNeighbors::nearest_neighbors_unseq_T(std::span p) const // TODO: std::span ... ? The compiler can sometimes do better optimization with fixed-size structures. +auto six::sicd::NearestNeighbors::unseq_nearest_neighbors(std::span p) const // TODO: std::span ... ? The compiler can sometimes do better optimization with fixed-size structures. { ZFloatV v; assert(p.size() == size(v)); @@ -931,7 +931,8 @@ static void finish_nearest_neighbors_unseq(const six::sicd::NearestNeighbors& im } template -void six::sicd::NearestNeighbors::nearest_neighbors_unseq_(std::span inputs, std::span results) const +void six::sicd::NearestNeighbors::nearest_neighbors_T(execution_policy policy, + std::span inputs, std::span results) const { // View the data as chunks of *elements_per_iteration*. This allows iterating // to go *elements_per_iteration* at a time; and each chunk can be processed @@ -949,38 +950,22 @@ void six::sicd::NearestNeighbors::nearest_neighbors_unseq_(std::span(v); + return unseq_nearest_neighbors(v); }; - std::transform(/*std::execution::unseq,*/ b, e, d, func); - // Then finish off anything left - finish_nearest_neighbors_unseq(*this, inputs, results); -} - -template -void six::sicd::NearestNeighbors::nearest_neighbors_par_unseq_T(std::span inputs, std::span results) const -{ - // View the data as chunks of *elements_per_iteration*. This allows iterating - // to go *elements_per_iteration* at a time; and each chunk can be processed - // using `nearest_neighbors_unseq_T()`, above. - using extents_t = coda_oss::dextents; // two dimensions: M×N - const extents_t extents{ inputs.size() / elements_per_iteration, elements_per_iteration }; - const coda_oss::mdspan md_inputs(inputs.data(), extents); - assert(md_inputs.size() <= inputs.size()); - auto const b = cbegin(md_inputs); - auto const e = cend(md_inputs); - - const coda_oss::mdspan md_results(results.data(), extents); - assert(md_results.size() <= results.size()); - auto const d = begin(md_results); - - const auto func = [&](const auto& v) + if (policy == execution_policy::unseq) { - return nearest_neighbors_unseq_T(v); - }; - //std::transform(std::execution::par_unseq, b, e, d, func); - mt::Transform_par(b, e, d, func); - + std::transform(/*std::execution::unseq,*/ b, e, d, func); + } + else if (policy == execution_policy::par_unseq) + { + //std::transform(std::execution::par_unseq, b, e, d, func); + mt::Transform_par(b, e, d, func); + } + else + { + throw std::logic_error("Unsupported execution_policy"); + } // Then finish off anything left finish_nearest_neighbors_unseq(*this, inputs, results); @@ -1019,74 +1004,51 @@ std::string SIX_SICD_API six_sicd_set_nearest_neighbors_unseq(std::string unseq) return retval; } -void six::sicd::NearestNeighbors::nearest_neighbors_unseq(std::span inputs, std::span results) const +void six::sicd::NearestNeighbors::nearest_neighbors_(execution_policy policy, + std::span inputs, std::span results) const { // TODO: there could be more complicated logic here to determine which UNSEQ // implementation to use. - // This is very simple as it's only used for unit-testing const auto& unseq = ::nearest_neighbors_unseq_; #if SIX_sicd_has_simd if (unseq == unseq_simd) { - return nearest_neighbors_unseq_(inputs, results); + return nearest_neighbors_T(policy, inputs, results); } #endif #if SIX_sicd_has_VCL if (unseq == unseq_vcl) { - return nearest_neighbors_unseq_(inputs, results); + return nearest_neighbors_T(policy, inputs, results); } #endif #if SIX_sicd_has_valarray if (unseq == unseq_valarray) { - return nearest_neighbors_unseq_(inputs, results); + return nearest_neighbors_T(policy, inputs, results); } #endif #if SIX_sicd_has_ximd if (unseq == unseq_ximd) { - return nearest_neighbors_unseq_(inputs, results); + return nearest_neighbors_T(policy, inputs, results); } #endif - throw std::logic_error("Don't know how to implement nearest_neighbors_unseq() for unseq=" + unseq); + throw std::logic_error("Don't know how to implement nearest_neighbors_() for unseq=" + unseq); +} +void six::sicd::NearestNeighbors::nearest_neighbors_unseq(std::span inputs, std::span results) const +{ + // TODO: there could be more complicated logic here to determine which UNSEQ + // implementation to use. + nearest_neighbors_(execution_policy::unseq, inputs, results); } - void six::sicd::NearestNeighbors::nearest_neighbors_par_unseq(std::span inputs, std::span results) const { // TODO: there could be more complicated logic here to determine which UNSEQ // implementation to use. - - // This is very simple as it's only used for unit-testing - const auto& unseq = ::nearest_neighbors_unseq_; - #if SIX_sicd_has_simd - if (unseq == unseq_simd) - { - return nearest_neighbors_par_unseq_T(inputs, results); - } - #endif - #if SIX_sicd_has_VCL - if (unseq == unseq_vcl) - { - return nearest_neighbors_par_unseq_T(inputs, results); - } - #endif - #if SIX_sicd_has_valarray - if (unseq == unseq_valarray) - { - return nearest_neighbors_par_unseq_T(inputs, results); - } - #endif - #if SIX_sicd_has_ximd - if (unseq == unseq_ximd) - { - return nearest_neighbors_par_unseq_T(inputs, results); - } - #endif - - throw std::logic_error("Don't know how to implement nearest_neighbors_par_unseq() for unseq=" + unseq); + nearest_neighbors_(execution_policy::par_unseq, inputs, results); } #endif // SIX_sicd_ComplexToAMP8IPHS8I_unseq